refactor(backend): Un-share resource initializations from AppService + Remove Pyro (#9750)

This is a prerequisite infra change for
https://github.com/Significant-Gravitas/AutoGPT/issues/9714.

We will need a service where we can maintain our own client (db, redis,
rabbitmq, be it async/sync) and configure our own cadence of
initialization and cleanup.

While refactoring the service.py, an option to use Pyro as an RPC
protocol is also removed.

### Changes 🏗️

* Decouple resource initialization and cleanup from the parent
AppService logic.
* Removed Pyro.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] CI
This commit is contained in:
Zamil Majdy
2025-04-08 21:47:22 +02:00
committed by GitHub
parent d316ed23d4
commit 7fedb5e2fd
12 changed files with 76 additions and 349 deletions

View File

@@ -1,3 +1,6 @@
import logging
from backend.data import db, redis
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
GraphExecution,
@@ -44,6 +47,7 @@ from backend.util.settings import Config
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
async def _spend_credits(
@@ -55,10 +59,22 @@ async def _spend_credits(
class DatabaseManager(AppService):
def __init__(self):
super().__init__()
self.use_db = True
self.use_redis = True
self.execution_event_bus = RedisExecutionEventBus()
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
super().run_service()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
@classmethod
def get_port(cls) -> int:
return config.database_api_port

View File

@@ -930,8 +930,6 @@ class Executor:
class ExecutionManager(AppService):
def __init__(self):
super().__init__()
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@@ -944,14 +942,17 @@ class ExecutionManager(AppService):
from backend.integrations.credentials_store import IntegrationCredentialsStore
self.credentials_store = IntegrationCredentialsStore()
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
)
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
sync_manager = multiprocessing.Manager()
logger.info(
f"[{self.service_name}] Started with max-{self.pool_size} graph workers"
)
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
@@ -968,10 +969,13 @@ class ExecutionManager(AppService):
)
def cleanup(self):
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
redis.disconnect()
@property
def db_client(self) -> "DatabaseManager":

View File

@@ -206,6 +206,12 @@ class Scheduler(AppService):
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.start()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
if self.scheduler:
self.scheduler.shutdown(wait=False)
@expose
def add_execution_schedule(
self,

View File

@@ -9,6 +9,7 @@ from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data import rabbitmq
from backend.data.notifications import (
BaseSummaryData,
BaseSummaryParams,
@@ -128,6 +129,20 @@ class NotificationManager(AppService):
self.running = True
self.email_sender = EmailSender()
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
"""Access the RabbitMQ service. Will raise if not configured."""
if not self.rabbitmq_service:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_service
@property
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
"""Access the RabbitMQ config. Will raise if not configured."""
if not self.rabbitmq_config:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_config
@classmethod
def get_port(cls) -> int:
return settings.config.notification_service_port
@@ -688,10 +703,14 @@ class NotificationManager(AppService):
)
def run_service(self):
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
self.run_and_wait(self.rabbitmq_service.connect())
logger.info(f"[{self.service_name}] Started notification service")
# Set up scheduler for batch processing of all notification types
# this can be changed later to spawn differnt cleanups on different schedules
# this can be changed later to spawn different cleanups on different schedules
try:
get_scheduler().add_batched_notification_schedule(
notification_types=list(NotificationType),
@@ -753,3 +772,5 @@ class NotificationManager(AppService):
"""Cleanup service resources"""
self.running = False
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())

View File

@@ -145,6 +145,10 @@ class AgentServer(backend.util.service.AppProcess):
log_config=generate_uvicorn_config(),
)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
@staticmethod
async def test_execute_graph(
graph_id: str,

View File

@@ -10,7 +10,6 @@ from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
from backend.data import redis
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
@@ -56,15 +55,12 @@ def get_db_client():
async def event_broadcaster(manager: ConnectionManager):
try:
redis.connect()
event_queue = AsyncRedisExecutionEventBus()
async for event in event_queue.listen("*"):
await manager.send_execution_update(event)
except Exception as e:
logger.exception(f"Event broadcaster error: {e}")
raise
finally:
redis.disconnect()
async def authenticate_websocket(websocket: WebSocket) -> str:
@@ -294,3 +290,7 @@ class WebsocketServer(AppProcess):
port=Config().websocket_server_port,
log_config=generate_uvicorn_config(),
)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")

