mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Support flexible RPC client (#9842)
Using sync code in the async route often introduces a blocking event-loop code that impacts stability. The current RPC system only provides a synchronous client to call the service endpoints. The scope of this PR is to provide an entirely decoupled signature between client and server, allowing the client can mix & match async & sync options on the client code while not changing the async/sync nature of the server. ### Changes 🏗️ * Add support for flexible async/sync RPC client. * Migrate scheduler client to all-async client. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Scheduler route test. - [x] Modified service_test.py - [x] Run normal agent executions
This commit is contained in:
@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import stripe
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
@@ -21,6 +20,7 @@ from prisma.types import (
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
@@ -34,8 +34,7 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.data.notifications import NotificationEventDTO, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.executor.utils import UsageTransactionMetadata
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.notifications import NotificationManagerClient
|
||||
from backend.server.model import Pagination
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -49,6 +48,17 @@ logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
@abstractmethod
|
||||
async def get_credits(self, user_id: str) -> int:
|
||||
@@ -365,21 +375,19 @@ class UserCreditBase(ABC):
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
def notification_client(self) -> NotificationManagerClient:
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
notification_type: NotificationType,
|
||||
):
|
||||
await asyncio.to_thread(
|
||||
lambda: self.notification_client().queue_notification(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
)
|
||||
await self.notification_client().queue_notification(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from .database import DatabaseManager
|
||||
from .database import DatabaseManager, DatabaseManagerClient
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
@@ -39,12 +40,14 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, exposed_run_and_wait
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
@@ -69,58 +72,107 @@ class DatabaseManager(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@staticmethod
|
||||
def _(f: Callable[P, R]) -> Callable[Concatenate[object, P], R]:
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
|
||||
get_incomplete_node_executions = exposed_run_and_wait(
|
||||
get_incomplete_node_executions
|
||||
)
|
||||
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
|
||||
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
|
||||
update_node_execution_status_batch = exposed_run_and_wait(
|
||||
update_node_execution_status_batch
|
||||
)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
|
||||
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
|
||||
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
|
||||
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
|
||||
get_graph_execution = _(get_graph_execution)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution_results = _(get_node_execution_results)
|
||||
get_incomplete_node_executions = _(get_incomplete_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
|
||||
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = exposed_run_and_wait(_spend_credits)
|
||||
spend_credits = _(_spend_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = exposed_run_and_wait(get_user_metadata)
|
||||
update_user_metadata = exposed_run_and_wait(update_user_metadata)
|
||||
get_user_integrations = exposed_run_and_wait(get_user_integrations)
|
||||
update_user_integrations = exposed_run_and_wait(update_user_integrations)
|
||||
get_user_metadata = _(get_user_metadata)
|
||||
update_user_metadata = _(update_user_metadata)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = exposed_run_and_wait(
|
||||
get_active_user_ids_in_timerange
|
||||
)
|
||||
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
|
||||
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
|
||||
get_user_notification_preference = exposed_run_and_wait(
|
||||
get_user_notification_preference
|
||||
)
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = exposed_run_and_wait(
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
|
||||
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
|
||||
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
|
||||
empty_user_notification_batch = _(empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(get_all_batches_by_type)
|
||||
get_user_notification_batch = _(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
_ = endpoint_to_sync
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Executions
|
||||
get_graph_execution = _(d.get_graph_execution)
|
||||
create_graph_execution = _(d.create_graph_execution)
|
||||
get_node_execution_results = _(d.get_node_execution_results)
|
||||
get_incomplete_node_executions = _(d.get_incomplete_node_executions)
|
||||
get_latest_node_execution = _(d.get_latest_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
get_graph = _(d.get_graph)
|
||||
get_connected_output_nodes = _(d.get_connected_output_nodes)
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = _(d.spend_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(d.get_user_metadata)
|
||||
update_user_metadata = _(d.update_user_metadata)
|
||||
get_user_integrations = _(d.get_user_integrations)
|
||||
update_user_integrations = _(d.update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
get_user_email_verification = _(d.get_user_email_verification)
|
||||
get_user_notification_preference = _(d.get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(d.empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(d.get_all_batches_by_type)
|
||||
get_user_notification_batch = _(d.get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
@@ -27,8 +27,8 @@ from backend.executor.utils import create_execution_queue_config
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
@@ -36,6 +36,7 @@ from prometheus_client import Gauge, start_http_server
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
@@ -49,7 +50,6 @@ from backend.executor.utils import (
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
UsageTransactionMetadata,
|
||||
block_usage_cost,
|
||||
execution_usage_cost,
|
||||
get_execution_event_bus,
|
||||
@@ -64,7 +64,7 @@ from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import close_service_client, get_service_client
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -135,7 +135,7 @@ ExecutionStream = Generator[NodeExecutionEntry, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
db_client: "DatabaseManager",
|
||||
db_client: "DatabaseManagerClient",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
@@ -284,7 +284,7 @@ def execute_node(
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManager",
|
||||
db_client: "DatabaseManagerClient",
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
@@ -461,7 +461,7 @@ class Executor:
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
cls.db_client.close()
|
||||
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -1064,26 +1064,22 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
log(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
@property
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_service() -> "NotificationManager":
|
||||
from backend.notifications import NotificationManager
|
||||
def get_notification_service() -> "NotificationManagerClient":
|
||||
from backend.notifications import NotificationManagerClient
|
||||
|
||||
return get_service_client(NotificationManager)
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
|
||||
|
||||
@@ -17,8 +17,14 @@ from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -59,9 +65,7 @@ def job_listener(event):
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
return get_service_client(NotificationManager)
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
@@ -159,11 +163,6 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
@@ -300,3 +299,15 @@ class Scheduler(AppService):
|
||||
),
|
||||
job,
|
||||
)
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return Scheduler
|
||||
|
||||
add_execution_schedule = endpoint_to_async(Scheduler.add_execution_schedule)
|
||||
delete_schedule = endpoint_to_async(Scheduler.delete_schedule)
|
||||
get_execution_schedules = endpoint_to_async(Scheduler.get_execution_schedules)
|
||||
add_batched_notification_schedule = Scheduler.add_batched_notification_schedule
|
||||
add_weekly_notification_schedule = Scheduler.add_weekly_notification_schedule
|
||||
|
||||
@@ -41,7 +41,7 @@ from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
@@ -82,26 +82,15 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: BlockInput | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the cost of executing a graph based on the number of executions.
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from pydantic import SecretStr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
@@ -210,11 +210,11 @@ class IntegrationCredentialsStore:
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self) -> "DatabaseManager":
|
||||
from backend.executor.database import DatabaseManager
|
||||
def db_manager(self) -> "DatabaseManagerClient":
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_integrations(user_id):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .notifications import NotificationManager
|
||||
from .notifications import NotificationManager, NotificationManagerClient
|
||||
|
||||
__all__ = [
|
||||
"NotificationManager",
|
||||
"NotificationManagerClient",
|
||||
]
|
||||
|
||||
@@ -31,7 +31,13 @@ from backend.data.notifications import (
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -108,16 +114,16 @@ def create_notification_config() -> RabbitMQConfig:
|
||||
|
||||
@thread_cached
|
||||
def get_scheduler():
|
||||
from backend.executor import Scheduler
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
|
||||
return get_service_client(Scheduler)
|
||||
return get_service_client(SchedulerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db():
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
class NotificationManager(AppService):
|
||||
@@ -774,3 +780,13 @@ class NotificationManager(AppService):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
queue_notification = endpoint_to_async(NotificationManager.queue_notification)
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
|
||||
@@ -57,7 +57,7 @@ from backend.data.user import (
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
)
|
||||
from backend.executor import Scheduler, scheduler
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
@@ -83,8 +83,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
@@ -779,14 +779,12 @@ async def create_schedule(
|
||||
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(
|
||||
lambda: execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -795,11 +793,11 @@ async def create_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def delete_schedule(
|
||||
async def delete_schedule(
|
||||
schedule_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
@@ -808,11 +806,11 @@ def delete_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def get_execution_schedules(
|
||||
async def get_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str | None = None,
|
||||
) -> list[scheduler.ExecutionJobInfo]:
|
||||
return execution_scheduler_client().get_execution_schedules(
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Protocol
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -19,7 +18,7 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,13 +45,6 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client():
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
|
||||
@@ -5,8 +5,10 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Coroutine,
|
||||
@@ -42,24 +44,15 @@ api_call_timeout = config.rpc_client_call_timeout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
EXPOSED_FLAG = "__exposed__"
|
||||
|
||||
|
||||
def expose(func: C) -> C:
|
||||
func = getattr(func, "__func__", func)
|
||||
setattr(func, "__exposed__", True)
|
||||
setattr(func, EXPOSED_FLAG, True)
|
||||
return func
|
||||
|
||||
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
# TODO:
|
||||
# This function lies about its return type to make the DynamicClient
|
||||
# call the function synchronously, fix this when DynamicClient can choose
|
||||
# to call a function synchronously or asynchronously.
|
||||
return expose(f) # type: ignore
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# AppService for IPC service based on HTTP request through FastAPI
|
||||
# --------------------------------------------------
|
||||
@@ -203,7 +196,7 @@ class AppService(BaseAppService, ABC):
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
if getattr(attr, "__exposed__", False):
|
||||
if getattr(attr, EXPOSED_FLAG, False):
|
||||
route_path = f"/{attr_name}"
|
||||
self.fastapi_app.add_api_route(
|
||||
route_path,
|
||||
@@ -234,31 +227,52 @@ class AppService(BaseAppService, ABC):
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def close_service_client(client: Any) -> None:
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
else:
|
||||
logger.warning(f"Client {client} is not closable")
|
||||
class AppServiceClient(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_service_type(cls) -> Type[AppService]:
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
|
||||
ASC = TypeVar("ASC", bound=AppServiceClient)
|
||||
|
||||
|
||||
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
|
||||
def get_service_client(
|
||||
service_type: Type[AS],
|
||||
service_client_type: Type[ASC],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
) -> AS:
|
||||
) -> ASC:
|
||||
class DynamicClient:
|
||||
def __init__(self):
|
||||
service_type = service_client_type.get_service_type()
|
||||
host = service_type.get_host()
|
||||
port = service_type.get_port()
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self.client = httpx.Client(
|
||||
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
def _call_method(self, method_name: str, **kwargs) -> Any:
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
def _handle_call_method_response(
|
||||
self, response: httpx.Response, method_name: str
|
||||
) -> Any:
|
||||
try:
|
||||
response = self.client.post(method_name, json=to_dict(kwargs))
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -269,36 +283,102 @@ def get_service_client(
|
||||
*(error.args or [str(e)])
|
||||
)
|
||||
|
||||
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
|
||||
)
|
||||
|
||||
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=await self.async_client.post(
|
||||
method_name, json=to_dict(kwargs)
|
||||
),
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
self.sync_client.close()
|
||||
await self.async_client.aclose()
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
self.sync_client.close()
|
||||
|
||||
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
|
||||
if args:
|
||||
arg_names = list(signature.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
return kwargs
|
||||
|
||||
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
# Try to get the original function from the service type.
|
||||
orig_func = getattr(service_type, name, None)
|
||||
if orig_func is None:
|
||||
raise AttributeError(f"Method {name} not found in {service_type}")
|
||||
original_func = getattr(service_client_type, name, None)
|
||||
if original_func is None:
|
||||
raise AttributeError(
|
||||
f"Method {name} not found in {service_client_type}"
|
||||
)
|
||||
else:
|
||||
name = original_func.__name__
|
||||
|
||||
sig = inspect.signature(orig_func)
|
||||
sig = inspect.signature(original_func)
|
||||
ret_ann = sig.return_annotation
|
||||
if ret_ann != inspect.Signature.empty:
|
||||
expected_return = TypeAdapter(ret_ann)
|
||||
else:
|
||||
expected_return = None
|
||||
|
||||
def method(*args, **kwargs) -> Any:
|
||||
if args:
|
||||
arg_names = list(sig.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
result = self._call_method(name, **kwargs)
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
if inspect.iscoroutinefunction(original_func):
|
||||
|
||||
return method
|
||||
async def async_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = await self._call_method_async(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
|
||||
client = cast(AS, DynamicClient())
|
||||
return async_method
|
||||
else:
|
||||
|
||||
def sync_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = self._call_method_sync(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
|
||||
return sync_method
|
||||
|
||||
client = cast(ASC, DynamicClient())
|
||||
client.health_check()
|
||||
|
||||
return cast(AS, client)
|
||||
return client
|
||||
|
||||
|
||||
def endpoint_to_sync(
|
||||
func: Callable[Concatenate[Any, P], Awaitable[R]],
|
||||
) -> Callable[Concatenate[Any, P], R]:
|
||||
"""
|
||||
Produce a *typed* stub that **looks** synchronous to the type‑checker.
|
||||
"""
|
||||
|
||||
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], R], _stub)
|
||||
|
||||
|
||||
def endpoint_to_async(
|
||||
func: Callable[Concatenate[Any, P], R],
|
||||
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
|
||||
"""
|
||||
The async mirror of `to_sync`.
|
||||
"""
|
||||
|
||||
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)
|
||||
|
||||
@@ -6,10 +6,10 @@ from prisma.models import CreditTransaction
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
from backend.data.credit import BetaUserCredit
|
||||
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import openai_credentials
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from backend.data import db
|
||||
from backend.executor import Scheduler
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.service import get_service_client
|
||||
@@ -17,11 +17,11 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
scheduler = get_service_client(Scheduler)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
scheduler = get_service_client(SchedulerClient)
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
schedule = scheduler.add_execution_schedule(
|
||||
schedule = await scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
graph_version=1,
|
||||
@@ -30,10 +30,12 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
)
|
||||
assert schedule
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[0].cron == "0 0 * * *"
|
||||
|
||||
scheduler.delete_schedule(schedule.id, user_id=test_user.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, user_id=test_user.id)
|
||||
await scheduler.delete_schedule(schedule.id, user_id=test_user.id)
|
||||
schedules = await scheduler.get_execution_schedules(
|
||||
test_graph.id, user_id=test_user.id
|
||||
)
|
||||
assert len(schedules) == 0
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
@@ -32,10 +38,25 @@ class ServiceTest(AppService):
|
||||
return self.run_and_wait(add_async(a, b))
|
||||
|
||||
|
||||
class ServiceTestClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return ServiceTest
|
||||
|
||||
add = ServiceTest.add
|
||||
subtract = ServiceTest.subtract
|
||||
fun_with_async = ServiceTest.fun_with_async
|
||||
|
||||
add_async = endpoint_to_async(ServiceTest.add)
|
||||
subtract_async = endpoint_to_async(ServiceTest.subtract)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_service_creation(server):
|
||||
with ServiceTest():
|
||||
client = get_service_client(ServiceTest)
|
||||
client = get_service_client(ServiceTestClient)
|
||||
assert client.add(5, 3) == 8
|
||||
assert client.subtract(10, 4) == 6
|
||||
assert client.fun_with_async(5, 3) == 8
|
||||
assert await client.add_async(5, 3) == 8
|
||||
assert await client.subtract_async(10, 4) == 6
|
||||
|
||||
Reference in New Issue
Block a user