feat(backend): Pyro to FastAPI migration for micro service (#9508)

Due to legacy reasons, we've been using Pyro for our inter-process
communication channel. While it fulfilled our initial needs, there were
a few limitations that have been encountered:
* Each connection will reserve 1 thread, when the thread is running out
there will be no connection being accepted by the service.
* Lack of asynchronous execution mode, we are locked in the sync
execution which ended up wasting the I/O bound workload. Moving away
from this will unlock async execution support for agent blocks.
* Low throughput, while the database is still the main bottleneck, we've
started seeing instances where the service is being denied due to the
high traffic of the Pyro service.

### Changes 🏗️

Replace the usage of Pyro with the FastAPI Rest HTTP server and make the
code work.

Introduced the new config:
`use_http_based_rpc`: Whether to use HTTP-based RPC for communication
between services.

If it's enabled FastAPI will be used, if it's disabled existing Pyro
will be used.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Create from scratch and execute an agent with at least 3 blocks
with cost (AI blocks).
- [x] Import an agent from file upload, and confirm it executes
correctly

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>
This commit is contained in:
Zamil Majdy
2025-02-25 12:04:10 +07:00
committed by GitHub
parent a694cf1e9d
commit 1d59fc869d
5 changed files with 284 additions and 67 deletions

View File

@@ -1,6 +1,3 @@
from functools import wraps
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
ExecutionResult,
@@ -23,12 +20,15 @@ from backend.data.user import (
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, expose, register_pydantic_serializers
from backend.util.service import AppService, expose, exposed_run_and_wait
from backend.util.settings import Config
P = ParamSpec("P")
R = TypeVar("R")
config = Config()
_user_credit_model = get_user_credit_model()
async def _spend_credits(entry: NodeExecutionEntry) -> int:
return await _user_credit_model.spend_credits(entry, 0, 0)
class DatabaseManager(AppService):
@@ -46,22 +46,6 @@ class DatabaseManager(AppService):
def send_execution_update(self, execution_result: ExecutionResult):
self.event_queue.publish(execution_result)
@staticmethod
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
# Executions
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_execution_results = exposed_run_and_wait(get_execution_results)
@@ -78,11 +62,7 @@ class DatabaseManager(AppService):
get_graph = exposed_run_and_wait(get_graph)
# Credits
user_credit_model = get_user_credit_model()
spend_credits = cast(
Callable[[Any, NodeExecutionEntry, float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)
spend_credits = exposed_run_and_wait(_spend_credits)
# User + User Metadata + User Integrations
get_user_metadata = exposed_run_and_wait(get_user_metadata)

View File

@@ -202,9 +202,7 @@ def execute_node(
output_size = 0
try:
# Charge the user for the execution before running the block.
# TODO: We assume the block is executed within 0 seconds.
# This is fine because for now, there is no block that is charged by time.
cost = db_client.spend_credits(data, input_size + output_size, 0)
cost = db_client.spend_credits(data)
outputs: dict[str, Any] = {}
for output_name, output_data in node_block.execute(
@@ -263,7 +261,7 @@ def execute_node(
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock:
if creds_lock and creds_lock.locked():
try:
creds_lock.release()
except Exception as e:

View File

@@ -54,11 +54,11 @@ class AppProcess(ABC):
"""
pass
def health_check(self):
def health_check(self) -> str:
"""
A method to check the health of the process.
"""
pass
return "OK"
def execute_run_command(self, silent):
signal.signal(signal.SIGTERM, self._self_terminate)

View File

@@ -1,5 +1,6 @@
import asyncio
import builtins
import inspect
import logging
import os
import threading
@@ -7,18 +8,21 @@ import time
import typing
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
from types import NoneType, UnionType
from typing import (
Annotated,
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
Dict,
FrozenSet,
Iterator,
List,
Optional,
ParamSpec,
Set,
Tuple,
Type,
@@ -29,12 +33,16 @@ from typing import (
get_origin,
)
import httpx
import Pyro5.api
from pydantic import BaseModel
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, TypeAdapter, create_model
from Pyro5 import api as pyro
from Pyro5 import config as pyro_config
from backend.data import db, rabbitmq, redis
from backend.util.json import to_dict
from backend.util.process import AppProcess
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
@@ -47,9 +55,31 @@ config = Config()
pyro_host = config.pyro_host
pyro_config.MAX_RETRIES = config.pyro_client_comm_retry # type: ignore
pyro_config.COMMTIMEOUT = config.pyro_client_comm_timeout # type: ignore
api_host = config.pyro_host
def expose(func: C) -> C:
P = ParamSpec("P")
R = TypeVar("R")
def fastapi_expose(func: C) -> C:
func = getattr(func, "__func__", func)
setattr(func, "__exposed__", True)
return func
def fastapi_exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
# TODO:
# This function lies about its return type to make the DynamicClient
# call the function synchronously, fix this when DynamicClient can choose
# to call a function synchronously or asynchronously.
return expose(f) # type: ignore
# ----- Begin Pyro Expose Block ---- #
def pyro_expose(func: C) -> C:
"""
Decorator to mark a method or class to be exposed for remote calls.
@@ -113,7 +143,36 @@ def _make_custom_deserializer(model: Type[BaseModel]):
return custom_dict_to_class
class AppService(AppProcess, ABC):
def pyro_exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
if config.use_http_based_rpc:
expose = fastapi_expose
exposed_run_and_wait = fastapi_exposed_run_and_wait
else:
expose = pyro_expose
exposed_run_and_wait = pyro_exposed_run_and_wait
# ----- End Pyro Expose Block ---- #
# --------------------------------------------------
# AppService for IPC service based on HTTP request through FastAPI
# --------------------------------------------------
class BaseAppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
@@ -121,9 +180,6 @@ class AppService(AppProcess, ABC):
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
use_supabase: bool = False
def __init__(self):
self.uri = None
@classmethod
@abstractmethod
def get_port(cls) -> int:
@@ -131,7 +187,7 @@ class AppService(AppProcess, ABC):
@classmethod
def get_host(cls) -> str:
return os.environ.get(f"{cls.service_name.upper()}_HOST", config.pyro_host)
return os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
@@ -151,12 +207,8 @@ class AppService(AppProcess, ABC):
while True:
time.sleep(10)
def __run_async(self, coro: Coroutine[Any, Any, T]):
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
future = self.__run_async(coro)
return future.result()
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
def run(self):
self.shared_event_loop = asyncio.get_event_loop()
@@ -166,12 +218,8 @@ class AppService(AppProcess, ABC):
redis.connect()
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
# if self.use_async:
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
# else:
# self.rabbitmq_service = rabbitmq.SyncRabbitMQ(self.rabbitmq_config)
# self.rabbitmq_service.connect()
if self.use_supabase:
from supabase import create_client
@@ -180,19 +228,6 @@ class AppService(AppProcess, ABC):
secrets.supabase_url, secrets.supabase_service_role_key
)
# Initialize the async loop.
async_thread = threading.Thread(target=self.__start_async_loop)
async_thread.daemon = True
async_thread.start()
# Initialize pyro service
daemon_thread = threading.Thread(target=self.__start_pyro)
daemon_thread.daemon = True
daemon_thread.start()
# Run the main service (if it's not implemented, just sleep).
self.run_service()
def cleanup(self):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
@@ -203,6 +238,105 @@ class AppService(AppProcess, ABC):
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
class FastApiAppService(BaseAppService, ABC):
fastapi_app: FastAPI
def _create_fastapi_endpoint(self, func: Callable) -> Callable:
"""
Generates a FastAPI endpoint for the given function, handling default and optional parameters.
:param func: The original function (sync/async, bound or unbound)
:return: A FastAPI endpoint function.
"""
sig = inspect.signature(func)
fields = {}
is_bound_method = False
for name, param in sig.parameters.items():
if name in ("self", "cls"):
is_bound_method = True
continue
# Use the provided annotation or fallback to str if not specified
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else str
)
# If a default value is provided, use it; otherwise, mark the field as required with '...'
default = param.default if param.default != inspect.Parameter.empty else ...
fields[name] = (annotation, default)
# Dynamically create a Pydantic model for the request body
RequestBodyModel = create_model("RequestBodyModel", **fields)
f = func.__get__(self) if is_bound_method else func
if asyncio.iscoroutinefunction(f):
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
try:
return await f(
**{name: getattr(body, name) for name in body.model_fields}
)
except Exception as e:
logger.exception(f"Error in {func.__name__}: {e}")
raise HTTPException(status_code=500, detail=e)
return async_endpoint
else:
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
try:
return f(
**{name: getattr(body, name) for name in body.model_fields}
)
except Exception as e:
logger.exception(f"Error in {func.__name__}: {e}")
raise HTTPException(status_code=500, detail=e)
return sync_endpoint
@conn_retry("FastAPI server", "Starting FastAPI server")
def __start_fastapi(self):
port = self.get_port()
host = self.get_host()
logger.info(
f"[{self.service_name}] Starting RPC server at http://{host}:{port}"
)
server = uvicorn.Server(uvicorn.Config(self.fastapi_app, host=host, port=port))
self.shared_event_loop.run_until_complete(server.serve())
def run(self):
super().run()
self.fastapi_app = FastAPI()
# Register the exposed API routes.
for attr_name, attr in vars(type(self)).items():
if getattr(attr, "__exposed__", False):
route_path = f"/{attr_name}"
self.fastapi_app.add_api_route(
route_path,
self._create_fastapi_endpoint(attr),
methods=["POST"],
)
self.fastapi_app.add_api_route(
"/health_check", self.health_check, methods=["POST"]
)
# Start the FastAPI server in a separate thread.
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
api_thread.start()
# Run the main service loop (blocking).
self.run_service()
# ----- Begin Pyro AppService Block ---- #
class PyroAppService(BaseAppService, ABC):
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
maximum_connection_thread_count = max(
@@ -216,28 +350,119 @@ class AppService(AppProcess, ABC):
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
daemon.requestLoop()
def __start_async_loop(self):
self.shared_event_loop.run_forever()
def run(self):
super().run()
# Initialize the async loop.
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
async_thread.daemon = True
async_thread.start()
# Initialize pyro service
daemon_thread = threading.Thread(target=self.__start_pyro)
daemon_thread.daemon = True
daemon_thread.start()
# Run the main service loop (blocking).
self.run_service()
# --------- UTILITIES --------- #
if config.use_http_based_rpc:
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
pass
else:
class AppService(PyroAppService, ABC):
pass
# ----- End Pyro AppService Block ---- #
# --------------------------------------------------
# HTTP Client utilities for dynamic service client abstraction
# --------------------------------------------------
AS = TypeVar("AS", bound=AppService)
def fastapi_close_service_client(client: Any) -> None:
if hasattr(client, "close"):
client.close()
else:
logger.warning(f"Client {client} is not closable")
@conn_retry("FastAPI client", "Creating service client")
def fastapi_get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient:
def __init__(self):
host = os.environ.get(f"{service_name.upper()}_HOST", api_host)
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
self.client = httpx.Client()
def _call_method(self, method_name: str, **kwargs) -> Any:
try:
url = f"{self.base_url}/{method_name}"
response = self.client.post(url, json=to_dict(kwargs))
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error in {method_name}: {e.response.text}")
raise Exception(e.response.text) from e
def close(self):
self.client.close()
def __getattr__(self, name: str) -> Callable[..., Any]:
# Try to get the original function from the service type.
orig_func = getattr(service_type, name, None)
if orig_func is None:
raise AttributeError(f"Method {name} not found in {service_type}")
sig = inspect.signature(orig_func)
ret_ann = sig.return_annotation
if ret_ann != inspect.Signature.empty:
expected_return = TypeAdapter(ret_ann)
else:
expected_return = None
def method(*args, **kwargs) -> Any:
if args:
arg_names = list(sig.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
result = self._call_method(name, **kwargs)
if expected_return:
return expected_return.validate_python(result)
return result
return method
client = cast(AS, DynamicClient())
client.health_check()
return cast(AS, client)
# ----- Begin Pyro Client Block ---- #
class PyroClient:
proxy: Pyro5.api.Proxy
def close_service_client(client: AppService) -> None:
def pyro_close_service_client(client: BaseAppService) -> None:
if isinstance(client, PyroClient):
client.proxy._pyroRelease()
else:
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
def get_service_client(service_type: Type[AS]) -> AS:
def pyro_get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient(PyroClient):
@@ -304,3 +529,13 @@ def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel
yield annotype
elif annotype not in builtin_types and not issubclass(annotype, Enum):
raise TypeError(f"Unsupported type encountered: {annotype}")
if config.use_http_based_rpc:
close_service_client = fastapi_close_service_client
get_service_client = fastapi_get_service_client
else:
close_service_client = pyro_close_service_client
get_service_client = pyro_get_service_client
# ----- End Pyro Client Block ---- #

View File

@@ -65,6 +65,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=1000,
description="Maximum number of workers to use for node execution within a single graph.",
)
use_http_based_rpc: bool = Field(
default=True,
description="Whether to use HTTP-based RPC for communication between services.",
)
pyro_host: str = Field(
default="localhost",
description="The default hostname of the Pyro server.",