mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Make Redis connection Sync + Use Redis as Distributed Lock (#8197)
This commit is contained in:
16
.github/workflows/platform-backend-ci.yml
vendored
16
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,6 +32,14 @@ jobs:
|
||||
python-version: ["3.10"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -96,9 +104,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
else
|
||||
poetry run pytest -vv test
|
||||
poetry run pytest -s -vv test
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -107,6 +115,10 @@ jobs:
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -8,40 +7,30 @@ from dotenv import load_dotenv
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
|
||||
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
|
||||
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
|
||||
|
||||
prisma, conn_id = Prisma(auto_register=True), ""
|
||||
prisma = Prisma(auto_register=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def connect(call_count=0):
|
||||
global conn_id
|
||||
if not conn_id:
|
||||
conn_id = str(uuid4())
|
||||
|
||||
try:
|
||||
logger.info(f"[Prisma-{conn_id}] Acquiring connection..")
|
||||
if not prisma.is_connected():
|
||||
await prisma.connect()
|
||||
logger.info(f"[Prisma-{conn_id}] Connection acquired!")
|
||||
except Exception as e:
|
||||
if call_count <= 5:
|
||||
logger.info(f"[Prisma-{conn_id}] Connection failed: {e}. Retrying now..")
|
||||
await asyncio.sleep(2**call_count)
|
||||
await connect(call_count + 1)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
async def disconnect():
|
||||
@conn_retry("Prisma", "Acquiring connection")
|
||||
async def connect():
|
||||
if prisma.is_connected():
|
||||
logger.info(f"[Prisma-{conn_id}] Releasing connection.")
|
||||
await prisma.disconnect()
|
||||
logger.info(f"[Prisma-{conn_id}] Connection released.")
|
||||
return
|
||||
await prisma.connect()
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
async def disconnect():
|
||||
if not prisma.is_connected():
|
||||
return
|
||||
await prisma.disconnect()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import ExecutionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,60 +16,46 @@ class DateTimeEncoder(json.JSONEncoder):
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class AsyncEventQueue(ABC):
|
||||
class AbstractEventQueue(ABC):
|
||||
@abstractmethod
|
||||
async def connect(self):
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def put(self, execution_result: ExecutionResult):
|
||||
def put(self, execution_result: ExecutionResult):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get(self) -> ExecutionResult | None:
|
||||
def get(self) -> ExecutionResult | None:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncRedisEventQueue(AsyncEventQueue):
|
||||
class RedisEventQueue(AbstractEventQueue):
|
||||
def __init__(self):
|
||||
self.host = os.getenv("REDIS_HOST", "localhost")
|
||||
self.port = int(os.getenv("REDIS_PORT", "6379"))
|
||||
self.password = os.getenv("REDIS_PASSWORD", "password")
|
||||
self.queue_name = os.getenv("REDIS_QUEUE", "execution_events")
|
||||
self.connection = None
|
||||
self.queue_name = redis.QUEUE_NAME
|
||||
|
||||
async def connect(self):
|
||||
if not self.connection:
|
||||
self.connection = Redis(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
decode_responses=True,
|
||||
)
|
||||
await self.connection.ping()
|
||||
logger.info(f"Connected to Redis on {self.host}:{self.port}")
|
||||
def connect(self):
|
||||
self.connection = redis.connect()
|
||||
|
||||
async def put(self, execution_result: ExecutionResult):
|
||||
def put(self, execution_result: ExecutionResult):
|
||||
if self.connection:
|
||||
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
|
||||
logger.info(f"Putting execution result to Redis {message}")
|
||||
await self.connection.lpush(self.queue_name, message) # type: ignore
|
||||
self.connection.lpush(self.queue_name, message)
|
||||
|
||||
async def get(self) -> ExecutionResult | None:
|
||||
def get(self) -> ExecutionResult | None:
|
||||
if self.connection:
|
||||
message = await self.connection.rpop(self.queue_name) # type: ignore
|
||||
message = self.connection.rpop(self.queue_name)
|
||||
if message is not None and isinstance(message, (str, bytes, bytearray)):
|
||||
data = json.loads(message)
|
||||
logger.info(f"Getting execution result from Redis {data}")
|
||||
return ExecutionResult(**data)
|
||||
return None
|
||||
|
||||
async def close(self):
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
logger.info("Closed connection to Redis")
|
||||
def close(self):
|
||||
redis.disconnect()
|
||||
|
||||
48
autogpt_platform/backend/backend/data/redis.py
Normal file
48
autogpt_platform/backend/backend/data/redis.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
|
||||
QUEUE_NAME = os.getenv("REDIS_QUEUE", "execution_events")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
connection: Redis | None = None
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring connection")
|
||||
def connect() -> Redis:
|
||||
global connection
|
||||
if connection:
|
||||
return connection
|
||||
|
||||
c = Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
)
|
||||
c.ping()
|
||||
connection = c
|
||||
return connection
|
||||
|
||||
|
||||
@conn_retry("Redis", "Releasing connection")
|
||||
def disconnect():
|
||||
global connection
|
||||
if connection:
|
||||
connection.close()
|
||||
connection = None
|
||||
|
||||
|
||||
def get_redis() -> Redis:
|
||||
if not connection:
|
||||
raise RuntimeError("Redis connection is not established")
|
||||
return connection
|
||||
@@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
from backend.data import db
|
||||
from backend.data import db, redis
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
@@ -216,12 +216,13 @@ def execute_node(
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(api_client: "AgentServer", key: Any):
|
||||
api_client.acquire_lock(key)
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
api_client.release_lock(key)
|
||||
lock.release()
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
@@ -268,7 +269,7 @@ def _enqueue_next_nodes(
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
|
||||
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = wait(
|
||||
upsert_execution_input(
|
||||
@@ -437,6 +438,7 @@ class Executor:
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
redis.connect()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.agent_server_client = get_agent_server_client()
|
||||
|
||||
@@ -454,6 +456,8 @@ class Executor:
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -24,7 +24,6 @@ class ExecutionScheduler(AppService):
|
||||
self.use_db = True
|
||||
self.last_check = datetime.min
|
||||
self.refresh_interval = refresh_interval
|
||||
self.use_redis = False
|
||||
|
||||
@property
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
|
||||
@@ -17,11 +17,10 @@ from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import get_block_costs, get_user_credit_model
|
||||
from backend.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.util.lock import KeyedMutex
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config, Settings
|
||||
|
||||
@@ -32,24 +31,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
mutex = KeyedMutex()
|
||||
use_redis = True
|
||||
use_queue = True
|
||||
_test_dependency_overrides = {}
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
def __init__(self, event_queue: AsyncEventQueue | None = None):
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().agent_server_port)
|
||||
self.event_queue = event_queue or AsyncRedisEventQueue()
|
||||
self.event_queue = RedisEventQueue()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
self.run_and_wait(self.event_queue.connect())
|
||||
self.event_queue.connect()
|
||||
await block.initialize_blocks()
|
||||
if await user_db.create_default_user(settings.config.enable_auth):
|
||||
await graph_db.import_packaged_templates()
|
||||
yield
|
||||
await self.event_queue.close()
|
||||
self.event_queue.close()
|
||||
await db.disconnect()
|
||||
|
||||
def run_service(self):
|
||||
@@ -616,15 +614,7 @@ class AgentServer(AppService):
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@expose
|
||||
def acquire_lock(self, key: Any):
|
||||
self.mutex.lock(key)
|
||||
|
||||
@expose
|
||||
def release_lock(self, key: Any):
|
||||
self.mutex.unlock(key)
|
||||
self.event_queue.put(execution_result)
|
||||
|
||||
@classmethod
|
||||
def update_configuration(
|
||||
|
||||
@@ -7,7 +7,7 @@ from autogpt_libs.auth import parse_jwt_token
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.queue import AsyncRedisEventQueue
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
@@ -20,15 +20,16 @@ settings = Settings()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await event_queue.connect()
|
||||
event_queue.connect()
|
||||
manager = get_connection_manager()
|
||||
asyncio.create_task(event_broadcaster(manager))
|
||||
fut = asyncio.create_task(event_broadcaster(manager))
|
||||
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
||||
yield
|
||||
await event_queue.close()
|
||||
event_queue.close()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
event_queue = AsyncRedisEventQueue()
|
||||
event_queue = RedisEventQueue()
|
||||
_connection_manager = None
|
||||
|
||||
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
|
||||
@@ -50,9 +51,11 @@ def get_connection_manager():
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
event = event_queue.get()
|
||||
if event is not None:
|
||||
await manager.send_execution_result(event)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
|
||||
class KeyedMutex:
|
||||
"""
|
||||
This class provides a mutex that can be locked and unlocked by a specific key.
|
||||
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.locks: dict[Any, tuple[Lock, int]] = ExpiringDict(
|
||||
max_len=6000, max_age_seconds=60
|
||||
)
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def lock(self, key: Any):
|
||||
with self.locks_lock:
|
||||
lock, request_count = self.locks.get(key, (Lock(), 0))
|
||||
self.locks[key] = (lock, request_count + 1)
|
||||
lock.acquire()
|
||||
|
||||
def unlock(self, key: Any):
|
||||
with self.locks_lock:
|
||||
lock, request_count = self.locks.pop(key)
|
||||
if request_count > 1:
|
||||
self.locks[key] = (lock, request_count - 1)
|
||||
lock.release()
|
||||
@@ -10,6 +10,11 @@ from backend.util.logging import configure_logging
|
||||
from backend.util.metrics import sentry_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_SERVICE_NAME = "MainProcess"
|
||||
|
||||
|
||||
def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
@@ -32,6 +37,11 @@ class AppProcess(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def service_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
@@ -52,10 +62,14 @@ class AppProcess(ABC):
|
||||
if silent:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
logger.info(f"[{self.__class__.__name__}] Starting...")
|
||||
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = self.service_name
|
||||
|
||||
logger.info(f"[{self.service_name}] Starting...")
|
||||
self.run()
|
||||
except (KeyboardInterrupt, SystemExit) as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] Terminated: {e}; quitting...")
|
||||
logger.warning(f"[{self.service_name}] Terminated: {e}; quitting...")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
self.cleanup()
|
||||
|
||||
@@ -1,7 +1,48 @@
|
||||
import logging
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
conn_retry = retry(
|
||||
stop=stop_after_attempt(30),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
reraise=True,
|
||||
)
|
||||
from backend.util.process import get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log_prefix(resource_name: str, conn_id: str):
|
||||
"""
|
||||
Returns a prefix string for logging purposes.
|
||||
This needs to be called on the fly to get the current process ID & service name,
|
||||
not the parent process ID & service name.
|
||||
"""
|
||||
return f"[PID-{os.getpid()}|{get_service_name()}|{resource_name}-{conn_id}]"
|
||||
|
||||
|
||||
def conn_retry(resource_name: str, action_name: str, max_retry: int = 5):
|
||||
conn_id = str(uuid4())
|
||||
|
||||
def before_call(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
logger.info(f"{prefix} {action_name} started...")
|
||||
|
||||
def after_call(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
if retry_state.outcome.failed:
|
||||
# Optionally, you can log something here if needed
|
||||
pass
|
||||
else:
|
||||
logger.info(f"{prefix} {action_name} completed!")
|
||||
|
||||
def on_retry(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.info(f"{prefix} {action_name} failed: {exception}. Retrying now...")
|
||||
|
||||
return retry(
|
||||
stop=stop_after_attempt(max_retry + 1),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
before=before_call,
|
||||
after=after_call,
|
||||
before_sleep=on_retry,
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ import Pyro5.api
|
||||
from Pyro5 import api as pyro
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from backend.data.queue import AbstractEventQueue, RedisEventQueue
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config, Secrets
|
||||
@@ -45,20 +45,15 @@ def expose(func: C) -> C:
|
||||
|
||||
class AppService(AppProcess):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
event_queue: AsyncEventQueue = AsyncRedisEventQueue()
|
||||
event_queue: AbstractEventQueue = RedisEventQueue()
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
use_queue: bool = False
|
||||
use_supabase: bool = False
|
||||
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
self.uri = None
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def service_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
@abstractmethod
|
||||
def run_service(self):
|
||||
while True:
|
||||
@@ -75,8 +70,8 @@ class AppService(AppProcess):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
if self.use_db:
|
||||
self.shared_event_loop.run_until_complete(db.connect())
|
||||
if self.use_redis:
|
||||
self.shared_event_loop.run_until_complete(self.event_queue.connect())
|
||||
if self.use_queue:
|
||||
self.event_queue.connect()
|
||||
if self.use_supabase:
|
||||
from supabase import create_client
|
||||
|
||||
@@ -102,11 +97,11 @@ class AppService(AppProcess):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
if self.use_redis:
|
||||
if self.use_queue:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
|
||||
self.run_and_wait(self.event_queue.close())
|
||||
self.event_queue.close()
|
||||
|
||||
@conn_retry
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
host = Config().pyro_host
|
||||
daemon = Pyro5.api.Daemon(host=host, port=self.port)
|
||||
@@ -125,7 +120,7 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient:
|
||||
@conn_retry
|
||||
@conn_retry("Pyro", f"Connecting to [{service_name}]")
|
||||
def __init__(self):
|
||||
host = os.environ.get(f"{service_name.upper()}_HOST", "localhost")
|
||||
uri = f"PYRO:{service_type.service_name}@{host}:{port}"
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import Block, initialize_blocks
|
||||
from backend.data.execution import ExecutionResult, ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME
|
||||
from backend.data.queue import AsyncEventQueue
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server import AgentServer
|
||||
@@ -14,44 +12,10 @@ from backend.server.rest_api import get_user_id
|
||||
log = print
|
||||
|
||||
|
||||
class InMemoryAsyncEventQueue(AsyncEventQueue):
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
self.connected = False
|
||||
self.closed = False
|
||||
|
||||
async def connect(self):
|
||||
if not self.connected:
|
||||
self.connected = True
|
||||
return
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
self.connected = False
|
||||
return
|
||||
|
||||
async def put(self, execution_result: ExecutionResult):
|
||||
if not self.connected:
|
||||
raise RuntimeError("Queue is not connected")
|
||||
await self.queue.put(execution_result)
|
||||
|
||||
async def get(self):
|
||||
if self.closed:
|
||||
return None
|
||||
if not self.connected:
|
||||
raise RuntimeError("Queue is not connected")
|
||||
try:
|
||||
item = await asyncio.wait_for(self.queue.get(), timeout=0.1)
|
||||
return item
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
|
||||
class SpinTestServer:
|
||||
def __init__(self):
|
||||
self.exec_manager = ExecutionManager()
|
||||
self.in_memory_queue = InMemoryAsyncEventQueue()
|
||||
self.agent_server = AgentServer(event_queue=self.in_memory_queue)
|
||||
self.agent_server = AgentServer()
|
||||
self.scheduler = ExecutionScheduler()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,7 +6,6 @@ from backend.util.service import AppService, expose, get_service_client
|
||||
class TestService(AppService):
|
||||
def __init__(self):
|
||||
super().__init__(port=8005)
|
||||
self.use_redis = False
|
||||
|
||||
def run_service(self):
|
||||
super().run_service()
|
||||
|
||||
Reference in New Issue
Block a user