mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(backend): Un-share resource initializations from AppService + Remove Pyro (#9750)
This is a prerequisite infra change for https://github.com/Significant-Gravitas/AutoGPT/issues/9714. We will need a service where we can maintain our own client (db, redis, rabbitmq, be it async/sync) and configure our own cadence of initialization and cleanup. While refactoring the service.py, an option to use Pyro as an RPC protocol is also removed. ### Changes 🏗️ * Decouple resource initialization and cleanup from the parent AppService logic. * Removed Pyro. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] CI
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
import logging
|
||||
|
||||
from backend.data import db, redis
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
GraphExecution,
|
||||
@@ -44,6 +47,7 @@ from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
@@ -55,10 +59,22 @@ async def _spend_credits(
|
||||
class DatabaseManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.execution_event_bus = RedisExecutionEventBus()
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@@ -930,8 +930,6 @@ class Executor:
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
@@ -944,14 +942,17 @@ class ExecutionManager(AppService):
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
self.credentials_store = IntegrationCredentialsStore()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"[{self.service_name}] Started with max-{self.pool_size} graph workers"
|
||||
)
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
@@ -968,10 +969,13 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
|
||||
super().cleanup()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
|
||||
@property
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
|
||||
@@ -206,6 +206,12 @@ class Scheduler(AppService):
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.start()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_execution_schedule(
|
||||
self,
|
||||
|
||||
@@ -9,6 +9,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import rabbitmq
|
||||
from backend.data.notifications import (
|
||||
BaseSummaryData,
|
||||
BaseSummaryParams,
|
||||
@@ -128,6 +129,20 @@ class NotificationManager(AppService):
|
||||
self.running = True
|
||||
self.email_sender = EmailSender()
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.notification_service_port
|
||||
@@ -688,10 +703,14 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
def run_service(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
self.run_and_wait(self.rabbitmq_service.connect())
|
||||
|
||||
logger.info(f"[{self.service_name}] Started notification service")
|
||||
|
||||
# Set up scheduler for batch processing of all notification types
|
||||
# this can be changed later to spawn differnt cleanups on different schedules
|
||||
# this can be changed later to spawn different cleanups on different schedules
|
||||
try:
|
||||
get_scheduler().add_batched_notification_schedule(
|
||||
notification_types=list(NotificationType),
|
||||
@@ -753,3 +772,5 @@ class NotificationManager(AppService):
|
||||
"""Cleanup service resources"""
|
||||
self.running = False
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
@@ -145,6 +145,10 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
log_config=generate_uvicorn_config(),
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str,
|
||||
|
||||
@@ -10,7 +10,6 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
@@ -56,15 +55,12 @@ def get_db_client():
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
redis.connect()
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
except Exception as e:
|
||||
logger.exception(f"Event broadcaster error: {e}")
|
||||
raise
|
||||
finally:
|
||||
redis.disconnect()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
@@ -294,3 +290,7 @@ class WebsocketServer(AppProcess):
|
||||
port=Config().websocket_server_port,
|
||||
log_config=generate_uvicorn_config(),
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")
|
||||
|
||||
@@ -48,6 +48,7 @@ class AppProcess(ABC):
|
||||
def service_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
|
||||
@@ -1,52 +1,33 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Coroutine,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import Pyro5.api
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, responses
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
from Pyro5 import api as pyro
|
||||
from Pyro5 import config as pyro_config
|
||||
|
||||
from backend.data import db, rabbitmq, redis
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config, Secrets
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
@@ -57,21 +38,18 @@ api_host = config.pyro_host
|
||||
api_comm_retry = config.pyro_client_comm_retry
|
||||
api_comm_timeout = config.pyro_client_comm_timeout
|
||||
api_call_timeout = config.rpc_client_call_timeout
|
||||
pyro_config.MAX_RETRIES = api_comm_retry # type: ignore
|
||||
pyro_config.COMMTIMEOUT = api_comm_timeout # type: ignore
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def fastapi_expose(func: C) -> C:
|
||||
def expose(func: C) -> C:
|
||||
func = getattr(func, "__func__", func)
|
||||
setattr(func, "__exposed__", True)
|
||||
return func
|
||||
|
||||
|
||||
def fastapi_exposed_run_and_wait(
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
# TODO:
|
||||
@@ -81,107 +59,11 @@ def fastapi_exposed_run_and_wait(
|
||||
return expose(f) # type: ignore
|
||||
|
||||
|
||||
# ----- Begin Pyro Expose Block ---- #
|
||||
def pyro_expose(func: C) -> C:
|
||||
"""
|
||||
Decorator to mark a method or class to be exposed for remote calls.
|
||||
|
||||
## ⚠️ Gotcha
|
||||
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
|
||||
Any other passed or returned class objects are converted to dictionaries by Pyro.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
msg = f"Error in {func.__name__}: {e}"
|
||||
if isinstance(e, ValueError):
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
register_pydantic_serializers(func)
|
||||
|
||||
return pyro.expose(wrapper) # type: ignore
|
||||
|
||||
|
||||
def register_pydantic_serializers(func: Callable):
|
||||
"""Register custom serializers and deserializers for annotated Pydantic models"""
|
||||
for name, annotation in func.__annotations__.items():
|
||||
try:
|
||||
pydantic_types = _pydantic_models_from_type_annotation(annotation)
|
||||
except Exception as e:
|
||||
raise TypeError(f"Error while exposing {func.__name__}: {e}")
|
||||
|
||||
for model in pydantic_types:
|
||||
logger.debug(
|
||||
f"Registering Pyro (de)serializers for {func.__name__} annotation "
|
||||
f"'{name}': {model.__qualname__}"
|
||||
)
|
||||
pyro.register_class_to_dict(model, _make_custom_serializer(model))
|
||||
pyro.register_dict_to_class(
|
||||
model.__qualname__, _make_custom_deserializer(model)
|
||||
)
|
||||
|
||||
|
||||
def _make_custom_serializer(model: Type[BaseModel]):
|
||||
def custom_class_to_dict(obj):
|
||||
data = {
|
||||
"__class__": obj.__class__.__qualname__,
|
||||
**obj.model_dump(),
|
||||
}
|
||||
logger.debug(f"Serializing {obj.__class__.__qualname__} with data: {data}")
|
||||
return data
|
||||
|
||||
return custom_class_to_dict
|
||||
|
||||
|
||||
def _make_custom_deserializer(model: Type[BaseModel]):
|
||||
def custom_dict_to_class(qualname, data: dict):
|
||||
logger.debug(f"Deserializing {model.__qualname__} from data: {data}")
|
||||
return model(**data)
|
||||
|
||||
return custom_dict_to_class
|
||||
|
||||
|
||||
def pyro_exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
# Register serializers for annotations on bare function
|
||||
register_pydantic_serializers(f)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
expose = fastapi_expose
|
||||
exposed_run_and_wait = fastapi_exposed_run_and_wait
|
||||
else:
|
||||
expose = pyro_expose
|
||||
exposed_run_and_wait = pyro_exposed_run_and_wait
|
||||
|
||||
# ----- End Pyro Expose Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# AppService for IPC service based on HTTP request through FastAPI
|
||||
# --------------------------------------------------
|
||||
class BaseAppService(AppProcess, ABC):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
rabbitmq_config: Optional[rabbitmq.RabbitMQConfig] = None
|
||||
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
|
||||
use_supabase: bool = False
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -202,20 +84,6 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
return target_host
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
def run_service(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
@@ -225,31 +93,6 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
if self.use_db:
|
||||
self.shared_event_loop.run_until_complete(db.connect())
|
||||
if self.use_redis:
|
||||
redis.connect()
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
|
||||
if self.use_supabase:
|
||||
from supabase import create_client
|
||||
|
||||
secrets = Secrets()
|
||||
self.supabase = create_client(
|
||||
secrets.supabase_url, secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
if self.use_redis:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
|
||||
|
||||
|
||||
class RemoteCallError(BaseModel):
|
||||
@@ -268,7 +111,7 @@ EXCEPTION_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class FastApiAppService(BaseAppService, ABC):
|
||||
class AppService(BaseAppService, ABC):
|
||||
fastapi_app: FastAPI
|
||||
|
||||
@staticmethod
|
||||
@@ -383,62 +226,13 @@ class FastApiAppService(BaseAppService, ABC):
|
||||
self.run_service()
|
||||
|
||||
|
||||
# ----- Begin Pyro AppService Block ---- #
|
||||
|
||||
|
||||
class PyroAppService(BaseAppService, ABC):
|
||||
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
maximum_connection_thread_count = max(
|
||||
Pyro5.config.THREADPOOL_SIZE,
|
||||
config.num_node_workers * config.num_graph_workers,
|
||||
)
|
||||
|
||||
Pyro5.config.THREADPOOL_SIZE = maximum_connection_thread_count # type: ignore
|
||||
daemon = Pyro5.api.Daemon(host=api_host, port=self.get_port())
|
||||
self.uri = daemon.register(self, objectId=self.service_name)
|
||||
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
|
||||
daemon.requestLoop()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
# Initialize the async loop.
|
||||
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
|
||||
async_thread.daemon = True
|
||||
async_thread.start()
|
||||
|
||||
# Initialize pyro service
|
||||
daemon_thread = threading.Thread(target=self.__start_pyro)
|
||||
daemon_thread.daemon = True
|
||||
daemon_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
self.run_service()
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
|
||||
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
class AppService(PyroAppService, ABC):
|
||||
pass
|
||||
|
||||
|
||||
# ----- End Pyro AppService Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# HTTP Client utilities for dynamic service client abstraction
|
||||
# --------------------------------------------------
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def fastapi_close_service_client(client: Any) -> None:
|
||||
def close_service_client(client: Any) -> None:
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
else:
|
||||
@@ -446,7 +240,7 @@ def fastapi_close_service_client(client: Any) -> None:
|
||||
|
||||
|
||||
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
|
||||
def fastapi_get_service_client(
|
||||
def get_service_client(
|
||||
service_type: Type[AS],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
) -> AS:
|
||||
@@ -506,93 +300,3 @@ def fastapi_get_service_client(
|
||||
client.health_check()
|
||||
|
||||
return cast(AS, client)
|
||||
|
||||
|
||||
# ----- Begin Pyro Client Block ---- #
|
||||
class PyroClient:
|
||||
proxy: Pyro5.api.Proxy
|
||||
|
||||
|
||||
def pyro_close_service_client(client: BaseAppService) -> None:
|
||||
if isinstance(client, PyroClient):
|
||||
client.proxy._pyroRelease()
|
||||
else:
|
||||
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
|
||||
|
||||
|
||||
def pyro_get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient(PyroClient):
|
||||
@conn_retry("Pyro", f"Connecting to [{service_name}]")
|
||||
def __init__(self):
|
||||
uri = f"PYRO:{service_type.service_name}@{service_type.get_host()}:{service_type.get_port()}"
|
||||
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
|
||||
self.proxy = Pyro5.api.Proxy(uri)
|
||||
# Attempt to bind to ensure the connection is established
|
||||
self.proxy._pyroBind()
|
||||
logger.debug(f"Successfully connected to service [{service_name}]")
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
res = getattr(self.proxy, name)
|
||||
return res
|
||||
|
||||
return cast(AS, DynamicClient())
|
||||
|
||||
|
||||
builtin_types = [*vars(builtins).values(), NoneType, Enum]
|
||||
|
||||
|
||||
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
|
||||
# Peel Annotated parameters
|
||||
if (origin := get_origin(annotation)) and origin is Annotated:
|
||||
annotation = get_args(annotation)[0]
|
||||
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in (
|
||||
Union,
|
||||
UnionType,
|
||||
list,
|
||||
List,
|
||||
tuple,
|
||||
Tuple,
|
||||
set,
|
||||
Set,
|
||||
frozenset,
|
||||
FrozenSet,
|
||||
):
|
||||
for arg in args:
|
||||
yield from _pydantic_models_from_type_annotation(arg)
|
||||
elif origin in (dict, Dict):
|
||||
key_type, value_type = args
|
||||
yield from _pydantic_models_from_type_annotation(key_type)
|
||||
yield from _pydantic_models_from_type_annotation(value_type)
|
||||
elif origin in (Awaitable, Coroutine):
|
||||
# For coroutines and awaitables, check the return type
|
||||
return_type = args[-1]
|
||||
yield from _pydantic_models_from_type_annotation(return_type)
|
||||
else:
|
||||
annotype = annotation if origin is None else origin
|
||||
|
||||
# Exclude generic types and aliases
|
||||
if (
|
||||
annotype is not None
|
||||
and not hasattr(typing, getattr(annotype, "__name__", ""))
|
||||
and isinstance(annotype, type)
|
||||
):
|
||||
if issubclass(annotype, BaseModel):
|
||||
yield annotype
|
||||
elif annotype not in builtin_types and not issubclass(annotype, Enum):
|
||||
raise TypeError(f"Unsupported type encountered: {annotype}")
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
close_service_client = fastapi_close_service_client
|
||||
get_service_client = fastapi_get_service_client
|
||||
else:
|
||||
close_service_client = pyro_close_service_client
|
||||
get_service_client = pyro_get_service_client
|
||||
|
||||
# ----- End Pyro Client Block ---- #
|
||||
|
||||
@@ -65,10 +65,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for node execution within a single graph.",
|
||||
)
|
||||
use_http_based_rpc: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use HTTP-based RPC for communication between services.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
|
||||
29
autogpt_platform/backend/poetry.lock
generated
29
autogpt_platform/backend/poetry.lock
generated
@@ -4098,21 +4098,6 @@ all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pyro5"
|
||||
version = "5.15"
|
||||
description = "Remote object communication library, fifth major version"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "Pyro5-5.15-py3-none-any.whl", hash = "sha256:4d85428ed75985e63f159d2486ad5680743ea76f766340fd30b65dd20f83d471"},
|
||||
{file = "Pyro5-5.15.tar.gz", hash = "sha256:82c3dfc9860b49f897b28ff24fe6716c841672c600af8fe40d0e3a7fac9a3f5e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
serpent = ">=1.41"
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.3.5"
|
||||
@@ -4963,18 +4948,6 @@ statsig = ["statsig (>=0.55.3)"]
|
||||
tornado = ["tornado (>=6)"]
|
||||
unleash = ["UnleashClient (>=6.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "serpent"
|
||||
version = "1.41"
|
||||
description = "Serialization based on ast.literal_eval"
|
||||
optional = false
|
||||
python-versions = ">=3.2"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "serpent-1.41-py3-none-any.whl", hash = "sha256:5fd776b3420441985bc10679564c2c9b4a19f77bea59f018e473441d98ae5dd7"},
|
||||
{file = "serpent-1.41.tar.gz", hash = "sha256:0407035fe3c6644387d48cff1467d5aa9feff814d07372b78677ed0ee3ed7095"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "75.8.0"
|
||||
@@ -6337,4 +6310,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "648b5a17a37ea000d19d74a65a4b9aaa8f5c6bf73b05f13fc202648cb2b287f0"
|
||||
content-hash = "781f77ec77cfce78b34fb57063dcc81df8e9c5a4be9a644033a0c197e0063730"
|
||||
|
||||
@@ -45,7 +45,6 @@ psutil = "^7.0.0"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
pydantic = { extras = ["email"], version = "^2.11.1" }
|
||||
pydantic-settings = "^2.8.1"
|
||||
pyro5 = "^5.15"
|
||||
pytest = "^8.3.5"
|
||||
pytest-asyncio = "^0.26.0"
|
||||
python-dotenv = "^1.1.0"
|
||||
|
||||
@@ -9,6 +9,9 @@ class ServiceTest(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return TEST_SERVICE_PORT
|
||||
|
||||
Reference in New Issue
Block a user