feat(backend): Support flexible RPC client (#9842)

Using sync code in the async route often introduces a blocking
event-loop code that impacts stability.

The current RPC system only provides a synchronous client to call the
service endpoints.
The scope of this PR is to provide an entirely decoupled signature
between client and server, allowing the client can mix & match async &
sync options on the client code while not changing the async/sync nature
of the server.

### Changes 🏗️

* Add support for flexible async/sync RPC client.
* Migrate scheduler client to all-async client.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Scheduler route test.
  - [x] Modified service_test.py
  - [x] Run normal agent executions
This commit is contained in:
Zamil Majdy
2025-05-01 11:38:06 +07:00
committed by GitHub
parent 602f887623
commit 86d5cfe60b
16 changed files with 353 additions and 186 deletions

View File

@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManager
from backend.executor import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:

View File

@@ -1,9 +1,8 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import cast
from typing import Any, cast
import stripe
from autogpt_libs.utils.cache import thread_cached
@@ -21,6 +20,7 @@ from prisma.types import (
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from pydantic import BaseModel
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
@@ -34,8 +34,7 @@ from backend.data.model import (
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.notifications import NotificationManagerClient
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
@@ -49,6 +48,17 @@ logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: dict[str, Any] | None = None
reason: str | None = None
class UserCreditBase(ABC):
@abstractmethod
async def get_credits(self, user_id: str) -> int:
@@ -365,21 +375,19 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def notification_client(self) -> NotificationManagerClient:
return get_service_client(NotificationManagerClient)
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await asyncio.to_thread(
lambda: self.notification_client().queue_notification(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
)
await self.notification_client().queue_notification(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
)
)

View File

@@ -1,9 +1,10 @@
from .database import DatabaseManager
from .database import DatabaseManager, DatabaseManagerClient
from .manager import ExecutionManager
from .scheduler import Scheduler
__all__ = [
"DatabaseManager",
"DatabaseManagerClient",
"ExecutionManager",
"Scheduler",
]

View File

@@ -1,4 +1,5 @@
import logging
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
@@ -39,12 +40,14 @@ from backend.data.user import (
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, exposed_run_and_wait
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.settings import Config
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
async def _spend_credits(
@@ -69,58 +72,107 @@ class DatabaseManager(AppService):
def get_port(cls) -> int:
return config.database_api_port
@staticmethod
def _(f: Callable[P, R]) -> Callable[Concatenate[object, P], R]:
return cast(Callable[Concatenate[object, P], R], expose(f))
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(
get_incomplete_node_executions
)
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
update_node_execution_status_batch = exposed_run_and_wait(
update_node_execution_status_batch
)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
get_graph_execution = _(get_graph_execution)
create_graph_execution = _(create_graph_execution)
get_node_execution_results = _(get_node_execution_results)
get_incomplete_node_executions = _(get_incomplete_node_executions)
get_latest_node_execution = _(get_latest_node_execution)
update_node_execution_status = _(update_node_execution_status)
update_node_execution_status_batch = _(update_node_execution_status_batch)
update_graph_execution_start_time = _(update_graph_execution_start_time)
update_graph_execution_stats = _(update_graph_execution_stats)
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
# Graphs
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
get_node = _(get_node)
get_graph = _(get_graph)
get_connected_output_nodes = _(get_connected_output_nodes)
get_graph_metadata = _(get_graph_metadata)
# Credits
spend_credits = exposed_run_and_wait(_spend_credits)
spend_credits = _(_spend_credits)
# User + User Metadata + User Integrations
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
get_user_metadata = _(get_user_metadata)
update_user_metadata = _(update_user_metadata)
get_user_integrations = _(get_user_integrations)
update_user_integrations = _(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch = _(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
empty_user_notification_batch = _(empty_user_notification_batch)
get_all_batches_by_type = _(get_all_batches_by_type)
get_user_notification_batch = _(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
get_user_notification_oldest_message_in_batch
)
class DatabaseManagerClient(AppServiceClient):
d = DatabaseManager
_ = endpoint_to_sync
@classmethod
def get_service_type(cls):
return DatabaseManager
# Executions
get_graph_execution = _(d.get_graph_execution)
create_graph_execution = _(d.create_graph_execution)
get_node_execution_results = _(d.get_node_execution_results)
get_incomplete_node_executions = _(d.get_incomplete_node_executions)
get_latest_node_execution = _(d.get_latest_node_execution)
update_node_execution_status = _(d.update_node_execution_status)
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
update_graph_execution_stats = _(d.update_graph_execution_stats)
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
# Graphs
get_node = _(d.get_node)
get_graph = _(d.get_graph)
get_connected_output_nodes = _(d.get_connected_output_nodes)
get_graph_metadata = _(d.get_graph_metadata)
# Credits
spend_credits = _(d.spend_credits)
# User + User Metadata + User Integrations
get_user_metadata = _(d.get_user_metadata)
update_user_metadata = _(d.update_user_metadata)
get_user_integrations = _(d.get_user_integrations)
update_user_integrations = _(d.update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
get_user_email_by_id = _(d.get_user_email_by_id)
get_user_email_verification = _(d.get_user_email_verification)
get_user_notification_preference = _(d.get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = _(
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(d.empty_user_notification_batch)
get_all_batches_by_type = _(d.get_all_batches_by_type)
get_user_notification_batch = _(d.get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
d.get_user_notification_oldest_message_in_batch
)

View File

@@ -27,8 +27,8 @@ from backend.executor.utils import create_execution_queue_config
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.notifications.notifications import NotificationManager
from backend.executor import DatabaseManagerClient
from backend.notifications.notifications import NotificationManagerClient
from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
@@ -36,6 +36,7 @@ from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
@@ -49,7 +50,6 @@ from backend.executor.utils import (
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_QUEUE_NAME,
CancelExecutionEvent,
UsageTransactionMetadata,
block_usage_cost,
execution_usage_cost,
get_execution_event_bus,
@@ -64,7 +64,7 @@ from backend.util.file import clean_exec_files
from backend.util.logging import configure_logging
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import func_retry
from backend.util.service import close_service_client, get_service_client
from backend.util.service import get_service_client
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -135,7 +135,7 @@ ExecutionStream = Generator[NodeExecutionEntry, None, None]
def execute_node(
db_client: "DatabaseManager",
db_client: "DatabaseManagerClient",
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
@@ -284,7 +284,7 @@ def execute_node(
def _enqueue_next_nodes(
db_client: "DatabaseManager",
db_client: "DatabaseManagerClient",
node: Node,
output: BlockData,
user_id: str,
@@ -461,7 +461,7 @@ class Executor:
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
cls.db_client.close()
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
sys.exit(0)
@@ -1064,26 +1064,22 @@ class ExecutionManager(AppProcess):
log(f"{prefix} ✅ Finished GraphExec cleanup")
@property
def db_client(self) -> "DatabaseManager":
return get_db_client()
# ------- UTILITIES ------- #
@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
@thread_cached
def get_notification_service() -> "NotificationManager":
from backend.notifications import NotificationManager
def get_notification_service() -> "NotificationManagerClient":
from backend.notifications import NotificationManagerClient
return get_service_client(NotificationManager)
return get_service_client(NotificationManagerClient)
def send_execution_update(entry: GraphExecution | NodeExecutionResult):

View File

@@ -17,8 +17,14 @@ from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.notifications.notifications import NotificationManagerClient
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.settings import Config
@@ -59,9 +65,7 @@ def job_listener(event):
@thread_cached
def get_notification_client():
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
return get_service_client(NotificationManagerClient)
def execute_graph(**kwargs):
@@ -159,11 +163,6 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
@property
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def run_service(self):
load_dotenv()
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
@@ -300,3 +299,15 @@ class Scheduler(AppService):
),
job,
)
class SchedulerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return Scheduler
add_execution_schedule = endpoint_to_async(Scheduler.add_execution_schedule)
delete_schedule = endpoint_to_async(Scheduler.delete_schedule)
get_execution_schedules = endpoint_to_async(Scheduler.get_execution_schedules)
add_batched_notification_schedule = Scheduler.add_batched_notification_schedule
add_weekly_notification_schedule = Scheduler.add_weekly_notification_schedule

View File

@@ -41,7 +41,7 @@ from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.executor import DatabaseManagerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
@@ -82,26 +82,15 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
# ============ Execution Cost Helpers ============ #
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
reason: str | None = None
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the number of executions.

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
if TYPE_CHECKING:
from backend.executor.database import DatabaseManager
from backend.executor.database import DatabaseManagerClient
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import RedisKeyedMutex
@@ -210,11 +210,11 @@ class IntegrationCredentialsStore:
@property
@thread_cached
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
def db_manager(self) -> "DatabaseManagerClient":
from backend.executor.database import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_integrations(user_id):

View File

@@ -1,5 +1,6 @@
from .notifications import NotificationManager
from .notifications import NotificationManager, NotificationManagerClient
__all__ = [
"NotificationManager",
"NotificationManagerClient",
]

View File

@@ -31,7 +31,13 @@ from backend.data.notifications import (
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.service import AppService, expose, get_service_client
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -108,16 +114,16 @@ def create_notification_config() -> RabbitMQConfig:
@thread_cached
def get_scheduler():
from backend.executor import Scheduler
from backend.executor.scheduler import SchedulerClient
return get_service_client(Scheduler)
return get_service_client(SchedulerClient)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManager
from backend.executor.database import DatabaseManagerClient
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
class NotificationManager(AppService):
@@ -774,3 +780,13 @@ class NotificationManager(AppService):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
class NotificationManagerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return NotificationManager
queue_notification = endpoint_to_async(NotificationManager.queue_notification)
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary

View File

@@ -57,7 +57,7 @@ from backend.data.user import (
update_user_email,
update_user_notification_preference,
)
from backend.executor import Scheduler, scheduler
from backend.executor import scheduler
from backend.executor import utils as execution_utils
from backend.executor.utils import create_execution_queue_config
from backend.integrations.creds_manager import IntegrationCredentialsManager
@@ -83,8 +83,8 @@ if TYPE_CHECKING:
@thread_cached
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
def execution_scheduler_client() -> scheduler.SchedulerClient:
return get_service_client(scheduler.SchedulerClient)
@thread_cached
@@ -779,14 +779,12 @@ async def create_schedule(
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
)
return await asyncio.to_thread(
lambda: execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
return await execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
@@ -795,11 +793,11 @@ async def create_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
def delete_schedule(
async def delete_schedule(
schedule_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
return {"id": schedule_id}
@@ -808,11 +806,11 @@ def delete_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
def get_execution_schedules(
async def get_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str | None = None,
) -> list[scheduler.ExecutionJobInfo]:
return execution_scheduler_client().get_execution_schedules(
return await execution_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)

View File

@@ -6,7 +6,6 @@ from typing import Protocol
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.logging.utils import generate_uvicorn_config
from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
@@ -19,7 +18,7 @@ from backend.server.model import (
WSSubscribeGraphExecutionRequest,
WSSubscribeGraphExecutionsRequest,
)
from backend.util.service import AppProcess, get_service_client
from backend.util.service import AppProcess
from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__)
@@ -46,13 +45,6 @@ def get_connection_manager():
return _connection_manager
@thread_cached
def get_db_client():
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
async def event_broadcaster(manager: ConnectionManager):
try:
event_queue = AsyncRedisExecutionEventBus()

View File

@@ -5,8 +5,10 @@ import os
import threading
import time
from abc import ABC, abstractmethod
from functools import cached_property, update_wrapper
from typing import (
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
@@ -42,24 +44,15 @@ api_call_timeout = config.rpc_client_call_timeout
P = ParamSpec("P")
R = TypeVar("R")
EXPOSED_FLAG = "__exposed__"
def expose(func: C) -> C:
func = getattr(func, "__func__", func)
setattr(func, "__exposed__", True)
setattr(func, EXPOSED_FLAG, True)
return func
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
# TODO:
# This function lies about its return type to make the DynamicClient
# call the function synchronously, fix this when DynamicClient can choose
# to call a function synchronously or asynchronously.
return expose(f) # type: ignore
# --------------------------------------------------
# AppService for IPC service based on HTTP request through FastAPI
# --------------------------------------------------
@@ -203,7 +196,7 @@ class AppService(BaseAppService, ABC):
# Register the exposed API routes.
for attr_name, attr in vars(type(self)).items():
if getattr(attr, "__exposed__", False):
if getattr(attr, EXPOSED_FLAG, False):
route_path = f"/{attr_name}"
self.fastapi_app.add_api_route(
route_path,
@@ -234,31 +227,52 @@ class AppService(BaseAppService, ABC):
AS = TypeVar("AS", bound=AppService)
def close_service_client(client: Any) -> None:
if hasattr(client, "close"):
client.close()
else:
logger.warning(f"Client {client} is not closable")
class AppServiceClient(ABC):
@classmethod
@abstractmethod
def get_service_type(cls) -> Type[AppService]:
pass
def health_check(self):
pass
def close(self):
pass
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
ASC = TypeVar("ASC", bound=AppServiceClient)
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
def get_service_client(
service_type: Type[AS],
service_client_type: Type[ASC],
call_timeout: int | None = api_call_timeout,
) -> AS:
) -> ASC:
class DynamicClient:
def __init__(self):
service_type = service_client_type.get_service_type()
host = service_type.get_host()
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
self.client = httpx.Client(
@cached_property
def sync_client(self) -> httpx.Client:
return httpx.Client(
base_url=self.base_url,
timeout=call_timeout,
)
def _call_method(self, method_name: str, **kwargs) -> Any:
@cached_property
def async_client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(
base_url=self.base_url,
timeout=call_timeout,
)
def _handle_call_method_response(
self, response: httpx.Response, method_name: str
) -> Any:
try:
response = self.client.post(method_name, json=to_dict(kwargs))
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
@@ -269,36 +283,102 @@ def get_service_client(
*(error.args or [str(e)])
)
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
)
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=await self.async_client.post(
method_name, json=to_dict(kwargs)
),
)
async def aclose(self):
self.sync_client.close()
await self.async_client.aclose()
def close(self):
self.client.close()
self.sync_client.close()
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
if args:
arg_names = list(signature.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
return kwargs
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
if expected_return:
return expected_return.validate_python(result)
return result
def __getattr__(self, name: str) -> Callable[..., Any]:
# Try to get the original function from the service type.
orig_func = getattr(service_type, name, None)
if orig_func is None:
raise AttributeError(f"Method {name} not found in {service_type}")
original_func = getattr(service_client_type, name, None)
if original_func is None:
raise AttributeError(
f"Method {name} not found in {service_client_type}"
)
else:
name = original_func.__name__
sig = inspect.signature(orig_func)
sig = inspect.signature(original_func)
ret_ann = sig.return_annotation
if ret_ann != inspect.Signature.empty:
expected_return = TypeAdapter(ret_ann)
else:
expected_return = None
def method(*args, **kwargs) -> Any:
if args:
arg_names = list(sig.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
result = self._call_method(name, **kwargs)
if expected_return:
return expected_return.validate_python(result)
return result
if inspect.iscoroutinefunction(original_func):
return method
async def async_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = await self._call_method_async(name, **params)
return self._get_return(expected_return, result)
client = cast(AS, DynamicClient())
return async_method
else:
def sync_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = self._call_method_sync(name, **params)
return self._get_return(expected_return, result)
return sync_method
client = cast(ASC, DynamicClient())
client.health_check()
return cast(AS, client)
return client
def endpoint_to_sync(
func: Callable[Concatenate[Any, P], Awaitable[R]],
) -> Callable[Concatenate[Any, P], R]:
"""
Produce a *typed* stub that **looks** synchronous to the typechecker.
"""
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], R], _stub)
def endpoint_to_async(
func: Callable[Concatenate[Any, P], R],
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
"""
The async mirror of `to_sync`.
"""
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)

View File

@@ -6,10 +6,10 @@ from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
from backend.data.credit import BetaUserCredit
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
from backend.data.execution import NodeExecutionEntry
from backend.data.user import DEFAULT_USER_ID
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
from backend.executor.utils import block_usage_cost
from backend.integrations.credentials_store import openai_credentials
from backend.util.test import SpinTestServer

View File

@@ -1,7 +1,7 @@
import pytest
from backend.data import db
from backend.executor import Scheduler
from backend.executor.scheduler import SchedulerClient
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
@@ -17,11 +17,11 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)
scheduler = get_service_client(Scheduler)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
scheduler = get_service_client(SchedulerClient)
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0
schedule = scheduler.add_execution_schedule(
schedule = await scheduler.add_execution_schedule(
graph_id=test_graph.id,
user_id=test_user.id,
graph_version=1,
@@ -30,10 +30,12 @@ async def test_agent_schedule(server: SpinTestServer):
)
assert schedule
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 1
assert schedules[0].cron == "0 0 * * *"
scheduler.delete_schedule(schedule.id, user_id=test_user.id)
schedules = scheduler.get_execution_schedules(test_graph.id, user_id=test_user.id)
await scheduler.delete_schedule(schedule.id, user_id=test_user.id)
schedules = await scheduler.get_execution_schedules(
test_graph.id, user_id=test_user.id
)
assert len(schedules) == 0

View File

@@ -1,6 +1,12 @@
import pytest
from backend.util.service import AppService, expose, get_service_client
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
TEST_SERVICE_PORT = 8765
@@ -32,10 +38,25 @@ class ServiceTest(AppService):
return self.run_and_wait(add_async(a, b))
class ServiceTestClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return ServiceTest
add = ServiceTest.add
subtract = ServiceTest.subtract
fun_with_async = ServiceTest.fun_with_async
add_async = endpoint_to_async(ServiceTest.add)
subtract_async = endpoint_to_async(ServiceTest.subtract)
@pytest.mark.asyncio(loop_scope="session")
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTest)
client = get_service_client(ServiceTestClient)
assert client.add(5, 3) == 8
assert client.subtract(10, 4) == 6
assert client.fun_with_async(5, 3) == 8
assert await client.add_async(5, 3) == 8
assert await client.subtract_async(10, 4) == 6