mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
feat(backend): Pyro to FastAPI migration for micro service (#9508)
Due to legacy reasons, we've been using Pyro for our inter-process communication channel. While it fulfilled our initial needs, there were a few limitations that have been encountered: * Each connection will reserve 1 thread, when the thread is running out there will be no connection being accepted by the service. * Lack of asynchronous execution mode, we are locked in the sync execution which ended up wasting the I/O bound workload. Moving away from this will unlock async execution support for agent blocks. * Low throughput, while the database is still the main bottleneck, we've started seeing instances where the service is being denied due to the high traffic of the Pyro service. ### Changes 🏗️ Replace the usage of Pyro with the FastAPI Rest HTTP server and make the code work. Introduced the new config: `use_http_based_rpc`: Whether to use HTTP-based RPC for communication between services. If it's enabled FastAPI will be used, if it's disabled existing Pyro will be used. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Create from scratch and execute an agent with at least 3 blocks with cost (AI blocks). - [x] Import an agent from file upload, and confirm it executes correctly <details> <summary>Example test plan</summary> - [ ] Create from scratch and execute an agent with at least 3 blocks - [ ] Import an agent from file upload, and confirm it executes correctly - [ ] Upload agent to marketplace - [ ] Import an agent from marketplace and confirm it executes correctly - [ ] Edit an agent from monitor, and confirm it executes correctly </details> #### For configuration changes: - [x] `.env.example` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) <details> <summary>Examples of configuration changes</summary> - Changing ports - Adding new services that need to communicate with each other - Secrets or environment variable changes - New or infrastructure changes such as databases </details>
This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionResult,
|
||||
@@ -23,12 +20,15 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, expose, register_pydantic_serializers
|
||||
from backend.util.service import AppService, expose, exposed_run_and_wait
|
||||
from backend.util.settings import Config
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
async def _spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
return await _user_credit_model.spend_credits(entry, 0, 0)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
@@ -46,22 +46,6 @@ class DatabaseManager(AppService):
|
||||
def send_execution_update(self, execution_result: ExecutionResult):
|
||||
self.event_queue.publish(execution_result)
|
||||
|
||||
@staticmethod
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
# Register serializers for annotations on bare function
|
||||
register_pydantic_serializers(f)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Executions
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_execution_results = exposed_run_and_wait(get_execution_results)
|
||||
@@ -78,11 +62,7 @@ class DatabaseManager(AppService):
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
|
||||
# Credits
|
||||
user_credit_model = get_user_credit_model()
|
||||
spend_credits = cast(
|
||||
Callable[[Any, NodeExecutionEntry, float, float], int],
|
||||
exposed_run_and_wait(user_credit_model.spend_credits),
|
||||
)
|
||||
spend_credits = exposed_run_and_wait(_spend_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = exposed_run_and_wait(get_user_metadata)
|
||||
|
||||
@@ -202,9 +202,7 @@ def execute_node(
|
||||
output_size = 0
|
||||
try:
|
||||
# Charge the user for the execution before running the block.
|
||||
# TODO: We assume the block is executed within 0 seconds.
|
||||
# This is fine because for now, there is no block that is charged by time.
|
||||
cost = db_client.spend_credits(data, input_size + output_size, 0)
|
||||
cost = db_client.spend_credits(data)
|
||||
|
||||
outputs: dict[str, Any] = {}
|
||||
for output_name, output_data in node_block.execute(
|
||||
@@ -263,7 +261,7 @@ def execute_node(
|
||||
raise e
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock:
|
||||
if creds_lock and creds_lock.locked():
|
||||
try:
|
||||
creds_lock.release()
|
||||
except Exception as e:
|
||||
|
||||
@@ -54,11 +54,11 @@ class AppProcess(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
def health_check(self) -> str:
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
"""
|
||||
pass
|
||||
return "OK"
|
||||
|
||||
def execute_run_command(self, silent):
|
||||
signal.signal(signal.SIGTERM, self._self_terminate)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -7,18 +8,21 @@ import time
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Coroutine,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
@@ -29,12 +33,16 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import Pyro5.api
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
from Pyro5 import api as pyro
|
||||
from Pyro5 import config as pyro_config
|
||||
|
||||
from backend.data import db, rabbitmq, redis
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config, Secrets
|
||||
@@ -47,9 +55,31 @@ config = Config()
|
||||
pyro_host = config.pyro_host
|
||||
pyro_config.MAX_RETRIES = config.pyro_client_comm_retry # type: ignore
|
||||
pyro_config.COMMTIMEOUT = config.pyro_client_comm_timeout # type: ignore
|
||||
api_host = config.pyro_host
|
||||
|
||||
|
||||
def expose(func: C) -> C:
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def fastapi_expose(func: C) -> C:
|
||||
func = getattr(func, "__func__", func)
|
||||
setattr(func, "__exposed__", True)
|
||||
return func
|
||||
|
||||
|
||||
def fastapi_exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
# TODO:
|
||||
# This function lies about its return type to make the DynamicClient
|
||||
# call the function synchronously, fix this when DynamicClient can choose
|
||||
# to call a function synchronously or asynchronously.
|
||||
return expose(f) # type: ignore
|
||||
|
||||
|
||||
# ----- Begin Pyro Expose Block ---- #
|
||||
def pyro_expose(func: C) -> C:
|
||||
"""
|
||||
Decorator to mark a method or class to be exposed for remote calls.
|
||||
|
||||
@@ -113,7 +143,36 @@ def _make_custom_deserializer(model: Type[BaseModel]):
|
||||
return custom_dict_to_class
|
||||
|
||||
|
||||
class AppService(AppProcess, ABC):
|
||||
def pyro_exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
# Register serializers for annotations on bare function
|
||||
register_pydantic_serializers(f)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
expose = fastapi_expose
|
||||
exposed_run_and_wait = fastapi_exposed_run_and_wait
|
||||
else:
|
||||
expose = pyro_expose
|
||||
exposed_run_and_wait = pyro_exposed_run_and_wait
|
||||
|
||||
# ----- End Pyro Expose Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# AppService for IPC service based on HTTP request through FastAPI
|
||||
# --------------------------------------------------
|
||||
class BaseAppService(AppProcess, ABC):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
@@ -121,9 +180,6 @@ class AppService(AppProcess, ABC):
|
||||
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
|
||||
use_supabase: bool = False
|
||||
|
||||
def __init__(self):
|
||||
self.uri = None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_port(cls) -> int:
|
||||
@@ -131,7 +187,7 @@ class AppService(AppProcess, ABC):
|
||||
|
||||
@classmethod
|
||||
def get_host(cls) -> str:
|
||||
return os.environ.get(f"{cls.service_name.upper()}_HOST", config.pyro_host)
|
||||
return os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
@@ -151,12 +207,8 @@ class AppService(AppProcess, ABC):
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
def __run_async(self, coro: Coroutine[Any, Any, T]):
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
|
||||
|
||||
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
future = self.__run_async(coro)
|
||||
return future.result()
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
@@ -166,12 +218,8 @@ class AppService(AppProcess, ABC):
|
||||
redis.connect()
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
|
||||
# if self.use_async:
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
|
||||
# else:
|
||||
# self.rabbitmq_service = rabbitmq.SyncRabbitMQ(self.rabbitmq_config)
|
||||
# self.rabbitmq_service.connect()
|
||||
if self.use_supabase:
|
||||
from supabase import create_client
|
||||
|
||||
@@ -180,19 +228,6 @@ class AppService(AppProcess, ABC):
|
||||
secrets.supabase_url, secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
# Initialize the async loop.
|
||||
async_thread = threading.Thread(target=self.__start_async_loop)
|
||||
async_thread.daemon = True
|
||||
async_thread.start()
|
||||
|
||||
# Initialize pyro service
|
||||
daemon_thread = threading.Thread(target=self.__start_pyro)
|
||||
daemon_thread.daemon = True
|
||||
daemon_thread.start()
|
||||
|
||||
# Run the main service (if it's not implemented, just sleep).
|
||||
self.run_service()
|
||||
|
||||
def cleanup(self):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
@@ -203,6 +238,105 @@ class AppService(AppProcess, ABC):
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
|
||||
|
||||
|
||||
class FastApiAppService(BaseAppService, ABC):
|
||||
fastapi_app: FastAPI
|
||||
|
||||
def _create_fastapi_endpoint(self, func: Callable) -> Callable:
|
||||
"""
|
||||
Generates a FastAPI endpoint for the given function, handling default and optional parameters.
|
||||
|
||||
:param func: The original function (sync/async, bound or unbound)
|
||||
:return: A FastAPI endpoint function.
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
fields = {}
|
||||
|
||||
is_bound_method = False
|
||||
for name, param in sig.parameters.items():
|
||||
if name in ("self", "cls"):
|
||||
is_bound_method = True
|
||||
continue
|
||||
|
||||
# Use the provided annotation or fallback to str if not specified
|
||||
annotation = (
|
||||
param.annotation if param.annotation != inspect.Parameter.empty else str
|
||||
)
|
||||
|
||||
# If a default value is provided, use it; otherwise, mark the field as required with '...'
|
||||
default = param.default if param.default != inspect.Parameter.empty else ...
|
||||
|
||||
fields[name] = (annotation, default)
|
||||
|
||||
# Dynamically create a Pydantic model for the request body
|
||||
RequestBodyModel = create_model("RequestBodyModel", **fields)
|
||||
f = func.__get__(self) if is_bound_method else func
|
||||
|
||||
if asyncio.iscoroutinefunction(f):
|
||||
|
||||
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
try:
|
||||
return await f(
|
||||
**{name: getattr(body, name) for name in body.model_fields}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in {func.__name__}: {e}")
|
||||
raise HTTPException(status_code=500, detail=e)
|
||||
|
||||
return async_endpoint
|
||||
else:
|
||||
|
||||
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
try:
|
||||
return f(
|
||||
**{name: getattr(body, name) for name in body.model_fields}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in {func.__name__}: {e}")
|
||||
raise HTTPException(status_code=500, detail=e)
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@conn_retry("FastAPI server", "Starting FastAPI server")
|
||||
def __start_fastapi(self):
|
||||
port = self.get_port()
|
||||
host = self.get_host()
|
||||
logger.info(
|
||||
f"[{self.service_name}] Starting RPC server at http://{host}:{port}"
|
||||
)
|
||||
server = uvicorn.Server(uvicorn.Config(self.fastapi_app, host=host, port=port))
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
if getattr(attr, "__exposed__", False):
|
||||
route_path = f"/{attr_name}"
|
||||
self.fastapi_app.add_api_route(
|
||||
route_path,
|
||||
self._create_fastapi_endpoint(attr),
|
||||
methods=["POST"],
|
||||
)
|
||||
self.fastapi_app.add_api_route(
|
||||
"/health_check", self.health_check, methods=["POST"]
|
||||
)
|
||||
|
||||
# Start the FastAPI server in a separate thread.
|
||||
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
|
||||
api_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
self.run_service()
|
||||
|
||||
|
||||
# ----- Begin Pyro AppService Block ---- #
|
||||
|
||||
|
||||
class PyroAppService(BaseAppService, ABC):
|
||||
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
maximum_connection_thread_count = max(
|
||||
@@ -216,28 +350,119 @@ class AppService(AppProcess, ABC):
|
||||
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
|
||||
daemon.requestLoop()
|
||||
|
||||
def __start_async_loop(self):
|
||||
self.shared_event_loop.run_forever()
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
# Initialize the async loop.
|
||||
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
|
||||
async_thread.daemon = True
|
||||
async_thread.start()
|
||||
|
||||
# Initialize pyro service
|
||||
daemon_thread = threading.Thread(target=self.__start_pyro)
|
||||
daemon_thread.daemon = True
|
||||
daemon_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
self.run_service()
|
||||
|
||||
|
||||
# --------- UTILITIES --------- #
|
||||
if config.use_http_based_rpc:
|
||||
|
||||
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
class AppService(PyroAppService, ABC):
|
||||
pass
|
||||
|
||||
|
||||
# ----- End Pyro AppService Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# HTTP Client utilities for dynamic service client abstraction
|
||||
# --------------------------------------------------
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def fastapi_close_service_client(client: Any) -> None:
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
else:
|
||||
logger.warning(f"Client {client} is not closable")
|
||||
|
||||
|
||||
@conn_retry("FastAPI client", "Creating service client")
|
||||
def fastapi_get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient:
|
||||
def __init__(self):
|
||||
host = os.environ.get(f"{service_name.upper()}_HOST", api_host)
|
||||
port = service_type.get_port()
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self.client = httpx.Client()
|
||||
|
||||
def _call_method(self, method_name: str, **kwargs) -> Any:
|
||||
try:
|
||||
url = f"{self.base_url}/{method_name}"
|
||||
response = self.client.post(url, json=to_dict(kwargs))
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error in {method_name}: {e.response.text}")
|
||||
raise Exception(e.response.text) from e
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
# Try to get the original function from the service type.
|
||||
orig_func = getattr(service_type, name, None)
|
||||
if orig_func is None:
|
||||
raise AttributeError(f"Method {name} not found in {service_type}")
|
||||
|
||||
sig = inspect.signature(orig_func)
|
||||
ret_ann = sig.return_annotation
|
||||
if ret_ann != inspect.Signature.empty:
|
||||
expected_return = TypeAdapter(ret_ann)
|
||||
else:
|
||||
expected_return = None
|
||||
|
||||
def method(*args, **kwargs) -> Any:
|
||||
if args:
|
||||
arg_names = list(sig.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
result = self._call_method(name, **kwargs)
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
|
||||
return method
|
||||
|
||||
client = cast(AS, DynamicClient())
|
||||
client.health_check()
|
||||
|
||||
return cast(AS, client)
|
||||
|
||||
|
||||
# ----- Begin Pyro Client Block ---- #
|
||||
class PyroClient:
|
||||
proxy: Pyro5.api.Proxy
|
||||
|
||||
|
||||
def close_service_client(client: AppService) -> None:
|
||||
def pyro_close_service_client(client: BaseAppService) -> None:
|
||||
if isinstance(client, PyroClient):
|
||||
client.proxy._pyroRelease()
|
||||
else:
|
||||
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
|
||||
|
||||
|
||||
def get_service_client(service_type: Type[AS]) -> AS:
|
||||
def pyro_get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient(PyroClient):
|
||||
@@ -304,3 +529,13 @@ def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel
|
||||
yield annotype
|
||||
elif annotype not in builtin_types and not issubclass(annotype, Enum):
|
||||
raise TypeError(f"Unsupported type encountered: {annotype}")
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
close_service_client = fastapi_close_service_client
|
||||
get_service_client = fastapi_get_service_client
|
||||
else:
|
||||
close_service_client = pyro_close_service_client
|
||||
get_service_client = pyro_get_service_client
|
||||
|
||||
# ----- End Pyro Client Block ---- #
|
||||
|
||||
@@ -65,6 +65,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for node execution within a single graph.",
|
||||
)
|
||||
use_http_based_rpc: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use HTTP-based RPC for communication between services.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
|
||||
Reference in New Issue
Block a user