feat(backend): RedisEventQueue into Pub/Sub (#8387)

This commit is contained in:
Zamil Majdy
2024-10-27 23:49:33 +07:00
committed by GitHub
parent 8a68516f0b
commit 1e620fdb13
5 changed files with 172 additions and 41 deletions

View File

@@ -2,11 +2,18 @@ import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
from pydantic import BaseModel
from redis.asyncio.client import PubSub as AsyncPubSub
from redis.client import PubSub
from backend.data import redis
from backend.data.execution import ExecutionResult
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
class DateTimeEncoder(json.JSONEncoder):
@@ -16,35 +23,122 @@ class DateTimeEncoder(json.JSONEncoder):
return super().default(o)
class AbstractEventQueue(ABC):
@abstractmethod
def put(self, execution_result: ExecutionResult):
pass
@abstractmethod
def get(self) -> ExecutionResult | None:
pass
M = TypeVar("M", bound=BaseModel)
class RedisEventQueue(AbstractEventQueue):
def __init__(self):
self.queue_name = redis.QUEUE_NAME
class BaseRedisEventBus(Generic[M], ABC):
Model: type[M]
@property
def connection(self):
@abstractmethod
def event_bus_name(self) -> str:
pass
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
channel_name = f"{self.event_bus_name}-{channel_key}"
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
return message, channel_name
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
message_type = "pmessage" if "*" in channel_key else "message"
if msg["type"] != message_type:
return None
try:
data = json.loads(msg["data"])
logger.info(f"Consuming an event from Redis {data}")
return self.Model(**data)
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")
def _subscribe(
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
) -> tuple[PubSub | AsyncPubSub, str]:
channel_name = f"{self.event_bus_name}-{channel_key}"
pubsub = connection.pubsub()
return pubsub, channel_name
class RedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
@property
def connection(self) -> redis.Redis:
return redis.get_redis()
def put(self, execution_result: ExecutionResult):
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
logger.info(f"Putting execution result to Redis {message}")
self.connection.lpush(self.queue_name, message)
def publish_event(self, event: M, channel_key: str):
message, channel_name = self._serialize_message(event, channel_key)
self.connection.publish(channel_name, message)
def get(self) -> ExecutionResult | None:
message = self.connection.rpop(self.queue_name)
if message is not None and isinstance(message, (str, bytes, bytearray)):
data = json.loads(message)
logger.info(f"Getting execution result from Redis {data}")
return ExecutionResult(**data)
elif message is not None:
logger.error(f"Failed to get execution result from Redis {message}")
return None
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
pubsub, channel_name = self._subscribe(self.connection, channel_key)
assert isinstance(pubsub, PubSub)
if "*" in channel_key:
pubsub.psubscribe(channel_name)
else:
pubsub.subscribe(channel_name)
for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()
async def publish_event(self, event: M, channel_key: str):
message, channel_name = self._serialize_message(event, channel_key)
connection = await self.connection
await connection.publish(channel_name, message)
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
pubsub, channel_name = self._subscribe(await self.connection, channel_key)
assert isinstance(pubsub, AsyncPubSub)
if "*" in channel_key:
await pubsub.psubscribe(channel_name)
else:
await pubsub.subscribe(channel_name)
async for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event
class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
Model = ExecutionResult
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
def publish(self, res: ExecutionResult):
self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")
def listen(
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> Generator[ExecutionResult, None, None]:
for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
yield execution_result
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
Model = ExecutionResult
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
async def publish(self, res: ExecutionResult):
await self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")
async def listen(
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> AsyncGenerator[ExecutionResult, None]:
async for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
yield execution_result

View File

@@ -3,6 +3,7 @@ import os
from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from backend.util.retry import conn_retry
@@ -11,10 +12,10 @@ load_dotenv()
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
QUEUE_NAME = os.getenv("REDIS_QUEUE", "execution_events")
logger = logging.getLogger(__name__)
connection: Redis | None = None
connection_async: AsyncRedis | None = None
@conn_retry("Redis", "Acquiring connection")
@@ -42,7 +43,42 @@ def disconnect():
connection = None
def get_redis() -> Redis:
if not connection:
raise RuntimeError("Redis connection is not established")
return connection
def get_redis(auto_connect: bool = True) -> Redis:
if connection:
return connection
if auto_connect:
return connect()
raise RuntimeError("Redis connection is not established")
@conn_retry("AsyncRedis", "Acquiring connection")
async def connect_async() -> AsyncRedis:
global connection_async
if connection_async:
return connection_async
c = AsyncRedis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=True,
)
await c.ping()
connection_async = c
return connection_async
@conn_retry("AsyncRedis", "Releasing connection")
async def disconnect_async():
global connection_async
if connection_async:
await connection_async.close()
connection_async = None
async def get_redis_async(auto_connect: bool = True) -> AsyncRedis:
if connection_async:
return connection_async
if auto_connect:
return await connect_async()
raise RuntimeError("AsyncRedis connection is not established")

View File

@@ -15,7 +15,7 @@ from backend.data.execution import (
upsert_execution_output,
)
from backend.data.graph import get_graph, get_node
from backend.data.queue import RedisEventQueue
from backend.data.queue import RedisExecutionEventBus
from backend.data.user import get_user_metadata, update_user_metadata
from backend.util.service import AppService, expose
from backend.util.settings import Config
@@ -30,7 +30,7 @@ class DatabaseManager(AppService):
super().__init__()
self.use_db = True
self.use_redis = True
self.event_queue = RedisEventQueue()
self.event_queue = RedisExecutionEventBus()
@classmethod
def get_port(cls) -> int:
@@ -38,7 +38,7 @@ class DatabaseManager(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
self.event_queue.put(ExecutionResult(**execution_result_dict))
self.event_queue.publish(ExecutionResult(**execution_result_dict))
@staticmethod
def exposed_run_and_wait(

View File

@@ -8,7 +8,7 @@ from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from backend.data import redis
from backend.data.queue import RedisEventQueue
from backend.data.queue import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
@@ -51,13 +51,9 @@ def get_connection_manager():
async def event_broadcaster(manager: ConnectionManager):
try:
redis.connect()
event_queue = RedisEventQueue()
while True:
event = event_queue.get()
if event:
await manager.send_execution_result(event)
else:
await asyncio.sleep(0.1)
event_queue = AsyncRedisExecutionEventBus()
async for event in event_queue.listen():
await manager.send_execution_result(event)
except Exception as e:
logger.exception(f"Event broadcaster error: {e}")
raise

View File

@@ -148,6 +148,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="What environment to behave as: local or cloud",
)
execution_event_bus_name: str = Field(
default="execution_event",
description="Name of the event bus",
)
backend_cors_allow_origins: List[str] = Field(default_factory=list)
@field_validator("backend_cors_allow_origins")