feat(backend): Make Redis connection Sync + Use Redis as Distributed Lock (#8197)

This commit is contained in:
Zamil Majdy
2024-10-07 12:39:32 +04:00
committed by GitHub
parent fe98abf875
commit daa054c79c
14 changed files with 190 additions and 179 deletions

View File

@@ -32,6 +32,14 @@ jobs:
python-version: ["3.10"]
runs-on: ubuntu-latest
services:
redis:
image: bitnami/redis:6.2
env:
REDIS_PASSWORD: testpassword
ports:
- 6379:6379
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -96,9 +104,9 @@ jobs:
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
else
poetry run pytest -vv test
poetry run pytest -s -vv test
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
@@ -107,6 +115,10 @@ jobs:
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: 'localhost'
REDIS_PORT: '6379'
REDIS_PASSWORD: 'testpassword'
env:
CI: true
PLAIN_OUTPUT: True

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
import os
from contextlib import asynccontextmanager
@@ -8,40 +7,30 @@ from dotenv import load_dotenv
from prisma import Prisma
from pydantic import BaseModel, Field, field_validator
from backend.util.retry import conn_retry
load_dotenv()
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
prisma, conn_id = Prisma(auto_register=True), ""
prisma = Prisma(auto_register=True)
logger = logging.getLogger(__name__)
async def connect(call_count=0):
global conn_id
if not conn_id:
conn_id = str(uuid4())
try:
logger.info(f"[Prisma-{conn_id}] Acquiring connection..")
if not prisma.is_connected():
await prisma.connect()
logger.info(f"[Prisma-{conn_id}] Connection acquired!")
except Exception as e:
if call_count <= 5:
logger.info(f"[Prisma-{conn_id}] Connection failed: {e}. Retrying now..")
await asyncio.sleep(2**call_count)
await connect(call_count + 1)
else:
raise e
async def disconnect():
@conn_retry("Prisma", "Acquiring connection")
async def connect():
if prisma.is_connected():
logger.info(f"[Prisma-{conn_id}] Releasing connection.")
await prisma.disconnect()
logger.info(f"[Prisma-{conn_id}] Connection released.")
return
await prisma.connect()
@conn_retry("Prisma", "Releasing connection")
async def disconnect():
if not prisma.is_connected():
return
await prisma.disconnect()
@asynccontextmanager

View File

@@ -1,11 +1,9 @@
import json
import logging
import os
from abc import ABC, abstractmethod
from datetime import datetime
from redis.asyncio import Redis
from backend.data import redis
from backend.data.execution import ExecutionResult
logger = logging.getLogger(__name__)
@@ -18,60 +16,46 @@ class DateTimeEncoder(json.JSONEncoder):
return super().default(o)
class AsyncEventQueue(ABC):
class AbstractEventQueue(ABC):
@abstractmethod
async def connect(self):
def connect(self):
pass
@abstractmethod
async def close(self):
def close(self):
pass
@abstractmethod
async def put(self, execution_result: ExecutionResult):
def put(self, execution_result: ExecutionResult):
pass
@abstractmethod
async def get(self) -> ExecutionResult | None:
def get(self) -> ExecutionResult | None:
pass
class AsyncRedisEventQueue(AsyncEventQueue):
class RedisEventQueue(AbstractEventQueue):
def __init__(self):
self.host = os.getenv("REDIS_HOST", "localhost")
self.port = int(os.getenv("REDIS_PORT", "6379"))
self.password = os.getenv("REDIS_PASSWORD", "password")
self.queue_name = os.getenv("REDIS_QUEUE", "execution_events")
self.connection = None
self.queue_name = redis.QUEUE_NAME
async def connect(self):
if not self.connection:
self.connection = Redis(
host=self.host,
port=self.port,
password=self.password,
decode_responses=True,
)
await self.connection.ping()
logger.info(f"Connected to Redis on {self.host}:{self.port}")
def connect(self):
self.connection = redis.connect()
async def put(self, execution_result: ExecutionResult):
def put(self, execution_result: ExecutionResult):
if self.connection:
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
logger.info(f"Putting execution result to Redis {message}")
await self.connection.lpush(self.queue_name, message) # type: ignore
self.connection.lpush(self.queue_name, message)
async def get(self) -> ExecutionResult | None:
def get(self) -> ExecutionResult | None:
if self.connection:
message = await self.connection.rpop(self.queue_name) # type: ignore
message = self.connection.rpop(self.queue_name)
if message is not None and isinstance(message, (str, bytes, bytearray)):
data = json.loads(message)
logger.info(f"Getting execution result from Redis {data}")
return ExecutionResult(**data)
return None
async def close(self):
if self.connection:
await self.connection.close()
self.connection = None
logger.info("Closed connection to Redis")
def close(self):
redis.disconnect()

View File

@@ -0,0 +1,48 @@
import logging
import os
from dotenv import load_dotenv
from redis import Redis
from backend.util.retry import conn_retry
load_dotenv()
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
QUEUE_NAME = os.getenv("REDIS_QUEUE", "execution_events")
logger = logging.getLogger(__name__)
connection: Redis | None = None
@conn_retry("Redis", "Acquiring connection")
def connect() -> Redis:
global connection
if connection:
return connection
c = Redis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=True,
)
c.ping()
connection = c
return connection
@conn_retry("Redis", "Releasing connection")
def disconnect():
global connection
if connection:
connection.close()
connection = None
def get_redis() -> Redis:
if not connection:
raise RuntimeError("Redis connection is not established")
return connection