View File

@@ -48,6 +48,7 @@ class AppProcess(ABC):
def service_name(cls) -> str:
return cls.__name__
@abstractmethod
def cleanup(self):
"""
Implement this method on a subclass to do post-execution cleanup,

View File

@@ -1,52 +1,33 @@
import asyncio
import builtins
import inspect
import logging
import os
import threading
import time
import typing
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
from types import NoneType, UnionType
from typing import (
Annotated,
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
Dict,
FrozenSet,
Iterator,
List,
Optional,
ParamSpec,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
import httpx
import Pyro5.api
import uvicorn
from fastapi import FastAPI, Request, responses
from pydantic import BaseModel, TypeAdapter, create_model
from Pyro5 import api as pyro
from Pyro5 import config as pyro_config
from backend.data import db, rabbitmq, redis
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.process import AppProcess, get_service_name
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
from backend.util.settings import Config
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -57,21 +38,18 @@ api_host = config.pyro_host
api_comm_retry = config.pyro_client_comm_retry
api_comm_timeout = config.pyro_client_comm_timeout
api_call_timeout = config.rpc_client_call_timeout
pyro_config.MAX_RETRIES = api_comm_retry # type: ignore
pyro_config.COMMTIMEOUT = api_comm_timeout # type: ignore
P = ParamSpec("P")
R = TypeVar("R")
def fastapi_expose(func: C) -> C:
def expose(func: C) -> C:
func = getattr(func, "__func__", func)
setattr(func, "__exposed__", True)
return func
def fastapi_exposed_run_and_wait(
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
# TODO:
@@ -81,107 +59,11 @@ def fastapi_exposed_run_and_wait(
return expose(f) # type: ignore
# ----- Begin Pyro Expose Block ---- #
def pyro_expose(func: C) -> C:
"""
Decorator to mark a method or class to be exposed for remote calls.
## ⚠️ Gotcha
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
Any other passed or returned class objects are converted to dictionaries by Pyro.
"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
msg = f"Error in {func.__name__}: {e}"
if isinstance(e, ValueError):
logger.warning(msg)
else:
logger.exception(msg)
raise
register_pydantic_serializers(func)
return pyro.expose(wrapper) # type: ignore
def register_pydantic_serializers(func: Callable):
"""Register custom serializers and deserializers for annotated Pydantic models"""
for name, annotation in func.__annotations__.items():
try:
pydantic_types = _pydantic_models_from_type_annotation(annotation)
except Exception as e:
raise TypeError(f"Error while exposing {func.__name__}: {e}")
for model in pydantic_types:
logger.debug(
f"Registering Pyro (de)serializers for {func.__name__} annotation "
f"'{name}': {model.__qualname__}"
)
pyro.register_class_to_dict(model, _make_custom_serializer(model))
pyro.register_dict_to_class(
model.__qualname__, _make_custom_deserializer(model)
)
def _make_custom_serializer(model: Type[BaseModel]):
def custom_class_to_dict(obj):
data = {
"__class__": obj.__class__.__qualname__,
**obj.model_dump(),
}
logger.debug(f"Serializing {obj.__class__.__qualname__} with data: {data}")
return data
return custom_class_to_dict
def _make_custom_deserializer(model: Type[BaseModel]):
def custom_dict_to_class(qualname, data: dict):
logger.debug(f"Deserializing {model.__qualname__} from data: {data}")
return model(**data)
return custom_dict_to_class
def pyro_exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
if config.use_http_based_rpc:
expose = fastapi_expose
exposed_run_and_wait = fastapi_exposed_run_and_wait
else:
expose = pyro_expose
exposed_run_and_wait = pyro_exposed_run_and_wait
# ----- End Pyro Expose Block ---- #
# --------------------------------------------------
# AppService for IPC service based on HTTP request through FastAPI
# --------------------------------------------------
class BaseAppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
rabbitmq_config: Optional[rabbitmq.RabbitMQConfig] = None
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
use_supabase: bool = False
@classmethod
@abstractmethod
@@ -202,20 +84,6 @@ class BaseAppService(AppProcess, ABC):
return target_host
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
"""Access the RabbitMQ service. Will raise if not configured."""
if not self.rabbitmq_service:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_service
@property
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
"""Access the RabbitMQ config. Will raise if not configured."""
if not self.rabbitmq_config:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_config
def run_service(self) -> None:
while True:
time.sleep(10)
@@ -225,31 +93,6 @@ class BaseAppService(AppProcess, ABC):
def run(self):
self.shared_event_loop = asyncio.get_event_loop()
if self.use_db:
self.shared_event_loop.run_until_complete(db.connect())
if self.use_redis:
redis.connect()
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
if self.use_supabase:
from supabase import create_client
secrets = Secrets()
self.supabase = create_client(
secrets.supabase_url, secrets.supabase_service_role_key
)
def cleanup(self):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
self.run_and_wait(db.disconnect())
if self.use_redis:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
redis.disconnect()
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
class RemoteCallError(BaseModel):
@@ -268,7 +111,7 @@ EXCEPTION_MAPPING = {
}
class FastApiAppService(BaseAppService, ABC):
class AppService(BaseAppService, ABC):
fastapi_app: FastAPI
@staticmethod
@@ -383,62 +226,13 @@ class FastApiAppService(BaseAppService, ABC):
self.run_service()
# ----- Begin Pyro AppService Block ---- #
class PyroAppService(BaseAppService, ABC):
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
maximum_connection_thread_count = max(
Pyro5.config.THREADPOOL_SIZE,
config.num_node_workers * config.num_graph_workers,
)
Pyro5.config.THREADPOOL_SIZE = maximum_connection_thread_count # type: ignore
daemon = Pyro5.api.Daemon(host=api_host, port=self.get_port())
self.uri = daemon.register(self, objectId=self.service_name)
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
daemon.requestLoop()
def run(self):
super().run()
# Initialize the async loop.
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
async_thread.daemon = True
async_thread.start()
# Initialize pyro service
daemon_thread = threading.Thread(target=self.__start_pyro)
daemon_thread.daemon = True
daemon_thread.start()
# Run the main service loop (blocking).
self.run_service()
if config.use_http_based_rpc:
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
pass
else:
class AppService(PyroAppService, ABC):
pass
# ----- End Pyro AppService Block ---- #
# --------------------------------------------------
# HTTP Client utilities for dynamic service client abstraction
# --------------------------------------------------
AS = TypeVar("AS", bound=AppService)
def fastapi_close_service_client(client: Any) -> None:
def close_service_client(client: Any) -> None:
if hasattr(client, "close"):
client.close()
else:
@@ -446,7 +240,7 @@ def fastapi_close_service_client(client: Any) -> None:
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
def fastapi_get_service_client(
def get_service_client(
service_type: Type[AS],
call_timeout: int | None = api_call_timeout,
) -> AS:
@@ -506,93 +300,3 @@ def fastapi_get_service_client(
client.health_check()
return cast(AS, client)
# ----- Begin Pyro Client Block ---- #
class PyroClient:
proxy: Pyro5.api.Proxy
def pyro_close_service_client(client: BaseAppService) -> None:
if isinstance(client, PyroClient):
client.proxy._pyroRelease()
else:
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
def pyro_get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient(PyroClient):
@conn_retry("Pyro", f"Connecting to [{service_name}]")
def __init__(self):
uri = f"PYRO:{service_type.service_name}@{service_type.get_host()}:{service_type.get_port()}"
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
self.proxy = Pyro5.api.Proxy(uri)
# Attempt to bind to ensure the connection is established
self.proxy._pyroBind()
logger.debug(f"Successfully connected to service [{service_name}]")
def __getattr__(self, name: str) -> Callable[..., Any]:
res = getattr(self.proxy, name)
return res
return cast(AS, DynamicClient())
builtin_types = [*vars(builtins).values(), NoneType, Enum]
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
# Peel Annotated parameters
if (origin := get_origin(annotation)) and origin is Annotated:
annotation = get_args(annotation)[0]
origin = get_origin(annotation)
args = get_args(annotation)
if origin in (
Union,
UnionType,
list,
List,
tuple,
Tuple,
set,
Set,
frozenset,
FrozenSet,
):
for arg in args:
yield from _pydantic_models_from_type_annotation(arg)
elif origin in (dict, Dict):
key_type, value_type = args
yield from _pydantic_models_from_type_annotation(key_type)
yield from _pydantic_models_from_type_annotation(value_type)
elif origin in (Awaitable, Coroutine):
# For coroutines and awaitables, check the return type
return_type = args[-1]
yield from _pydantic_models_from_type_annotation(return_type)
else:
annotype = annotation if origin is None else origin
# Exclude generic types and aliases
if (
annotype is not None
and not hasattr(typing, getattr(annotype, "__name__", ""))
and isinstance(annotype, type)
):
if issubclass(annotype, BaseModel):
yield annotype
elif annotype not in builtin_types and not issubclass(annotype, Enum):
raise TypeError(f"Unsupported type encountered: {annotype}")
if config.use_http_based_rpc:
close_service_client = fastapi_close_service_client
get_service_client = fastapi_get_service_client
else:
close_service_client = pyro_close_service_client
get_service_client = pyro_get_service_client
# ----- End Pyro Client Block ---- #

View File

@@ -65,10 +65,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=1000,
description="Maximum number of workers to use for node execution within a single graph.",
)
use_http_based_rpc: bool = Field(
default=True,
description="Whether to use HTTP-based RPC for communication between services.",
)
pyro_host: str = Field(
default="localhost",
description="The default hostname of the Pyro server.",

View File

@@ -4098,21 +4098,6 @@ all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
nodejs = ["nodejs-wheel-binaries"]
[[package]]
name = "pyro5"
version = "5.15"
description = "Remote object communication library, fifth major version"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "Pyro5-5.15-py3-none-any.whl", hash = "sha256:4d85428ed75985e63f159d2486ad5680743ea76f766340fd30b65dd20f83d471"},
{file = "Pyro5-5.15.tar.gz", hash = "sha256:82c3dfc9860b49f897b28ff24fe6716c841672c600af8fe40d0e3a7fac9a3f5e"},
]
[package.dependencies]
serpent = ">=1.41"
[[package]]
name = "pytest"
version = "8.3.5"
@@ -4963,18 +4948,6 @@ statsig = ["statsig (>=0.55.3)"]
tornado = ["tornado (>=6)"]
unleash = ["UnleashClient (>=6.0.1)"]
[[package]]
name = "serpent"
version = "1.41"
description = "Serialization based on ast.literal_eval"
optional = false
python-versions = ">=3.2"
groups = ["main"]
files = [
{file = "serpent-1.41-py3-none-any.whl", hash = "sha256:5fd776b3420441985bc10679564c2c9b4a19f77bea59f018e473441d98ae5dd7"},
{file = "serpent-1.41.tar.gz", hash = "sha256:0407035fe3c6644387d48cff1467d5aa9feff814d07372b78677ed0ee3ed7095"},
]
[[package]]
name = "setuptools"
version = "75.8.0"
@@ -6337,4 +6310,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "648b5a17a37ea000d19d74a65a4b9aaa8f5c6bf73b05f13fc202648cb2b287f0"
content-hash = "781f77ec77cfce78b34fb57063dcc81df8e9c5a4be9a644033a0c197e0063730"

View File

@@ -45,7 +45,6 @@ psutil = "^7.0.0"
psycopg2-binary = "^2.9.10"
pydantic = { extras = ["email"], version = "^2.11.1" }
pydantic-settings = "^2.8.1"
pyro5 = "^5.15"
pytest = "^8.3.5"
pytest-asyncio = "^0.26.0"
python-dotenv = "^1.1.0"

View File

@@ -9,6 +9,9 @@ class ServiceTest(AppService):
def __init__(self):
super().__init__()
def cleanup(self):
pass
@classmethod
def get_port(cls) -> int:
return TEST_SERVICE_PORT