Compare commits

...

1 Commits

Author SHA1 Message Date
Aarushi
9334eee41d update execution manager with redis 2024-10-21 14:02:43 +01:00
3 changed files with 28 additions and 30 deletions

View File

@@ -38,28 +38,6 @@ class NodeExecution(BaseModel):
ExecutionStatus = AgentExecutionStatus ExecutionStatus = AgentExecutionStatus
T = TypeVar("T")
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
def get(self) -> T:
return self.queue.get()
def empty(self) -> bool:
return self.queue.empty()
class ExecutionResult(BaseModel): class ExecutionResult(BaseModel):
graph_id: str graph_id: str

View File

@@ -2,12 +2,14 @@ import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Any, Generic, TypeVar
from backend.data import redis from backend.data import redis
from backend.data.execution import ExecutionResult from backend.data.execution import ExecutionResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T")
class DateTimeEncoder(json.JSONEncoder): class DateTimeEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):
@@ -48,3 +50,21 @@ class RedisEventQueue(AbstractEventQueue):
elif message is not None: elif message is not None:
logger.error(f"Failed to get execution result from Redis {message}") logger.error(f"Failed to get execution result from Redis {message}")
return None return None
class ExecutionQueue(Generic[T]):
def __init__(self, queue_name: str):
self.redis = redis.get_redis()
self.queue_name = queue_name
def add(self, item: T):
message = json.dumps(item.model_dump(), default=str)
self.redis.lpush(self.queue_name, message)
def get(self) -> T:
while True:
_, message = self.redis.brpop(self.queue_name)
return T.model_validate(json.loads(message))
def empty(self) -> bool:
return self.redis.llen(self.queue_name) == 0

View File

@@ -13,13 +13,14 @@ from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from pydantic import BaseModel from pydantic import BaseModel
from redis.lock import Lock as RedisLock from redis.lock import Lock as RedisLock
from backend.data.queue import ExecutionQueue
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.executor import DatabaseManager from backend.executor import DatabaseManager
from backend.data import redis from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.execution import ( from backend.data.execution import (
ExecutionQueue,
ExecutionResult, ExecutionResult,
ExecutionStatus, ExecutionStatus,
GraphExecution, GraphExecution,
@@ -415,6 +416,7 @@ class Executor:
configure_logging() configure_logging()
set_service_name("NodeExecutor") set_service_name("NodeExecutor")
redis.connect() redis.connect()
cls.node_queue = ExecutionQueue[NodeExecution]("node_execution_queue")
cls.pid = os.getpid() cls.pid = os.getpid()
cls.db_client = get_db_client() cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager() cls.creds_manager = IntegrationCredentialsManager()
@@ -454,7 +456,6 @@ class Executor:
@error_logged @error_logged
def on_node_execution( def on_node_execution(
cls, cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution, node_exec: NodeExecution,
): ):
log_metadata = LogMetadata( log_metadata = LogMetadata(
@@ -465,7 +466,7 @@ class Executor:
node_id=node_exec.node_id, node_id=node_exec.node_id,
block_name="-", block_name="-",
) )
q = cls.node_queue
execution_stats = {} execution_stats = {}
timing_info, _ = cls._on_node_execution( timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats q, node_exec, log_metadata, execution_stats
@@ -481,7 +482,6 @@ class Executor:
@time_measured @time_measured
def _on_node_execution( def _on_node_execution(
cls, cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution, node_exec: NodeExecution,
log_metadata: LogMetadata, log_metadata: LogMetadata,
stats: dict[str, Any] | None = None, stats: dict[str, Any] | None = None,
@@ -491,7 +491,7 @@ class Executor:
for execution in execute_node( for execution in execute_node(
cls.db_client, cls.creds_manager, node_exec, stats cls.db_client, cls.creds_manager, node_exec, stats
): ):
q.add(execution) cls.node_queue.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}") log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
except Exception as e: except Exception as e:
log_metadata.exception( log_metadata.exception(
@@ -582,7 +582,7 @@ class Executor:
cancel_thread.start() cancel_thread.start()
try: try:
queue = ExecutionQueue[NodeExecution]() queue = ExecutionQueue[NodeExecution]("node_execution_queue")
for node_exec in graph_exec.start_node_execs: for node_exec in graph_exec.start_node_execs:
queue.add(node_exec) queue.add(node_exec)
@@ -620,7 +620,7 @@ class Executor:
) )
running_executions[exec_data.node_id] = cls.executor.apply_async( running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution, cls.on_node_execution,
(queue, exec_data), (exec_data,),
callback=make_exec_callback(exec_data), callback=make_exec_callback(exec_data),
) )
@@ -661,7 +661,7 @@ class ExecutionManager(AppService):
self.use_redis = True self.use_redis = True
self.use_supabase = True self.use_supabase = True
self.pool_size = settings.config.num_graph_workers self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]() self.queue = ExecutionQueue[GraphExecution]("graph_execution_queue")
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {} self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
def run_service(self): def run_service(self):