refactor(backend): Refactor log client and resource cleanup (#10558)

## Summary
- Created centralized service client helpers with thread caching in
`util/clients.py`
- Refactored service client management to eliminate health checks and
improve performance
- Enhanced logging in process cleanup to include error details
- Improved retry mechanisms and resource cleanup across the platform
- Updated multiple services to use new centralized client patterns

## Key Changes
### New Centralized Client Factory (`util/clients.py`)
- Added thread-cached factory functions for all major service clients:
  - Database managers (sync and async)
  - Scheduler client
  - Notification manager
  - Execution event bus (Redis-based)
  - RabbitMQ execution queue (sync and async)
  - Integration credentials store
- All clients use `@thread_cached` decorator for performance
optimization

### Service Client Improvements
- **Removed health checks**: Eliminated unnecessary health check calls
from `get_service_client()` to reduce startup overhead
- **Enhanced retry support**: Database manager clients now use request
retry by default
- **Better error handling**: Improved error propagation and logging

### Enhanced Logging and Cleanup
- **Process termination logs**: Added error details to termination
messages in `util/process.py`
- **Retry mechanism updates**: Improved retry logic with better error
handling in `util/retry.py`
- **Resource cleanup**: Better resource management across executors and
monitoring services

### Updated Service Usage
- Refactored 21+ files to use new centralized client patterns
- Updated all executor, monitoring, and notification services
- Maintained backward compatibility while improving performance

## Files Changed
- **Created**: `backend/util/clients.py` - Centralized client factory
with thread caching
- **Modified**: 21 files across blocks, executor, monitoring, and
utility modules
- **Key areas**: Service client initialization, resource cleanup, retry
mechanisms

## Test Plan
- [x] Verify all existing tests pass
- [x] Validate service startup and client initialization  
- [x] Test resource cleanup on process termination
- [x] Confirm retry mechanisms work correctly
- [x] Validate thread caching performance improvements
- [x] Ensure no breaking changes to existing functionality

## Breaking Changes
None - all changes maintain backward compatibility.

## Additional Notes
This refactoring centralizes client management patterns that were
scattered across the codebase, making them more consistent and
performant through thread caching. The removal of health checks reduces
startup time while maintaining reliability through improved retry
mechanisms.

🤖 Generated with [Claude Code](https://claude.ai/code)
This commit is contained in:
Zamil Majdy
2025-08-06 10:53:01 +04:00
committed by GitHub
parent fa2d968458
commit 3fe88b6106
22 changed files with 420 additions and 361 deletions

View File

@@ -14,7 +14,8 @@ from backend.data.block import (
)
from backend.data.execution import ExecutionStatus
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json, retry
from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry
_logger = logging.getLogger(__name__)
@@ -48,7 +49,7 @@ class AgentExecutorBlock(Block):
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
return validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass
@@ -180,7 +181,7 @@ class AgentExecutorBlock(Block):
)
yield output_name, output_data
@retry.func_retry
@func_retry
async def _stop(
self,
graph_exec_id: str,

View File

@@ -1,26 +1,18 @@
from datetime import datetime
from typing import Optional
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel, Field
from backend.data.block import BlockSchema
from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import MissingConfigError
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
async def get_profile_key(user_id: str):
user_integrations: UserIntegrations = (
await get_database_manager_client().get_user_integrations(user_id)
await get_database_manager_async_client().get_user_integrations(user_id)
)
return user_integrations.managed_credentials.ayrshare_profile_key

View File

@@ -1,22 +1,13 @@
import logging
from typing import Any, Literal
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
StorageScope = Literal["within_agent", "across_agents"]
@@ -88,7 +79,7 @@ class PersistInformationBlock(Block):
async def _store_data(
self, user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
return await get_database_manager_client().set_execution_kv_data(
return await get_database_manager_async_client().set_execution_kv_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=key,
@@ -149,7 +140,7 @@ class RetrieveInformationBlock(Block):
yield "value", input_data.default_value
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
return await get_database_manager_client().get_execution_kv_data(
return await get_database_manager_async_client().get_execution_kv_data(
user_id=user_id,
key=key,
)

View File

@@ -3,8 +3,6 @@ import re
from collections import Counter
from typing import TYPE_CHECKING, Any
from autogpt_libs.utils.cache import thread_cached
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
@@ -17,6 +15,7 @@ from backend.data.block import (
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
if TYPE_CHECKING:
from backend.data.graph import Link, Node
@@ -24,14 +23,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
@@ -333,7 +324,7 @@ class SmartDecisionMakerBlock(Block):
if not graph_id or not graph_version:
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_client()
db_client = get_database_manager_async_client()
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
@@ -393,7 +384,7 @@ class SmartDecisionMakerBlock(Block):
ValueError: If no tool links are found for the specified node_id, or if a sink node
or its metadata cannot be found.
"""
db_client = get_database_manager_client()
db_client = get_database_manager_async_client()
tools = [
(link, node)
for link, node in await db_client.get_connected_output_nodes(node_id)

View File

@@ -39,6 +39,7 @@ from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import type as type_utils
from backend.util.json import SafeJson
from backend.util.retry import func_retry
from backend.util.settings import Config
from backend.util.truncate import truncate
@@ -883,15 +884,15 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
def publish(self, res: GraphExecution | NodeExecutionResult):
if isinstance(res, GraphExecution):
self.publish_graph_exec_update(res)
self._publish_graph_exec_update(res)
else:
self.publish_node_exec_update(res)
self._publish_node_exec_update(res)
def publish_node_exec_update(self, res: NodeExecutionResult):
def _publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
def publish_graph_exec_update(self, res: GraphExecution):
def _publish_graph_exec_update(self, res: GraphExecution):
event = GraphExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
@@ -923,17 +924,18 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
def event_bus_name(self) -> str:
return config.execution_event_bus_name
@func_retry
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
if isinstance(res, GraphExecutionMeta):
await self.publish_graph_exec_update(res)
await self._publish_graph_exec_update(res)
else:
await self.publish_node_exec_update(res)
await self._publish_node_exec_update(res)
async def publish_node_exec_update(self, res: NodeExecutionResult):
async def _publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
# Add default empty values for compatibility
event_data = res.model_dump()

View File

@@ -4,20 +4,12 @@ from enum import Enum
from typing import Awaitable, Optional
import aio_pika
import aio_pika.exceptions as aio_ex
import pika
import pika.adapters.blocking_connection
from pika.exceptions import AMQPError
from pika.spec import BasicProperties
from pydantic import BaseModel
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from backend.util.retry import conn_retry
from backend.util.retry import conn_retry, func_retry
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -198,12 +190,7 @@ class SyncRabbitMQ(RabbitMQBase):
routing_key=queue.routing_key or queue.name,
)
@retry(
retry=retry_if_exception_type((AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
@func_retry
def publish_message(
self,
routing_key: str,
@@ -302,12 +289,7 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@retry(
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
@func_retry
async def publish_message(
self,
routing_key: str,

View File

@@ -13,6 +13,7 @@ from backend.blocks.llm import LlmModel, llm_call
from backend.data.block import get_block
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.data.model import APIKeyCredentials, GraphExecutionStats
from backend.util.retry import func_retry
from backend.util.settings import Settings
from backend.util.truncate import truncate
@@ -415,6 +416,7 @@ def _build_execution_summary(
}
@func_retry
async def _call_llm_direct(
credentials: APIKeyCredentials, prompt: list[dict[str, str]]
) -> str:

View File

@@ -26,14 +26,13 @@ from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.activity_status_generator import (
generate_activity_status_for_execution,
)
from backend.executor.utils import LogMetadata, create_execution_queue_config
from backend.executor.utils import LogMetadata
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
@@ -63,14 +62,19 @@ from backend.executor.utils import (
ExecutionOutputEntry,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
get_async_execution_event_bus,
get_execution_event_bus,
parse_execution_output,
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
get_database_manager_async_client,
get_database_manager_client,
get_execution_event_bus,
)
from backend.util.decorator import (
async_error_logged,
async_time_measured,
@@ -81,7 +85,6 @@ from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import continuous_retry, func_retry
from backend.util.service import get_service_client
from backend.util.settings import Settings
_logger = logging.getLogger(__name__)
@@ -1088,11 +1091,33 @@ class ExecutionManager(AppProcess):
super().__init__()
self.pool_size = settings.config.num_graph_workers
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
self._stop_consuming = None
self._executor = None
self._stop_consuming = None
self._cancel_thread = None
self._cancel_client = None
self._run_thread = None
self._run_client = None
@property
def cancel_thread(self) -> threading.Thread:
if self._cancel_thread is None:
self._cancel_thread = threading.Thread(
target=lambda: self._consume_execution_cancel(),
daemon=True,
)
return self._cancel_thread
@property
def run_thread(self) -> threading.Thread:
if self._run_thread is None:
self._run_thread = threading.Thread(
target=lambda: self._consume_execution_run(),
daemon=True,
)
return self._run_thread
@property
def stop_consuming(self) -> threading.Event:
if self._stop_consuming is None:
@@ -1108,44 +1133,55 @@ class ExecutionManager(AppProcess):
)
return self._executor
@property
def cancel_client(self) -> SyncRabbitMQ:
if self._cancel_client is None:
self._cancel_client = SyncRabbitMQ(create_execution_queue_config())
return self._cancel_client
@property
def run_client(self) -> SyncRabbitMQ:
if self._run_client is None:
self._run_client = SyncRabbitMQ(create_execution_queue_config())
return self._run_client
def run(self):
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
pool_size_gauge.set(self.pool_size)
self._update_prompt_metrics()
start_http_server(settings.config.execution_manager_port)
threading.Thread(
target=lambda: self._consume_execution_cancel(),
daemon=True,
).start()
self.cancel_thread.start()
self.run_thread.start()
threading.Thread(
target=lambda: self._consume_execution_run(),
daemon=True,
).start()
threading.Thread(
target=start_http_server,
args=(settings.config.execution_manager_port,),
daemon=True,
).start()
while not self.stop_consuming.is_set():
while True:
time.sleep(1e5)
@continuous_retry()
def _consume_execution_cancel(self):
self._cancel_client = SyncRabbitMQ(create_execution_queue_config())
self._cancel_client.connect()
cancel_channel = self._cancel_client.get_channel()
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
if self.stop_consuming.is_set() and not self.active_graph_runs:
logger.info(
f"[{self.service_name}] Stop reconnecting cancel consumer since the service is cleaned up."
)
return
self.cancel_client.connect()
cancel_channel = self.cancel_client.get_channel()
cancel_channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
)
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
cancel_channel.start_consuming()
raise RuntimeError(f"❌ cancel message consumer is stopped: {cancel_channel}")
if not self.stop_consuming.is_set() or self.active_graph_runs:
raise RuntimeError(
f"[{self.service_name}] ❌ cancel message consumer is stopped: {cancel_channel}"
)
logger.info(
f"[{self.service_name}] ✅ Cancel message consumer stopped gracefully"
)
@continuous_retry()
def _consume_execution_run(self):
@@ -1159,9 +1195,8 @@ class ExecutionManager(AppProcess):
)
return
self._run_client = SyncRabbitMQ(create_execution_queue_config())
self._run_client.connect()
run_channel = self._run_client.get_channel()
self.run_client.connect()
run_channel = self.run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
# Configure consumer for long-running graph executions
@@ -1173,21 +1208,12 @@ class ExecutionManager(AppProcess):
consumer_tag="graph_execution_consumer",
)
run_channel.confirm_delivery()
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
# Continue consuming messages until stop flag is set
# This keeps the connection alive but rejects new messages in _handle_run_message
while not self.stop_consuming.is_set():
try:
run_channel.connection.process_data_events(time_limit=1)
except Exception as e:
if self.stop_consuming.is_set():
# Expected during shutdown
break
logger.error(f"[{self.service_name}] Error processing events: {e}")
raise
run_channel.start_consuming()
if not self.stop_consuming.is_set():
raise RuntimeError(
f"[{self.service_name}] ❌ run message consumer is stopped: {run_channel}"
)
logger.info(f"[{self.service_name}] ✅ Run message consumer stopped gracefully")
@error_logged(swallow=True)
@@ -1233,18 +1259,30 @@ class ExecutionManager(AppProcess):
):
delivery_tag = method.delivery_tag
@func_retry
def _ack_message(reject: bool = False):
"""Acknowledge or reject the message based on execution status."""
if reject:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=True)
)
else:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_ack(delivery_tag)
)
# Check if we're shutting down - reject new messages but keep connection alive
if self.stop_consuming.is_set():
logger.info(
f"[{self.service_name}] Rejecting new execution during shutdown"
)
channel.basic_nack(delivery_tag, requeue=True)
_ack_message(reject=True)
return
# Check if we can accept more runs
self._cleanup_completed_runs()
if len(self.active_graph_runs) >= self.pool_size:
channel.basic_nack(delivery_tag, requeue=True)
_ack_message(reject=True)
return
try:
@@ -1253,7 +1291,7 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
)
channel.basic_nack(delivery_tag, requeue=False)
_ack_message(reject=True)
return
graph_exec_id = graph_exec_entry.graph_exec_id
@@ -1265,7 +1303,7 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
channel.basic_nack(delivery_tag, requeue=False)
_ack_message(reject=True)
return
cancel_event = multiprocessing.Manager().Event()
@@ -1283,23 +1321,9 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
)
try:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=True)
)
except Exception as ack_error:
logger.error(
f"[{self.service_name}] Failed to NACK message for {graph_exec_id}: {ack_error}"
)
_ack_message(reject=True)
else:
try:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_ack(delivery_tag)
)
except Exception as ack_error:
logger.error(
f"[{self.service_name}] Failed to ACK message for {graph_exec_id}: {ack_error}"
)
_ack_message(reject=False)
except BaseException as e:
logger.exception(
f"[{self.service_name}] Error in run completion callback: {e}"
@@ -1326,7 +1350,7 @@ class ExecutionManager(AppProcess):
def _update_prompt_metrics(self):
active_count = len(self.active_graph_runs)
active_runs_gauge.set(active_count)
if self._stop_consuming and self._stop_consuming.is_set():
if self.stop_consuming.is_set():
utilization_gauge.set(1.0)
else:
utilization_gauge.set(active_count / self.pool_size)
@@ -1337,8 +1361,15 @@ class ExecutionManager(AppProcess):
logger.info(f"{prefix} 🧹 Starting graceful shutdown...")
# Signal the consumer thread to stop (thread-safe)
self.stop_consuming.set()
logger.info(f"{prefix} ✅ Signaled execution message consumer to stop")
try:
self.stop_consuming.set()
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: run_channel.stop_consuming()
)
logger.info(f"{prefix} ✅ Exec consumer has been signaled to stop")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}")
# Wait for active executions to complete
if self.active_graph_runs:
@@ -1371,34 +1402,34 @@ class ExecutionManager(AppProcess):
else:
logger.info(f"{prefix} ✅ All executions completed gracefully")
# NOW shutdown executor pool after all executions and cleanup are complete
if self._executor:
logger.info(f"{prefix} ⏳ Shutting down GraphExec pool...")
try:
# All active executions are done, safe to shutdown workers
self._executor.shutdown(cancel_futures=True, wait=False)
logger.info(f"{prefix} ✅ Executor shutdown completed")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {e}")
# Shutdown the executor
try:
self.executor.shutdown(cancel_futures=True, wait=False)
logger.info(f"{prefix} ✅ Executor shutdown completed")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
# Clean up RabbitMQ connections
if self._cancel_client:
logger.info(f"{prefix} ⏳ Disconnecting cancel RabbitMQ client...")
try:
self._cancel_client.disconnect()
logger.info(f"{prefix} ✅ Cancel RabbitMQ client disconnected")
except Exception as e:
logger.error(
f"{prefix} ⚠️ Error disconnecting cancel RabbitMQ client: {e}"
)
# Disconnect the run execution consumer
try:
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: self.run_client.disconnect()
)
self.run_thread.join()
logger.info(f"{prefix} ✅ Run client disconnected")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
if self._run_client:
logger.info(f"{prefix} ⏳ Disconnecting run RabbitMQ client...")
try:
self._run_client.disconnect()
logger.info(f"{prefix} ✅ Run RabbitMQ client disconnected")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error disconnecting run RabbitMQ client: {e}")
# Disconnect the cancel execution consumer
try:
cancel_channel = self.cancel_client.get_channel()
cancel_channel.connection.add_callback_threadsafe(
lambda: self.cancel_client.disconnect()
)
self.cancel_thread.join()
logger.info(f"{prefix} ✅ Cancel client disconnected")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error disconnecting cancel client: {type(e)} {e}")
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
@@ -1406,26 +1437,15 @@ class ExecutionManager(AppProcess):
# ------- UTILITIES ------- #
@thread_cached
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
# Disable health check for the service client to avoid breaking process initializer.
return get_service_client(
DatabaseManagerClient, health_check=False, request_retry=True
)
return get_database_manager_client()
@thread_cached
def get_db_async_client() -> "DatabaseManagerAsyncClient":
from backend.executor import DatabaseManagerAsyncClient
# Disable health check for the service client to avoid breaking process initializer.
return get_service_client(
DatabaseManagerAsyncClient, health_check=False, request_retry=True
)
return get_database_manager_async_client()
@func_retry
async def send_async_execution_update(
entry: GraphExecution | NodeExecutionResult | None,
) -> None:
@@ -1434,6 +1454,7 @@ async def send_async_execution_update(
await get_async_execution_event_bus().publish(entry)
@func_retry
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
if entry is None:
return

View File

@@ -93,6 +93,7 @@ async def _execute_graph(**kwargs):
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
)
except Exception as e:
# TODO: We need to communicate this error to the user somehow.
logger.error(f"Error executing graph {args.graph_id}: {e}")

View File

@@ -1,10 +1,9 @@
import pytest
from backend.data import db
from backend.executor.scheduler import SchedulerClient
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
from backend.util.clients import get_scheduler_client
from backend.util.test import SpinTestServer
@@ -17,7 +16,7 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)
scheduler = get_service_client(SchedulerClient)
scheduler = get_scheduler_client()
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0

View File

@@ -4,9 +4,8 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db
@@ -16,33 +15,25 @@ from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.db import prisma
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecutionStats,
GraphExecutionWithNodes,
RedisExecutionEventBus,
)
from backend.data.graph import GraphModel, Node
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.clients import (
get_async_execution_event_bus,
get_async_execution_queue,
get_database_manager_async_client,
get_integration_credentials_store,
)
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.logging import TruncatedLogger
from backend.util.mock import MockObject
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
@@ -79,51 +70,6 @@ class LogMetadata(TruncatedLogger):
)
@thread_cached
def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@thread_cached
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
@thread_cached
def get_execution_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_execution_queue_config())
client.connect()
return client
@thread_cached
async def get_async_execution_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
from backend.integrations.credentials_store import IntegrationCredentialsStore
return IntegrationCredentialsStore()
@thread_cached
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
return get_service_client(DatabaseManagerClient)
@thread_cached
def get_db_async_client() -> "DatabaseManagerAsyncClient":
from backend.executor import DatabaseManagerAsyncClient
return get_service_client(DatabaseManagerAsyncClient)
# ============ Execution Cost Helpers ============ #
@@ -450,7 +396,7 @@ def validate_exec(
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
logger.warning(error_message)
return None, error_message
return data, node_block.name
@@ -685,11 +631,6 @@ def _merge_nodes_input_masks(
# ============ Execution Queue Helpers ============ #
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
GRAPH_EXECUTION_EXCHANGE = Exchange(
name="graph_execution",
type=ExchangeType.DIRECT,
@@ -750,6 +691,10 @@ def create_execution_queue_config() -> RabbitMQConfig:
)
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
async def stop_graph_execution(
user_id: str,
graph_exec_id: str,
@@ -763,7 +708,7 @@ async def stop_graph_execution(
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
queue_client = await get_async_execution_queue()
db = execution_db if prisma.is_connected() else get_db_async_client()
db = execution_db if prisma.is_connected() else get_database_manager_async_client()
await queue_client.publish_message(
routing_key="",
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
@@ -849,8 +794,8 @@ async def add_graph_execution(
gdb = graph_db
edb = execution_db
else:
gdb = get_db_async_client()
edb = get_db_async_client()
gdb = get_database_manager_async_client()
edb = get_database_manager_async_client()
graph: GraphModel | None = await gdb.get_graph(
graph_id=graph_id,
@@ -903,7 +848,7 @@ async def add_graph_execution(
except BaseException as e:
err = str(e) or type(e).__name__
if not graph_exec:
logger.error(f"Graph execution #{graph_id} failed: {err}")
logger.error(f"Unable to execute graph #{graph_id} failed: {err}")
raise
logger.error(

View File

@@ -5,7 +5,6 @@ from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import Optional
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
from pydantic import SecretStr
@@ -229,17 +228,15 @@ class IntegrationCredentialsStore:
return self._locks
@property
@thread_cached
def db_manager(self):
if prisma.is_connected():
from backend.data import user
return user
else:
from backend.executor.database import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
from backend.util.clients import get_database_manager_async_client
return get_service_client(DatabaseManagerAsyncClient)
return get_database_manager_async_client()
# =============== USER-MANAGED CREDENTIALS =============== #
async def add_creds(self, user_id: str, credentials: Credentials) -> None:

View File

@@ -8,10 +8,11 @@ from pydantic import BaseModel
from backend.data.block import get_block
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.clients import (
get_database_manager_client,
get_notification_manager_client,
)
from backend.util.metrics import sentry_capture_error
from backend.util.service import get_service_client
from backend.util.settings import Config
logger = logging.getLogger(__name__)
@@ -40,7 +41,7 @@ class BlockErrorMonitor:
def __init__(self, include_top_blocks: int | None = None):
self.config = config
self.notification_client = get_service_client(NotificationManagerClient)
self.notification_client = get_notification_manager_client()
self.include_top_blocks = (
include_top_blocks
if include_top_blocks is not None
@@ -107,7 +108,7 @@ class BlockErrorMonitor:
) -> dict[str, BlockStatsWithSamples]:
"""Get block execution stats using efficient SQL aggregation."""
result = execution_utils.get_db_client().get_block_error_stats(
result = get_database_manager_client().get_block_error_stats(
start_time, end_time
)
@@ -197,7 +198,7 @@ class BlockErrorMonitor:
) -> list[str]:
"""Get error samples for a specific block - just a few recent ones."""
# Only fetch a small number of recent failed executions for this specific block
executions = execution_utils.get_db_client().get_node_executions(
executions = get_database_manager_client().get_node_executions(
block_ids=[block_id],
statuses=[ExecutionStatus.FAILED],
created_time_gte=start_time,

View File

@@ -4,10 +4,11 @@ import logging
from datetime import datetime, timedelta, timezone
from backend.data.execution import ExecutionStatus
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.clients import (
get_database_manager_client,
get_notification_manager_client,
)
from backend.util.metrics import sentry_capture_error
from backend.util.service import get_service_client
from backend.util.settings import Config
logger = logging.getLogger(__name__)
@@ -25,13 +26,13 @@ class LateExecutionMonitor:
def __init__(self):
self.config = config
self.notification_client = get_service_client(NotificationManagerClient)
self.notification_client = get_notification_manager_client()
def check_late_executions(self) -> str:
"""Check for late executions and send alerts if found."""
# Check for QUEUED executions
queued_late_executions = execution_utils.get_db_client().get_graph_executions(
queued_late_executions = get_database_manager_client().get_graph_executions(
statuses=[ExecutionStatus.QUEUED],
created_time_gte=datetime.now(timezone.utc)
- timedelta(
@@ -43,7 +44,7 @@ class LateExecutionMonitor:
)
# Check for RUNNING executions stuck for more than 24 hours
running_late_executions = execution_utils.get_db_client().get_graph_executions(
running_late_executions = get_database_manager_client().get_graph_executions(
statuses=[ExecutionStatus.RUNNING],
created_time_gte=datetime.now(timezone.utc)
- timedelta(hours=24)

View File

@@ -2,12 +2,10 @@
import logging
from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.notifications.notifications import NotificationManagerClient
from backend.util.service import get_service_client
from backend.util.clients import get_notification_manager_client
logger = logging.getLogger(__name__)
@@ -17,11 +15,6 @@ class NotificationJobArgs(BaseModel):
cron: str
@thread_cached
def get_notification_manager_client():
return get_service_client(NotificationManagerClient)
def process_existing_batches(**kwargs):
"""Process existing notification batches."""
args = NotificationJobArgs(**kwargs)

View File

@@ -5,7 +5,6 @@ from datetime import datetime, timedelta, timezone
from typing import Callable
import aio_pika
from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from backend.data import rabbitmq
@@ -26,26 +25,14 @@ from backend.data.notifications import (
get_notif_data_type,
get_summary_params_type,
)
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.clients import get_database_manager_client
from backend.util.logging import TruncatedLogger
from backend.util.metrics import discord_send_alert
from backend.util.retry import continuous_retry
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_sync,
expose,
get_service_client,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.settings import Settings
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
@@ -116,27 +103,6 @@ def create_notification_config() -> RabbitMQConfig:
)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManagerClient
return get_service_client(DatabaseManagerClient)
@thread_cached
def get_notification_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_notification_config())
client.connect()
return client
@thread_cached
async def get_async_notification_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_notification_config())
await client.connect()
return client
def get_routing_key(event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
"""Get the appropriate routing key for an event"""
@@ -161,6 +127,8 @@ def queue_notification(event: NotificationEventModel) -> NotificationResult:
exchange = "notifications"
routing_key = get_routing_key(event.type)
from backend.util.clients import get_notification_queue
queue = get_notification_queue()
queue.publish_message(
routing_key=routing_key,
@@ -186,6 +154,8 @@ async def queue_notification_async(event: NotificationEventModel) -> Notificatio
exchange = "notifications"
routing_key = get_routing_key(event.type)
from backend.util.clients import get_async_notification_queue
queue = await get_async_notification_queue()
await queue.publish_message(
routing_key=routing_key,
@@ -241,7 +211,7 @@ class NotificationManager(AppService):
processed_count = 0
current_time = datetime.now(tz=timezone.utc)
start_time = current_time - timedelta(days=7)
users = get_db().get_active_user_ids_in_timerange(
users = get_database_manager_client().get_active_user_ids_in_timerange(
end_time=current_time.isoformat(),
start_time=start_time.isoformat(),
)
@@ -275,14 +245,14 @@ class NotificationManager(AppService):
for notification_type in notification_types:
# Get all batches for this notification type
batches = get_db().get_all_batches_by_type(notification_type)
batches = get_database_manager_client().get_all_batches_by_type(
notification_type
)
for batch in batches:
# Check if batch has aged out
oldest_message = (
get_db().get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
oldest_message = get_database_manager_client().get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
if not oldest_message:
@@ -296,7 +266,11 @@ class NotificationManager(AppService):
# If batch has aged out, process it
if oldest_message.created_at + max_delay < current_time:
recipient_email = get_db().get_user_email_by_id(batch.user_id)
recipient_email = (
get_database_manager_client().get_user_email_by_id(
batch.user_id
)
)
if not recipient_email:
logger.error(
@@ -313,13 +287,15 @@ class NotificationManager(AppService):
f"User {batch.user_id} does not want to receive {notification_type} notifications"
)
# Clear the batch
get_db().empty_user_notification_batch(
get_database_manager_client().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
batch_data = get_db().get_user_notification_batch(
batch.user_id, notification_type
batch_data = (
get_database_manager_client().get_user_notification_batch(
batch.user_id, notification_type
)
)
if not batch_data or not batch_data.notifications:
@@ -327,7 +303,7 @@ class NotificationManager(AppService):
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
get_db().empty_user_notification_batch(
get_database_manager_client().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
@@ -363,7 +339,7 @@ class NotificationManager(AppService):
)
# Clear the batch
get_db().empty_user_notification_batch(
get_database_manager_client().empty_user_notification_batch(
batch.user_id, notification_type
)
@@ -412,9 +388,11 @@ class NotificationManager(AppService):
self, user_id: str, event_type: NotificationType
) -> bool:
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
validated_email = get_db().get_user_email_verification(user_id)
validated_email = get_database_manager_client().get_user_email_verification(
user_id
)
preference = (
get_db()
get_database_manager_client()
.get_user_notification_preference(user_id)
.preferences.get(event_type, True)
)
@@ -505,10 +483,14 @@ class NotificationManager(AppService):
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
) -> bool:
get_db().create_or_add_to_user_notification_batch(user_id, event_type, event)
get_database_manager_client().create_or_add_to_user_notification_batch(
user_id, event_type, event
)
oldest_message = get_db().get_user_notification_oldest_message_in_batch(
user_id, event_type
oldest_message = (
get_database_manager_client().get_user_notification_oldest_message_in_batch(
user_id, event_type
)
)
if not oldest_message:
logger.error(
@@ -559,7 +541,9 @@ class NotificationManager(AppService):
return False
logger.debug(f"Processing immediate notification: {event}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
recipient_email = get_database_manager_client().get_user_email_by_id(
event.user_id
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -594,7 +578,9 @@ class NotificationManager(AppService):
return False
logger.info(f"Processing batch notification: {event}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
recipient_email = get_database_manager_client().get_user_email_by_id(
event.user_id
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -613,7 +599,9 @@ class NotificationManager(AppService):
if not should_send:
logger.info("Batch not old enough to send")
return False
batch = get_db().get_user_notification_batch(event.user_id, event.type)
batch = get_database_manager_client().get_user_notification_batch(
event.user_id, event.type
)
if not batch or not batch.notifications:
logger.error(f"Batch not found for user {event.user_id}")
return False
@@ -714,7 +702,9 @@ class NotificationManager(AppService):
logger.info(
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
)
get_db().empty_user_notification_batch(event.user_id, event.type)
get_database_manager_client().empty_user_notification_batch(
event.user_id, event.type
)
else:
logger.warning(
f"Only sent {successfully_sent_count} of {len(batch_messages)} notifications. "
@@ -736,7 +726,9 @@ class NotificationManager(AppService):
logger.info(f"Processing summary notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
recipient_email = get_database_manager_client().get_user_email_by_id(
event.user_id
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False

View File

@@ -9,7 +9,6 @@ import pydantic
import stripe
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from autogpt_libs.utils.cache import thread_cached
from fastapi import (
APIRouter,
Body,
@@ -51,7 +50,6 @@ from backend.data.credit import (
get_user_credit_model,
set_auto_top_up,
)
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.model import CredentialsMetaInput
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -84,18 +82,13 @@ from backend.server.model import (
UploadFileResponse,
)
from backend.server.utils import get_user_id
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.service import get_service_client
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
@thread_cached
def execution_scheduler_client() -> scheduler.SchedulerClient:
return get_service_client(scheduler.SchedulerClient, health_check=False)
def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
"""Create standardized file size error response."""
return HTTPException(
@@ -104,11 +97,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
)
@thread_cached
def execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
settings = Settings()
logger = logging.getLogger(__name__)
@@ -924,7 +912,7 @@ async def create_graph_execution_schedule(
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
)
return await execution_scheduler_client().add_execution_schedule(
return await get_scheduler_client().add_execution_schedule(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
@@ -945,7 +933,7 @@ async def list_graph_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str = Path(),
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(
return await get_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)
@@ -960,7 +948,7 @@ async def list_graph_execution_schedules(
async def list_all_graphs_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(user_id=user_id)
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
@v1_router.delete(
@@ -974,7 +962,7 @@ async def delete_graph_execution_schedule(
schedule_id: str = Path(..., description="ID of the schedule to delete"),
) -> dict[str, Any]:
try:
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
await get_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
except NotFoundError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,

View File

@@ -0,0 +1,134 @@
"""
Centralized service client helpers with thread caching.
"""
from typing import TYPE_CHECKING
from autogpt_libs.utils.cache import thread_cached
if TYPE_CHECKING:
from backend.data.execution import (
AsyncRedisExecutionEventBus,
RedisExecutionEventBus,
)
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
from backend.executor.scheduler import SchedulerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.notifications.notifications import NotificationManagerClient
@thread_cached
def get_database_manager_client() -> "DatabaseManagerClient":
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
from backend.executor import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerClient, request_retry=True)
@thread_cached
def get_database_manager_async_client() -> "DatabaseManagerAsyncClient":
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, request_retry=True)
@thread_cached
def get_scheduler_client() -> "SchedulerClient":
"""Get a thread-cached SchedulerClient."""
from backend.executor.scheduler import SchedulerClient
from backend.util.service import get_service_client
return get_service_client(SchedulerClient)
@thread_cached
def get_notification_manager_client() -> "NotificationManagerClient":
"""Get a thread-cached NotificationManagerClient."""
from backend.notifications.notifications import NotificationManagerClient
from backend.util.service import get_service_client
return get_service_client(NotificationManagerClient)
# ============ Execution Event Bus Helpers ============ #
@thread_cached
def get_execution_event_bus() -> "RedisExecutionEventBus":
"""Get a thread-cached RedisExecutionEventBus."""
from backend.data.execution import RedisExecutionEventBus
return RedisExecutionEventBus()
@thread_cached
def get_async_execution_event_bus() -> "AsyncRedisExecutionEventBus":
"""Get a thread-cached AsyncRedisExecutionEventBus."""
from backend.data.execution import AsyncRedisExecutionEventBus
return AsyncRedisExecutionEventBus()
# ============ Execution Queue Helpers ============ #
@thread_cached
def get_execution_queue() -> "SyncRabbitMQ":
"""Get a thread-cached SyncRabbitMQ execution queue client."""
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
client = SyncRabbitMQ(create_execution_queue_config())
client.connect()
return client
@thread_cached
async def get_async_execution_queue() -> "AsyncRabbitMQ":
"""Get a thread-cached AsyncRabbitMQ execution queue client."""
from backend.data.rabbitmq import AsyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
# ============ Integration Credentials Store ============ #
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
"""Get a thread-cached IntegrationCredentialsStore."""
from backend.integrations.credentials_store import IntegrationCredentialsStore
return IntegrationCredentialsStore()
# ============ Notification Queue Helpers ============ #
@thread_cached
def get_notification_queue() -> "SyncRabbitMQ":
"""Get a thread-cached SyncRabbitMQ notification queue client."""
from backend.data.rabbitmq import SyncRabbitMQ
from backend.notifications.notifications import create_notification_config
client = SyncRabbitMQ(create_notification_config())
client.connect()
return client
@thread_cached
async def get_async_notification_queue() -> "AsyncRabbitMQ":
"""Get a thread-cached AsyncRabbitMQ notification queue client."""
from backend.data.rabbitmq import AsyncRabbitMQ
from backend.notifications.notifications import create_notification_config
client = AsyncRabbitMQ(create_notification_config())
await client.connect()
return client

View File

@@ -75,7 +75,7 @@ class AppProcess(ABC):
self.run()
except BaseException as e:
logger.warning(
f"[{self.service_name}] Termination request: {type(e).__name__}; executing cleanup."
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
)
finally:
self.cleanup()

View File

@@ -161,14 +161,21 @@ def continuous_retry(*, retry_delay: float = 1.0):
@wraps(func)
def sync_wrapper(*args, **kwargs):
counter = 0
while True:
try:
return func(*args, **kwargs)
except Exception as exc:
logger.exception(
"%s failed with %s — retrying in %.2f s",
counter += 1
if counter % 10 == 0:
log = logger.exception
else:
log = logger.warning
log(
"%s failed for the %s times, error: [%s] — retrying in %.2fs",
func.__name__,
exc,
counter,
str(exc) or type(exc).__name__,
retry_delay,
)
time.sleep(retry_delay)
@@ -176,13 +183,20 @@ def continuous_retry(*, retry_delay: float = 1.0):
@wraps(func)
async def async_wrapper(*args, **kwargs):
while True:
counter = 0
try:
return await func(*args, **kwargs)
except Exception as exc:
logger.exception(
"%s failed with %s — retrying in %.2f s",
counter += 1
if counter % 10 == 0:
log = logger.exception
else:
log = logger.warning
log(
"%s failed for the %s times, error: [%s] — retrying in %.2fs",
func.__name__,
exc,
counter,
str(exc) or type(exc).__name__,
retry_delay,
)
await asyncio.sleep(retry_delay)

View File

@@ -277,7 +277,6 @@ ASC = TypeVar("ASC", bound=AppServiceClient)
def get_service_client(
service_client_type: Type[ASC],
call_timeout: int | None = api_call_timeout,
health_check: bool = True,
request_retry: bool = False,
) -> ASC:
@@ -461,8 +460,6 @@ def get_service_client(
return sync_method
client = cast(ASC, DynamicClient())
if health_check and hasattr(client, "health_check"):
client.health_check()
return client

View File

@@ -16,6 +16,12 @@ from backend.util.service import (
TEST_SERVICE_PORT = 8765
def wait_for_service_ready(service_client_type, timeout_seconds=30):
"""Helper method to wait for a service to be ready using health check with retry."""
client = get_service_client(service_client_type, request_retry=True)
client.health_check() # This will retry until service is ready
class ServiceTest(AppService):
def __init__(self):
super().__init__()
@@ -28,6 +34,15 @@ class ServiceTest(AppService):
def get_port(cls) -> int:
return TEST_SERVICE_PORT
def __enter__(self):
# Start the service
result = super().__enter__()
# Wait for the service to be ready
wait_for_service_ready(ServiceTestClient)
return result
@expose
def add(self, a: int, b: int) -> int:
return a + b
@@ -48,13 +63,13 @@ class ServiceTest(AppService):
"""Method that fails 2 times then succeeds - for testing retry logic"""
self.fail_count += 1
if self.fail_count <= 2:
raise RuntimeError("Database connection failed")
raise RuntimeError(f"Intended error for testing {self.fail_count}/2")
return a + b
@expose
def always_failing_add(self, a: int, b: int) -> int:
"""Method that always fails - for testing no retry when disabled"""
raise RuntimeError("Database connection failed")
raise RuntimeError("Intended error for testing")
class ServiceTestClient(AppServiceClient):
@@ -349,5 +364,5 @@ def test_service_no_retry_when_disabled(server):
client = get_service_client(ServiceTestClient, request_retry=False)
# This should fail immediately without retry
with pytest.raises(RuntimeError, match="Database connection failed"):
with pytest.raises(RuntimeError, match="Intended error for testing"):
client.always_failing_add(5, 3)