View File

@@ -17,7 +17,7 @@ from pydantic import BaseModel
if TYPE_CHECKING:
from backend.server.rest_api import AgentServer
from backend.data import db
from backend.data import db, redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
@@ -216,12 +216,13 @@ def execute_node(
@contextmanager
def synchronized(api_client: "AgentServer", key: Any):
api_client.acquire_lock(key)
def synchronized(key: str, timeout: int = 60):
lock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
try:
lock.acquire()
yield
finally:
api_client.release_lock(key)
lock.release()
def _enqueue_next_nodes(
@@ -268,7 +269,7 @@ def _enqueue_next_nodes(
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = wait(
upsert_execution_input(
@@ -437,6 +438,7 @@ class Executor:
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()
redis.connect()
cls.loop.run_until_complete(db.connect())
cls.agent_server_client = get_agent_server_client()
@@ -454,6 +456,8 @@ class Executor:
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
@classmethod

View File

@@ -24,7 +24,6 @@ class ExecutionScheduler(AppService):
self.use_db = True
self.last_check = datetime.min
self.refresh_interval = refresh_interval
self.use_redis = False
@property
def execution_manager_client(self) -> ExecutionManager:

View File

@@ -17,11 +17,10 @@ from backend.data import graph as graph_db
from backend.data import user as user_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.queue import AsyncEventQueue, AsyncRedisEventQueue
from backend.data.queue import RedisEventQueue
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.lock import KeyedMutex
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config, Settings
@@ -32,24 +31,23 @@ logger = logging.getLogger(__name__)
class AgentServer(AppService):
mutex = KeyedMutex()
use_redis = True
use_queue = True
_test_dependency_overrides = {}
_user_credit_model = get_user_credit_model()
def __init__(self, event_queue: AsyncEventQueue | None = None):
def __init__(self):
super().__init__(port=Config().agent_server_port)
self.event_queue = event_queue or AsyncRedisEventQueue()
self.event_queue = RedisEventQueue()
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
self.run_and_wait(self.event_queue.connect())
self.event_queue.connect()
await block.initialize_blocks()
if await user_db.create_default_user(settings.config.enable_auth):
await graph_db.import_packaged_templates()
yield
await self.event_queue.close()
self.event_queue.close()
await db.disconnect()
def run_service(self):
@@ -616,15 +614,7 @@ class AgentServer(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.run_and_wait(self.event_queue.put(execution_result))
@expose
def acquire_lock(self, key: Any):
self.mutex.lock(key)
@expose
def release_lock(self, key: Any):
self.mutex.unlock(key)
self.event_queue.put(execution_result)
@classmethod
def update_configuration(

View File

@@ -7,7 +7,7 @@ from autogpt_libs.auth import parse_jwt_token
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from backend.data.queue import AsyncRedisEventQueue
from backend.data.queue import RedisEventQueue
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
@@ -20,15 +20,16 @@ settings = Settings()
@asynccontextmanager
async def lifespan(app: FastAPI):
await event_queue.connect()
event_queue.connect()
manager = get_connection_manager()
asyncio.create_task(event_broadcaster(manager))
fut = asyncio.create_task(event_broadcaster(manager))
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
yield
await event_queue.close()
event_queue.close()
app = FastAPI(lifespan=lifespan)
event_queue = AsyncRedisEventQueue()
event_queue = RedisEventQueue()
_connection_manager = None
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
@@ -50,9 +51,11 @@ def get_connection_manager():
async def event_broadcaster(manager: ConnectionManager):
while True:
event = await event_queue.get()
event = event_queue.get()
if event is not None:
await manager.send_execution_result(event)
else:
await asyncio.sleep(0.1)
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -1,31 +0,0 @@
from threading import Lock
from typing import Any
from expiringdict import ExpiringDict
class KeyedMutex:
"""
This class provides a mutex that can be locked and unlocked by a specific key.
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
in case the key is not unlocked for a specified duration, to prevent memory leaks.
"""
def __init__(self):
self.locks: dict[Any, tuple[Lock, int]] = ExpiringDict(
max_len=6000, max_age_seconds=60
)
self.locks_lock = Lock()
def lock(self, key: Any):
with self.locks_lock:
lock, request_count = self.locks.get(key, (Lock(), 0))
self.locks[key] = (lock, request_count + 1)
lock.acquire()
def unlock(self, key: Any):
with self.locks_lock:
lock, request_count = self.locks.pop(key)
if request_count > 1:
self.locks[key] = (lock, request_count - 1)
lock.release()

View File

@@ -10,6 +10,11 @@ from backend.util.logging import configure_logging
from backend.util.metrics import sentry_init
logger = logging.getLogger(__name__)
_SERVICE_NAME = "MainProcess"
def get_service_name():
return _SERVICE_NAME
class AppProcess(ABC):
@@ -32,6 +37,11 @@ class AppProcess(ABC):
"""
pass
@classmethod
@property
def service_name(cls) -> str:
return cls.__name__
def cleanup(self):
"""
Implement this method on a subclass to do post-execution cleanup,
@@ -52,10 +62,14 @@ class AppProcess(ABC):
if silent:
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
logger.info(f"[{self.__class__.__name__}] Starting...")
global _SERVICE_NAME
_SERVICE_NAME = self.service_name
logger.info(f"[{self.service_name}] Starting...")
self.run()
except (KeyboardInterrupt, SystemExit) as e:
logger.warning(f"[{self.__class__.__name__}] Terminated: {e}; quitting...")
logger.warning(f"[{self.service_name}] Terminated: {e}; quitting...")
def _self_terminate(self, signum: int, frame):
self.cleanup()

View File

@@ -1,7 +1,48 @@
import logging
import os
from uuid import uuid4
from tenacity import retry, stop_after_attempt, wait_exponential
conn_retry = retry(
stop=stop_after_attempt(30),
wait=wait_exponential(multiplier=1, min=1, max=30),
reraise=True,
)
from backend.util.process import get_service_name
logger = logging.getLogger(__name__)
def _log_prefix(resource_name: str, conn_id: str):
"""
Returns a prefix string for logging purposes.
This needs to be called on the fly to get the current process ID & service name,
not the parent process ID & service name.
"""
return f"[PID-{os.getpid()}|{get_service_name()}|{resource_name}-{conn_id}]"
def conn_retry(resource_name: str, action_name: str, max_retry: int = 5):
conn_id = str(uuid4())
def before_call(retry_state):
prefix = _log_prefix(resource_name, conn_id)
logger.info(f"{prefix} {action_name} started...")
def after_call(retry_state):
prefix = _log_prefix(resource_name, conn_id)
if retry_state.outcome.failed:
# Optionally, you can log something here if needed
pass
else:
logger.info(f"{prefix} {action_name} completed!")
def on_retry(retry_state):
prefix = _log_prefix(resource_name, conn_id)
exception = retry_state.outcome.exception()
logger.info(f"{prefix} {action_name} failed: {exception}. Retrying now...")
return retry(
stop=stop_after_attempt(max_retry + 1),
wait=wait_exponential(multiplier=1, min=1, max=30),
before=before_call,
after=after_call,
before_sleep=on_retry,
reraise=True,
)

View File

@@ -10,7 +10,7 @@ import Pyro5.api
from Pyro5 import api as pyro
from backend.data import db
from backend.data.queue import AsyncEventQueue, AsyncRedisEventQueue
from backend.data.queue import AbstractEventQueue, RedisEventQueue
from backend.util.process import AppProcess
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
@@ -45,20 +45,15 @@ def expose(func: C) -> C:
class AppService(AppProcess):
shared_event_loop: asyncio.AbstractEventLoop
event_queue: AsyncEventQueue = AsyncRedisEventQueue()
event_queue: AbstractEventQueue = RedisEventQueue()
use_db: bool = False
use_redis: bool = False
use_queue: bool = False
use_supabase: bool = False
def __init__(self, port):
self.port = port
self.uri = None
@classmethod
@property
def service_name(cls) -> str:
return cls.__name__
@abstractmethod
def run_service(self):
while True:
@@ -75,8 +70,8 @@ class AppService(AppProcess):
self.shared_event_loop = asyncio.get_event_loop()
if self.use_db:
self.shared_event_loop.run_until_complete(db.connect())
if self.use_redis:
self.shared_event_loop.run_until_complete(self.event_queue.connect())
if self.use_queue:
self.event_queue.connect()
if self.use_supabase:
from supabase import create_client
@@ -102,11 +97,11 @@ class AppService(AppProcess):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
self.run_and_wait(db.disconnect())
if self.use_redis:
if self.use_queue:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
self.run_and_wait(self.event_queue.close())
self.event_queue.close()
@conn_retry
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
host = Config().pyro_host
daemon = Pyro5.api.Daemon(host=host, port=self.port)
@@ -125,7 +120,7 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
service_name = service_type.service_name
class DynamicClient:
@conn_retry
@conn_retry("Pyro", f"Connecting to [{service_name}]")
def __init__(self):
host = os.environ.get(f"{service_name.upper()}_HOST", "localhost")
uri = f"PYRO:{service_type.service_name}@{host}:{port}"

View File

@@ -1,11 +1,9 @@
import asyncio
import time
from backend.data import db
from backend.data.block import Block, initialize_blocks
from backend.data.execution import ExecutionResult, ExecutionStatus
from backend.data.execution import ExecutionStatus
from backend.data.model import CREDENTIALS_FIELD_NAME
from backend.data.queue import AsyncEventQueue
from backend.data.user import create_default_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server import AgentServer
@@ -14,44 +12,10 @@ from backend.server.rest_api import get_user_id
log = print
class InMemoryAsyncEventQueue(AsyncEventQueue):
def __init__(self):
self.queue = asyncio.Queue()
self.connected = False
self.closed = False
async def connect(self):
if not self.connected:
self.connected = True
return
async def close(self):
self.closed = True
self.connected = False
return
async def put(self, execution_result: ExecutionResult):
if not self.connected:
raise RuntimeError("Queue is not connected")
await self.queue.put(execution_result)
async def get(self):
if self.closed:
return None
if not self.connected:
raise RuntimeError("Queue is not connected")
try:
item = await asyncio.wait_for(self.queue.get(), timeout=0.1)
return item
except asyncio.TimeoutError:
return None
class SpinTestServer:
def __init__(self):
self.exec_manager = ExecutionManager()
self.in_memory_queue = InMemoryAsyncEventQueue()
self.agent_server = AgentServer(event_queue=self.in_memory_queue)
self.agent_server = AgentServer()
self.scheduler = ExecutionScheduler()
@staticmethod

View File

@@ -6,7 +6,6 @@ from backend.util.service import AppService, expose, get_service_client
class TestService(AppService):
def __init__(self):
super().__init__(port=8005)
self.use_redis = False
def run_service(self):
super().run_service()