mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-11 16:18:07 -05:00
Compare commits
5 Commits
master
...
feat/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f652cb978 | ||
|
|
279552a2a3 | ||
|
|
fb6ac1d6ca | ||
|
|
9db15bff02 | ||
|
|
db4b94e0dc |
@@ -315,9 +315,10 @@ class NodeExecutionResult(BaseModel):
|
||||
input_data: BlockInput
|
||||
output_data: CompletedBlockOutput
|
||||
add_time: datetime
|
||||
queue_time: datetime | None
|
||||
start_time: datetime | None
|
||||
end_time: datetime | None
|
||||
queue_time: datetime | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
stats: NodeExecutionStats | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
@@ -369,6 +370,7 @@ class NodeExecutionResult(BaseModel):
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(
|
||||
@@ -654,6 +656,42 @@ async def upsert_execution_input(
|
||||
)
|
||||
|
||||
|
||||
async def create_node_execution(
|
||||
node_exec_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Create a new node execution with the first input."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecution.prisma().create(
|
||||
data=AgentNodeExecutionCreateInput(
|
||||
id=node_exec_id,
|
||||
agentNodeId=node_id,
|
||||
agentGraphExecutionId=graph_exec_id,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
Input={"create": {"name": input_name, "data": json_input_data}},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def add_input_to_node_execution(
|
||||
node_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Add an input to an existing node execution."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=input_name,
|
||||
data=json_input_data,
|
||||
referencedByInputExecId=node_exec_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
|
||||
@@ -115,7 +115,7 @@ async def generate_activity_status_for_execution(
|
||||
|
||||
# Get all node executions for this graph execution
|
||||
node_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, include_exec_data=True
|
||||
graph_exec_id=graph_exec_id, include_exec_data=True
|
||||
)
|
||||
|
||||
# Get graph metadata and full graph structure for name, description, and links
|
||||
|
||||
@@ -4,12 +4,13 @@ from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
add_input_to_node_execution,
|
||||
create_graph_execution,
|
||||
create_node_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
@@ -17,7 +18,6 @@ from backend.data.execution import (
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.generate_data import get_user_execution_summary_data
|
||||
@@ -105,13 +105,13 @@ class DatabaseManager(AppService):
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
create_node_execution = _(create_node_execution)
|
||||
add_input_to_node_execution = _(add_input_to_node_execution)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
@@ -171,10 +171,12 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
create_node_execution = _(d.create_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
add_input_to_node_execution = _(d.add_input_to_node_execution)
|
||||
|
||||
# Graphs
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
@@ -189,14 +191,6 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
add_store_agent_to_library = _(d.add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -207,16 +201,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
|
||||
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def with_lock(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionCache:
|
||||
def __init__(self, graph_exec_id: str, db_client: "DatabaseManagerClient"):
|
||||
self._lock = threading.RLock()
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._graph_stats: GraphExecutionStats = GraphExecutionStats()
|
||||
self._node_executions: OrderedDict[str, NodeExecutionResult] = OrderedDict()
|
||||
|
||||
for execution in db_client.get_node_executions(self._graph_exec_id):
|
||||
self._node_executions[execution.node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def get_node_execution(self, node_exec_id: str) -> Optional[NodeExecutionResult]:
|
||||
execution = self._node_executions.get(node_exec_id)
|
||||
return execution.model_copy(deep=True) if execution else None
|
||||
|
||||
@with_lock
|
||||
def get_latest_node_execution(self, node_id: str) -> Optional[NodeExecutionResult]:
|
||||
for execution in reversed(self._node_executions.values()):
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status != ExecutionStatus.INCOMPLETE
|
||||
):
|
||||
return execution.model_copy(deep=True)
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
statuses: Optional[list] = None,
|
||||
block_ids: Optional[list] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
results = []
|
||||
for execution in self._node_executions.values():
|
||||
if statuses and execution.status not in statuses:
|
||||
continue
|
||||
if block_ids and execution.block_id not in block_ids:
|
||||
continue
|
||||
if node_id and execution.node_id != node_id:
|
||||
continue
|
||||
results.append(execution.model_copy(deep=True))
|
||||
return results
|
||||
|
||||
@with_lock
|
||||
def update_node_execution_status(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: Optional[dict] = None,
|
||||
stats: Optional[dict] = None,
|
||||
):
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.status = status
|
||||
|
||||
if execution_data:
|
||||
execution.input_data.update(execution_data)
|
||||
|
||||
if stats:
|
||||
execution.stats = execution.stats or NodeExecutionStats()
|
||||
current_stats = execution.stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
execution.stats = NodeExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
) -> NodeExecutionResult:
|
||||
if node_exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {node_exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[node_exec_id]
|
||||
if output_name not in execution.output_data:
|
||||
execution.output_data[output_name] = []
|
||||
execution.output_data[output_name].append(output_data)
|
||||
|
||||
return execution
|
||||
|
||||
@with_lock
|
||||
def update_graph_stats(
|
||||
self, status: Optional[ExecutionStatus] = None, stats: Optional[dict] = None
|
||||
):
|
||||
if status is not None:
|
||||
pass
|
||||
if stats is not None:
|
||||
current_stats = self._graph_stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
self._graph_stats = GraphExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def update_graph_start_time(self):
|
||||
"""Update graph start time (handled by database persistence)."""
|
||||
pass
|
||||
|
||||
@with_lock
|
||||
def find_incomplete_execution_for_input(
|
||||
self, node_id: str, input_name: str
|
||||
) -> tuple[str, NodeExecutionResult] | None:
|
||||
for exec_id, execution in self._node_executions.items():
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status == ExecutionStatus.INCOMPLETE
|
||||
and input_name not in execution.input_data
|
||||
):
|
||||
return exec_id, execution
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def add_node_execution(
|
||||
self, node_exec_id: str, execution: NodeExecutionResult
|
||||
) -> None:
|
||||
self._node_executions[node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def update_execution_input(
|
||||
self, exec_id: str, input_name: str, input_data: Any
|
||||
) -> dict:
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.input_data[input_name] = input_data
|
||||
return execution.input_data.copy()
|
||||
|
||||
def finalize(self) -> None:
|
||||
with self._lock:
|
||||
self._node_executions.clear()
|
||||
self._graph_stats = GraphExecutionStats()
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def execution_client_with_mock_db(event_loop):
|
||||
"""Create an ExecutionDataClient with proper database records."""
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, User
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
# Create test database records to satisfy foreign key constraints
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": "test_user_123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
|
||||
await AgentGraph.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_456",
|
||||
"version": 1,
|
||||
"userId": "test_user_123",
|
||||
"name": "Test Graph",
|
||||
"description": "Test graph for execution tests",
|
||||
}
|
||||
)
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_exec_id",
|
||||
"userId": "test_user_123",
|
||||
"agentGraphId": "test_graph_456",
|
||||
"agentGraphVersion": 1,
|
||||
"executionStatus": AgentExecutionStatus.RUNNING,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Records might already exist, that's fine
|
||||
pass
|
||||
|
||||
# Mock the graph execution metadata - align with assertions below
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Create client with ThreadPoolExecutor and graph metadata (constructed inside patch)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Storage for tracking created executions
|
||||
created_executions = []
|
||||
|
||||
async def mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
def sync_mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock sync execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
# Prepare mock async and sync DB clients
|
||||
async_mock_client = AsyncMock()
|
||||
async_mock_client.create_node_execution = mock_create_node_execution
|
||||
|
||||
sync_mock_client = MagicMock()
|
||||
sync_mock_client.create_node_execution = sync_mock_create_node_execution
|
||||
# Mock graph execution for return values
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
mock_graph_update = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# No-ops for other sync methods used by the client during tests
|
||||
sync_mock_client.add_input_to_node_execution.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_node_execution_status.side_effect = (
|
||||
lambda *args, **kwargs: None
|
||||
)
|
||||
sync_mock_client.upsert_execution_output.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_graph_execution_stats.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
sync_mock_client.update_graph_execution_start_time.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus"
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
# Now construct the client under the patch so it captures the mocked clients
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
# Store the mocks for the test to access if needed
|
||||
setattr(client, "_test_async_client", async_mock_client)
|
||||
setattr(client, "_test_sync_client", sync_mock_client)
|
||||
setattr(client, "_created_executions", created_executions)
|
||||
yield client
|
||||
|
||||
# Cleanup test database records
|
||||
try:
|
||||
await AgentGraphExecution.prisma().delete_many(
|
||||
where={"id": "test_graph_exec_id"}
|
||||
)
|
||||
await AgentGraph.prisma().delete_many(where={"id": "test_graph_456"})
|
||||
await User.prisma().delete_many(where={"id": "test_user_123"})
|
||||
except Exception:
|
||||
# Cleanup may fail if records don't exist
|
||||
pass
|
||||
|
||||
# Cleanup
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionCreation:
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
async def test_execution_creation_with_valid_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that execution creation generates and persists valid IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "test_node_789"
|
||||
input_name = "test_input"
|
||||
input_data = "test_value"
|
||||
block_id = "test_block_abc"
|
||||
|
||||
# This should trigger execution creation since cache is empty
|
||||
exec_id, input_dict = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Verify execution ID is valid UUID
|
||||
try:
|
||||
uuid.UUID(exec_id)
|
||||
except ValueError:
|
||||
pytest.fail(f"Generated execution ID '{exec_id}' is not a valid UUID")
|
||||
|
||||
# Verify execution was created in cache with complete data
|
||||
assert exec_id in client._cache._node_executions
|
||||
cached_execution = client._cache._node_executions[exec_id]
|
||||
|
||||
# Check all required fields have valid values
|
||||
assert cached_execution.user_id == "test_user_123"
|
||||
assert cached_execution.graph_id == "test_graph_456"
|
||||
assert cached_execution.graph_version == 1
|
||||
assert cached_execution.graph_exec_id == "test_graph_exec_id"
|
||||
assert cached_execution.node_exec_id == exec_id
|
||||
assert cached_execution.node_id == node_id
|
||||
assert cached_execution.block_id == block_id
|
||||
assert cached_execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert cached_execution.input_data == {input_name: input_data}
|
||||
assert isinstance(cached_execution.add_time, datetime)
|
||||
|
||||
# Verify execution was persisted to database with our generated ID
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
created = created_executions[0]
|
||||
assert created["node_exec_id"] == exec_id # Our generated ID was used
|
||||
assert created["node_id"] == node_id
|
||||
assert created["graph_exec_id"] == "test_graph_exec_id"
|
||||
assert created["input_name"] == input_name
|
||||
assert created["input_data"] == input_data
|
||||
|
||||
# Verify input dict returned correctly
|
||||
assert input_dict == {input_name: input_data}
|
||||
|
||||
async def test_execution_reuse_vs_creation(self, execution_client_with_mock_db):
|
||||
"""Test that execution reuse works and creation only happens when needed."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "reuse_test_node"
|
||||
block_id = "reuse_test_block"
|
||||
|
||||
# Create first execution
|
||||
exec_id_1, input_dict_1 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_1",
|
||||
input_data="value_1",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# This should reuse the existing INCOMPLETE execution
|
||||
exec_id_2, input_dict_2 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_2",
|
||||
input_data="value_2",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should reuse the same execution
|
||||
assert exec_id_1 == exec_id_2
|
||||
assert input_dict_2 == {"input_1": "value_1", "input_2": "value_2"}
|
||||
|
||||
# Only one execution should be created in database
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
|
||||
# Verify cache has the merged inputs
|
||||
cached_execution = client._cache._node_executions[exec_id_1]
|
||||
assert cached_execution.input_data == {
|
||||
"input_1": "value_1",
|
||||
"input_2": "value_2",
|
||||
}
|
||||
|
||||
# Now complete the execution and try to add another input
|
||||
client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# Verify the execution status was actually updated in the cache
|
||||
updated_execution = client._cache._node_executions[exec_id_1]
|
||||
assert (
|
||||
updated_execution.status == ExecutionStatus.COMPLETED
|
||||
), f"Expected COMPLETED but got {updated_execution.status}"
|
||||
|
||||
# This should create a NEW execution since the first is no longer INCOMPLETE
|
||||
exec_id_3, input_dict_3 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_3",
|
||||
input_data="value_3",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should be a different execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
assert input_dict_3 == {"input_3": "value_3"}
|
||||
|
||||
# Verify cache behavior: should have two different executions in cache now
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_1 in cached_executions
|
||||
assert exec_id_3 in cached_executions
|
||||
|
||||
# First execution should be COMPLETED
|
||||
assert cached_executions[exec_id_1].status == ExecutionStatus.COMPLETED
|
||||
# Third execution should be INCOMPLETE (newly created)
|
||||
assert cached_executions[exec_id_3].status == ExecutionStatus.INCOMPLETE
|
||||
|
||||
async def test_multiple_nodes_get_different_execution_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that different nodes get different execution IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
# Create executions for different nodes
|
||||
exec_id_a, _ = client.upsert_execution_input(
|
||||
node_id="node_a",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_a",
|
||||
)
|
||||
|
||||
exec_id_b, _ = client.upsert_execution_input(
|
||||
node_id="node_b",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_b",
|
||||
)
|
||||
|
||||
# Should be different executions with different IDs
|
||||
assert exec_id_a != exec_id_b
|
||||
|
||||
# Both should be valid UUIDs
|
||||
uuid.UUID(exec_id_a)
|
||||
uuid.UUID(exec_id_b)
|
||||
|
||||
# Both should be in cache
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_a in cached_executions
|
||||
assert exec_id_b in cached_executions
|
||||
|
||||
# Both should have correct node IDs
|
||||
assert cached_executions[exec_id_a].node_id == "node_a"
|
||||
assert cached_executions[exec_id_b].node_id == "node_b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionResult,
|
||||
)
|
||||
from backend.data.graph import Node
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.execution_cache import ExecutionCache
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def non_blocking_persist(func: Callable[P, T]) -> Callable[P, None]:
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
# First argument is always self for methods - access through cast for typing
|
||||
self = cast("ExecutionDataClient", args[0])
|
||||
future = self._executor.submit(func, *args, **kwargs)
|
||||
self._pending_tasks.add(future)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionDataClient:
|
||||
def __init__(
|
||||
self, executor: Executor, graph_exec_id: str, graph_metadata: GraphExecutionMeta
|
||||
):
|
||||
self._executor = executor
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._cache = ExecutionCache(graph_exec_id, self.db_client_sync)
|
||||
self._pending_tasks = set()
|
||||
self._graph_metadata = graph_metadata
|
||||
self.graph_lock = threading.RLock()
|
||||
|
||||
def finalize_execution(self, timeout: float = 30.0):
|
||||
logger.info(f"Flushing db writes for execution {self._graph_exec_id}")
|
||||
exceptions = []
|
||||
|
||||
# Wait for all pending database operations to complete
|
||||
logger.debug(
|
||||
f"Waiting for {len(self._pending_tasks)} pending database operations"
|
||||
)
|
||||
for future in list(self._pending_tasks):
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Background database operation failed: {e}")
|
||||
exceptions.append(e)
|
||||
finally:
|
||||
self._pending_tasks.discard(future)
|
||||
|
||||
self._cache.finalize()
|
||||
|
||||
if exceptions:
|
||||
logger.error(f"Background persistence failed with {len(exceptions)} errors")
|
||||
raise RuntimeError(
|
||||
f"Background persistence failed with {len(exceptions)} errors: {exceptions}"
|
||||
)
|
||||
|
||||
@property
|
||||
def db_client_async(self) -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@property
|
||||
def db_client_sync(self) -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
@property
|
||||
def event_bus(self):
|
||||
return get_execution_event_bus()
|
||||
|
||||
async def get_node(self, node_id: str) -> Node:
|
||||
return await self.db_client_async.get_node(node_id)
|
||||
|
||||
def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
) -> int:
|
||||
return self.db_client_sync.spend_credits(
|
||||
user_id=user_id, cost=cost, metadata=metadata
|
||||
)
|
||||
|
||||
def get_graph_execution_meta(
|
||||
self, user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
return self.db_client_sync.get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=execution_id
|
||||
)
|
||||
|
||||
def get_graph_metadata(
|
||||
self, graph_id: str, graph_version: int | None = None
|
||||
) -> Any:
|
||||
return self.db_client_sync.get_graph_metadata(graph_id, graph_version)
|
||||
|
||||
def get_credits(self, user_id: str) -> int:
|
||||
return self.db_client_sync.get_credits(user_id)
|
||||
|
||||
def get_user_email_by_id(self, user_id: str) -> str | None:
|
||||
return self.db_client_sync.get_user_email_by_id(user_id)
|
||||
|
||||
def get_latest_node_execution(self, node_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_latest_node_execution(node_id)
|
||||
|
||||
def get_node_execution(self, node_exec_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_node_execution(node_exec_id)
|
||||
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
) -> list[NodeExecutionResult]:
|
||||
return self._cache.get_node_executions(
|
||||
statuses=statuses, block_ids=block_ids, node_id=node_id
|
||||
)
|
||||
|
||||
def update_node_status_and_publish(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
self._cache.update_node_execution_status(exec_id, status, execution_data, stats)
|
||||
self._persist_node_status_to_db(exec_id, status, execution_data, stats)
|
||||
|
||||
def upsert_execution_input(
|
||||
self, node_id: str, input_name: str, input_data: Any, block_id: str
|
||||
) -> tuple[str, dict]:
|
||||
# Validate input parameters to prevent foreign key constraint errors
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
raise ValueError(f"Invalid node_id: {node_id}")
|
||||
if not self._graph_exec_id or not isinstance(self._graph_exec_id, str):
|
||||
raise ValueError(f"Invalid graph_exec_id: {self._graph_exec_id}")
|
||||
if not block_id or not isinstance(block_id, str):
|
||||
raise ValueError(f"Invalid block_id: {block_id}")
|
||||
|
||||
# UPDATE: Try to find an existing incomplete execution for this node and input
|
||||
if result := self._cache.find_incomplete_execution_for_input(
|
||||
node_id, input_name
|
||||
):
|
||||
exec_id, _ = result
|
||||
updated_input_data = self._cache.update_execution_input(
|
||||
exec_id, input_name, input_data
|
||||
)
|
||||
self._persist_add_input_to_db(exec_id, input_name, input_data)
|
||||
return exec_id, updated_input_data
|
||||
|
||||
# CREATE: No suitable execution found, create new one
|
||||
node_exec_id = str(uuid.uuid4())
|
||||
logger.debug(
|
||||
f"Creating new execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}"
|
||||
)
|
||||
|
||||
new_execution = NodeExecutionResult(
|
||||
user_id=self._graph_metadata.user_id,
|
||||
graph_id=self._graph_metadata.graph_id,
|
||||
graph_version=self._graph_metadata.graph_version,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={input_name: input_data},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
)
|
||||
self._cache.add_node_execution(node_exec_id, new_execution)
|
||||
self._persist_new_node_execution_to_db(
|
||||
node_exec_id, node_id, input_name, input_data
|
||||
)
|
||||
|
||||
return node_exec_id, {input_name: input_data}
|
||||
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self._cache.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
self._persist_execution_output_to_db(node_exec_id, output_name, output_data)
|
||||
|
||||
def update_graph_stats_and_publish(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> None:
|
||||
stats_dict = stats.model_dump() if stats else None
|
||||
self._cache.update_graph_stats(status=status, stats=stats_dict)
|
||||
self._persist_graph_stats_to_db(status=status, stats=stats)
|
||||
|
||||
def update_graph_start_time_and_publish(self) -> None:
|
||||
self._cache.update_graph_start_time()
|
||||
self._persist_graph_start_time_to_db()
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_node_status_to_db(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
exec_update = self.db_client_sync.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_add_input_to_db(
|
||||
self, node_exec_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
self.db_client_sync.add_input_to_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_execution_output_to_db(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self.db_client_sync.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := self.get_node_execution(node_exec_id):
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_stats_to_db(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
):
|
||||
graph_update = self.db_client_sync.update_graph_execution_stats(
|
||||
self._graph_exec_id, status, stats
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution stats for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_start_time_to_db(self):
|
||||
graph_update = self.db_client_sync.update_graph_execution_start_time(
|
||||
self._graph_exec_id
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution start time for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
async def generate_activity_status(
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
user_id: str,
|
||||
execution_status: ExecutionStatus,
|
||||
) -> str | None:
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
|
||||
return await generate_activity_status_for_execution(
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_stats=execution_stats,
|
||||
db_client=self.db_client_async,
|
||||
user_id=user_id,
|
||||
execution_status=execution_status,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _send_execution_update(self, execution: NodeExecutionResult):
|
||||
"""Send execution update to event bus."""
|
||||
try:
|
||||
self.event_bus.publish(execution)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send execution update: {e}")
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_new_node_execution_to_db(
|
||||
self, node_exec_id: str, node_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
try:
|
||||
self.db_client_sync.create_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create node execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def increment_execution_count(self, user_id: str) -> int:
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}"
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""Test suite for ExecutionDataClient."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def execution_client(event_loop):
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_id",
|
||||
graph_id="test_graph_id",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Mock all database operations to prevent connection attempts
|
||||
async_mock_client = AsyncMock()
|
||||
sync_mock_client = MagicMock()
|
||||
|
||||
# Mock all database methods to return None or empty results
|
||||
sync_mock_client.get_node_executions.return_value = []
|
||||
sync_mock_client.create_node_execution.return_value = None
|
||||
sync_mock_client.add_input_to_node_execution.return_value = None
|
||||
sync_mock_client.update_node_execution_status.return_value = None
|
||||
sync_mock_client.upsert_execution_output.return_value = None
|
||||
sync_mock_client.update_graph_execution_stats.return_value = mock_graph_meta
|
||||
sync_mock_client.update_graph_execution_start_time.return_value = mock_graph_meta
|
||||
|
||||
# Mock event bus to prevent connection attempts
|
||||
mock_event_bus = MagicMock()
|
||||
mock_event_bus.publish.return_value = None
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus",
|
||||
return_value=mock_event_bus,
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
yield client
|
||||
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionDataClient:
|
||||
|
||||
async def test_update_node_status_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that node status updates are immediately visible in cache."""
|
||||
# First create an execution to update
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
status = ExecutionStatus.RUNNING
|
||||
execution_data = {"step": "processing"}
|
||||
stats = {"duration": 5.2}
|
||||
|
||||
# Update status of existing execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=status,
|
||||
execution_data=execution_data,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify immediate visibility in cache
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == status
|
||||
# execution_data should be merged with existing input_data, not replace it
|
||||
expected_input_data = {"test_input": "test_value", "step": "processing"}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
def test_update_node_status_execution_not_found_raises_error(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that updating non-existent execution raises error instead of creating it."""
|
||||
non_existent_id = "does-not-exist"
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Execution does-not-exist not found in cache"
|
||||
):
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=non_existent_id, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
async def test_upsert_execution_output_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that output updates are immediately visible in cache."""
|
||||
# First create an execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
output_name = "result"
|
||||
output_data = {"answer": 42, "confidence": 0.95}
|
||||
|
||||
# Update to RUNNING status first
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name=output_name, output_data=output_data
|
||||
)
|
||||
# Check output through the node execution
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert output_name in cached_exec.output_data
|
||||
assert cached_exec.output_data[output_name] == [output_data]
|
||||
|
||||
async def test_get_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_node_execution returns cached data immediately."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"result": "success"},
|
||||
)
|
||||
|
||||
result = execution_client.get_node_execution(node_exec_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "result": "success"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_latest_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_latest_node_execution returns cached data."""
|
||||
node_id = "node-1"
|
||||
|
||||
# First create an execution for this node
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status to make it non-INCOMPLETE (so it's returned by get_latest)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"from": "cache"},
|
||||
)
|
||||
|
||||
result = execution_client.get_latest_node_execution(node_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.RUNNING
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "from": "cache"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_node_executions_sync_filters_correctly(self, execution_client):
|
||||
# Create executions with different statuses
|
||||
executions = [
|
||||
(ExecutionStatus.RUNNING, "block-a"),
|
||||
(ExecutionStatus.COMPLETED, "block-a"),
|
||||
(ExecutionStatus.FAILED, "block-b"),
|
||||
(ExecutionStatus.RUNNING, "block-b"),
|
||||
]
|
||||
|
||||
exec_ids = []
|
||||
for i, (status, block_id) in enumerate(executions):
|
||||
# First create the execution
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=block_id,
|
||||
)
|
||||
exec_ids.append(exec_id)
|
||||
|
||||
# Then update its status and metadata
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id, status=status, execution_data={"block": block_id}
|
||||
)
|
||||
# Update cached execution with graph_exec_id and block_id for filtering
|
||||
# Note: In real implementation, these would be set during creation
|
||||
# For test purposes, we'll skip this manual update since the filtering
|
||||
# logic should work with the data as created
|
||||
|
||||
# Test status filtering
|
||||
running_execs = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
assert len(running_execs) == 2
|
||||
assert all(e.status == ExecutionStatus.RUNNING for e in running_execs)
|
||||
|
||||
# Test block_id filtering
|
||||
block_a_execs = execution_client.get_node_executions(block_ids=["block-a"])
|
||||
assert len(block_a_execs) == 2
|
||||
assert all(e.block_id == "block-a" for e in block_a_execs)
|
||||
|
||||
# Test combined filtering
|
||||
running_block_b = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING], block_ids=["block-b"]
|
||||
)
|
||||
assert len(running_block_b) == 1
|
||||
assert running_block_b[0].status == ExecutionStatus.RUNNING
|
||||
assert running_block_b[0].block_id == "block-b"
|
||||
|
||||
async def test_write_then_read_consistency(self, execution_client):
|
||||
"""Test critical race condition scenario: immediate read after write."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="consistency-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Write status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"step": 1},
|
||||
)
|
||||
|
||||
# Write output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="intermediate",
|
||||
output_data={"progress": 50},
|
||||
)
|
||||
|
||||
# Update status again
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"step": 2},
|
||||
)
|
||||
|
||||
# All changes should be immediately visible
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data - step 2 overwrites step 1
|
||||
expected_input_data = {"test_input": "test_value", "step": 2}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
# Output should be visible in execution record
|
||||
assert cached_exec.output_data["intermediate"] == [{"progress": 50}]
|
||||
|
||||
async def test_concurrent_operations_are_thread_safe(self, execution_client):
|
||||
"""Test that concurrent operations don't corrupt cache."""
|
||||
num_threads = 3 # Reduced for simpler test
|
||||
operations_per_thread = 5 # Reduced for simpler test
|
||||
|
||||
# Create all executions upfront
|
||||
created_exec_ids = []
|
||||
for thread_id in range(num_threads):
|
||||
for i in range(operations_per_thread):
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{thread_id}-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=f"block-{thread_id}-{i}",
|
||||
)
|
||||
created_exec_ids.append((exec_id, thread_id, i))
|
||||
|
||||
def worker(thread_data):
|
||||
"""Perform multiple operations from a thread."""
|
||||
thread_id, ops = thread_data
|
||||
for i, (exec_id, _, _) in enumerate(ops):
|
||||
# Status updates
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"thread": thread_id, "op": i},
|
||||
)
|
||||
|
||||
# Output updates (use just one exec_id per thread for outputs)
|
||||
if i == 0: # Only add outputs to first execution of each thread
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=exec_id,
|
||||
output_name=f"output_{i}",
|
||||
output_data={"thread": thread_id, "value": i},
|
||||
)
|
||||
|
||||
# Organize executions by thread
|
||||
thread_data = []
|
||||
for tid in range(num_threads):
|
||||
thread_ops = [
|
||||
exec_data for exec_data in created_exec_ids if exec_data[1] == tid
|
||||
]
|
||||
thread_data.append((tid, thread_ops))
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for data in thread_data:
|
||||
thread = threading.Thread(target=worker, args=(data,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for completion
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify data integrity
|
||||
expected_executions = num_threads * operations_per_thread
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == expected_executions
|
||||
|
||||
# Verify outputs - only first execution of each thread should have outputs
|
||||
output_count = 0
|
||||
for execution in all_executions:
|
||||
if execution.output_data:
|
||||
output_count += 1
|
||||
assert output_count == num_threads # One output per thread
|
||||
|
||||
async def test_sync_and_async_versions_consistent(self, execution_client):
|
||||
"""Test that sync and async versions of output operations behave the same."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="sync-async-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="sync_result",
|
||||
output_data={"method": "sync"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="async_result",
|
||||
output_data={"method": "async"},
|
||||
)
|
||||
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert "sync_result" in cached_exec.output_data
|
||||
assert "async_result" in cached_exec.output_data
|
||||
assert cached_exec.output_data["sync_result"] == [{"method": "sync"}]
|
||||
assert cached_exec.output_data["async_result"] == [{"method": "async"}]
|
||||
|
||||
async def test_finalize_execution_completes_and_clears_cache(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that finalize_execution waits for background tasks and clears cache."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="pending-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Trigger some background operations
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name="test", output_data={"value": 1}
|
||||
)
|
||||
|
||||
# Wait for background tasks - may fail in test environment due to DB issues
|
||||
try:
|
||||
execution_client.finalize_execution(timeout=5.0)
|
||||
except RuntimeError as e:
|
||||
# In test environment, background DB operations may fail, but cache should still be cleared
|
||||
assert "Background persistence failed" in str(e)
|
||||
|
||||
# Cache should be cleared regardless of background task failures
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 0 # Cache should be cleared
|
||||
|
||||
async def test_manager_usage_pattern(self, execution_client):
|
||||
# Create executions first
|
||||
node_exec_id_1, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-1",
|
||||
input_name="input1",
|
||||
input_data="data1",
|
||||
block_id="block-1",
|
||||
)
|
||||
|
||||
node_exec_id_2, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-2",
|
||||
input_name="input_from_node1",
|
||||
input_data="value1",
|
||||
block_id="block-2",
|
||||
)
|
||||
|
||||
# Simulate manager.py workflow
|
||||
|
||||
# 1. Start execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "data1"},
|
||||
)
|
||||
|
||||
# 2. Node produces output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id_1,
|
||||
output_name="result",
|
||||
output_data={"output": "value1"},
|
||||
)
|
||||
|
||||
# 3. Complete first node
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# 4. Start second node (would read output from first)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_2,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input_from_node1": "value1"},
|
||||
)
|
||||
|
||||
# 5. Manager queries for executions
|
||||
|
||||
all_executions = execution_client.get_node_executions()
|
||||
running_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
completed_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.COMPLETED]
|
||||
)
|
||||
|
||||
# Verify manager can see all data immediately
|
||||
assert len(all_executions) == 2
|
||||
assert len(running_executions) == 1
|
||||
assert len(completed_executions) == 1
|
||||
|
||||
# Verify output is accessible
|
||||
exec_1 = execution_client.get_node_execution(node_exec_id_1)
|
||||
assert exec_1 is not None
|
||||
assert exec_1.output_data["result"] == [{"output": "value1"}]
|
||||
|
||||
def test_stats_handling_in_update_node_status(self, execution_client):
|
||||
"""Test that stats parameter is properly handled in update_node_status_and_publish."""
|
||||
# Create a fake execution directly in cache to avoid database issues
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
|
||||
node_exec_id = "test-stats-exec-id"
|
||||
fake_execution = NodeExecutionResult(
|
||||
user_id="test-user",
|
||||
graph_id="test-graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec",
|
||||
node_exec_id=node_exec_id,
|
||||
node_id="stats-test-node",
|
||||
block_id="test-block",
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={"test_input": "test_value"},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Add directly to cache
|
||||
execution_client._cache.add_node_execution(node_exec_id, fake_execution)
|
||||
|
||||
stats = {"token_count": 150, "processing_time": 2.5}
|
||||
|
||||
# Update status with stats
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify execution was updated and stats are stored
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.RUNNING
|
||||
|
||||
# Stats should be stored in proper stats field
|
||||
assert execution.stats is not None
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Only check the fields we set, ignore defaults
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
|
||||
# Update with additional stats
|
||||
additional_stats = {"error_count": 0}
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
stats=additional_stats,
|
||||
)
|
||||
|
||||
# Stats should be merged
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.COMPLETED
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Check the merged stats
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
assert stats_dict["error_count"] == 0
|
||||
|
||||
async def test_upsert_execution_input_scenarios(self, execution_client):
|
||||
"""Test different scenarios of upsert_execution_input - create vs update."""
|
||||
node_id = "test-node"
|
||||
graph_exec_id = (
|
||||
"test_graph_exec_id" # Must match the ExecutionDataClient's scope
|
||||
)
|
||||
|
||||
# Scenario 1: Create new execution when none exists
|
||||
exec_id_1, input_data_1 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="first_input",
|
||||
input_data="value1",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert execution.node_id == node_id
|
||||
assert execution.graph_exec_id == graph_exec_id
|
||||
assert input_data_1 == {"first_input": "value1"}
|
||||
|
||||
# Scenario 2: Add input to existing INCOMPLETE execution
|
||||
exec_id_2, input_data_2 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="second_input",
|
||||
input_data="value2",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should use same execution
|
||||
assert exec_id_2 == exec_id_1
|
||||
assert input_data_2 == {"first_input": "value1", "second_input": "value2"}
|
||||
|
||||
# Verify execution has both inputs
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.input_data == {
|
||||
"first_input": "value1",
|
||||
"second_input": "value2",
|
||||
}
|
||||
|
||||
# Scenario 3: Create new execution when existing is not INCOMPLETE
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
exec_id_3, input_data_3 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="third_input",
|
||||
input_data="value3",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
execution_3 = execution_client.get_node_execution(exec_id_3)
|
||||
assert execution_3 is not None
|
||||
assert input_data_3 == {"third_input": "value3"}
|
||||
|
||||
# Verify we now have 2 executions
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 2
|
||||
|
||||
def test_graph_stats_operations(self, execution_client):
|
||||
"""Test graph-level stats and start time operations."""
|
||||
|
||||
# Test update_graph_stats_and_publish
|
||||
from backend.data.model import GraphExecutionStats
|
||||
|
||||
stats = GraphExecutionStats(
|
||||
walltime=10.5, cputime=8.2, node_count=5, node_error_count=1
|
||||
)
|
||||
|
||||
execution_client.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING, stats=stats
|
||||
)
|
||||
|
||||
# Verify stats are stored in cache
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
execution_client.update_graph_start_time_and_publish()
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
def test_public_methods_accessible(self, execution_client):
|
||||
"""Test that public methods are accessible."""
|
||||
assert hasattr(execution_client._cache, "update_node_execution_status")
|
||||
assert hasattr(execution_client._cache, "upsert_execution_output")
|
||||
assert hasattr(execution_client._cache, "add_node_execution")
|
||||
assert hasattr(execution_client._cache, "find_incomplete_execution_for_input")
|
||||
assert hasattr(execution_client._cache, "update_execution_input")
|
||||
assert hasattr(execution_client, "upsert_execution_input")
|
||||
assert hasattr(execution_client, "update_node_status_and_publish")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -5,38 +5,15 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
@@ -48,19 +25,28 @@ from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
@@ -69,21 +55,17 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
async_time_measured,
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -138,7 +120,6 @@ async def execute_node(
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
db_client: The client to send execution updates to the server.
|
||||
creds_manager: The manager to acquire and release credentials.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
@@ -235,7 +216,7 @@ async def execute_node(
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
execution_data_client: ExecutionDataClient,
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
@@ -248,8 +229,7 @@ async def _enqueue_next_nodes(
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
) -> NodeExecutionEntry:
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
execution_data=data,
|
||||
@@ -282,21 +262,22 @@ async def _enqueue_next_nodes(
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
next_node = await execution_data_client.get_node(next_node_id)
|
||||
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
with execution_data_client.graph_lock:
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = await db_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
next_node_exec_id, next_node_input = (
|
||||
execution_data_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
block_id=next_node.block_id,
|
||||
)
|
||||
)
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=next_node_exec_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
)
|
||||
@@ -308,8 +289,8 @@ async def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := await db_client.get_latest_node_execution(
|
||||
next_node_id, graph_exec_id
|
||||
latest_execution := execution_data_client.get_latest_node_execution(
|
||||
next_node_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
@@ -348,9 +329,8 @@ async def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in await db_client.get_node_executions(
|
||||
for iexec in execution_data_client.get_node_executions(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.INCOMPLETE],
|
||||
):
|
||||
idata = iexec.input_data
|
||||
@@ -414,6 +394,9 @@ class ExecutionProcessor:
|
||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||
"""
|
||||
|
||||
# Current execution data client (scoped to current graph execution)
|
||||
execution_data: ExecutionDataClient
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
self,
|
||||
@@ -431,8 +414,7 @@ class ExecutionProcessor:
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
node = await self.execution_data.get_node(node_exec.node_id)
|
||||
execution_stats = NodeExecutionStats()
|
||||
|
||||
timing_info, status = await self._on_node_execution(
|
||||
@@ -440,7 +422,6 @@ class ExecutionProcessor:
|
||||
node_exec=node_exec,
|
||||
node_exec_progress=node_exec_progress,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
@@ -464,15 +445,12 @@ class ExecutionProcessor:
|
||||
if node_error and not isinstance(node_error, str):
|
||||
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
|
||||
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=status,
|
||||
stats=node_stats,
|
||||
)
|
||||
await async_update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
@@ -485,22 +463,17 @@ class ExecutionProcessor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
async def persist_output(output_name: str, output_data: Any) -> None:
|
||||
await db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := await db_client.get_node_execution(
|
||||
node_exec.node_exec_id
|
||||
):
|
||||
await send_async_execution_update(exec_update)
|
||||
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -512,8 +485,7 @@ class ExecutionProcessor:
|
||||
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
@@ -574,6 +546,8 @@ class ExecutionProcessor:
|
||||
self.node_evaluation_thread = threading.Thread(
|
||||
target=self.node_evaluation_loop.run_forever, daemon=True
|
||||
)
|
||||
# single thread executor
|
||||
self.execution_data_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
@@ -593,9 +567,13 @@ class ExecutionProcessor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_client()
|
||||
|
||||
exec_meta = db_client.get_graph_execution_meta(
|
||||
# Get graph execution metadata first via sync client
|
||||
from backend.util.clients import get_database_manager_client
|
||||
|
||||
db_client_sync = get_database_manager_client()
|
||||
|
||||
exec_meta = db_client_sync.get_graph_execution_meta(
|
||||
user_id=graph_exec.user_id,
|
||||
execution_id=graph_exec.graph_exec_id,
|
||||
)
|
||||
@@ -605,12 +583,15 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
# Create scoped ExecutionDataClient for this graph execution with metadata
|
||||
self.execution_data = ExecutionDataClient(
|
||||
self.execution_data_executor, graph_exec.graph_exec_id, exec_meta
|
||||
)
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
|
||||
)
|
||||
self.execution_data.update_graph_start_time_and_publish()
|
||||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||||
@@ -620,9 +601,7 @@ class ExecutionProcessor:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
|
||||
)
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
else:
|
||||
@@ -653,12 +632,10 @@ class ExecutionProcessor:
|
||||
|
||||
# Activity status handling
|
||||
activity_status = asyncio.run_coroutine_threadsafe(
|
||||
generate_activity_status_for_execution(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.generate_activity_status(
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_version=graph_exec.graph_version,
|
||||
execution_stats=exec_stats,
|
||||
db_client=get_db_async_client(),
|
||||
user_id=graph_exec.user_id,
|
||||
execution_status=status,
|
||||
),
|
||||
@@ -673,15 +650,14 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
self._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=exec_meta.status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
self.execution_data.finalize_execution()
|
||||
|
||||
def _charge_usage(
|
||||
self,
|
||||
@@ -690,7 +666,6 @@ class ExecutionProcessor:
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
@@ -700,7 +675,7 @@ class ExecutionProcessor:
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -718,7 +693,7 @@ class ExecutionProcessor:
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -751,7 +726,6 @@ class ExecutionProcessor:
|
||||
"""
|
||||
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
|
||||
error: Exception | None = None
|
||||
db_client = get_db_client()
|
||||
execution_stats_lock = threading.Lock()
|
||||
|
||||
# State holders ----------------------------------------------------
|
||||
@@ -762,7 +736,7 @@ class ExecutionProcessor:
|
||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
|
||||
try:
|
||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
if self.execution_data.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
@@ -774,7 +748,7 @@ class ExecutionProcessor:
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_inputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec=graph_exec,
|
||||
),
|
||||
self.node_evaluation_loop,
|
||||
@@ -789,16 +763,34 @@ class ExecutionProcessor:
|
||||
# ------------------------------------------------------------
|
||||
# Pre‑populate queue ---------------------------------------
|
||||
# ------------------------------------------------------------
|
||||
for node_exec in db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
|
||||
queued_executions = self.execution_data.get_node_executions(
|
||||
statuses=[
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
],
|
||||
):
|
||||
node_entry = node_exec.to_node_execution_entry(graph_exec.user_context)
|
||||
execution_queue.add(node_entry)
|
||||
)
|
||||
log_metadata.info(
|
||||
f"Pre-populating queue with {len(queued_executions)} executions from cache"
|
||||
)
|
||||
|
||||
for i, node_exec in enumerate(queued_executions):
|
||||
log_metadata.info(
|
||||
f" [{i}] {node_exec.node_exec_id}: status={node_exec.status}, node={node_exec.node_id}"
|
||||
)
|
||||
try:
|
||||
node_entry = node_exec.to_node_execution_entry(
|
||||
graph_exec.user_context
|
||||
)
|
||||
execution_queue.add(node_entry)
|
||||
log_metadata.info(" Added to execution queue successfully")
|
||||
except Exception as e:
|
||||
log_metadata.error(f" Failed to add to execution queue: {e}")
|
||||
|
||||
log_metadata.info(
|
||||
f"Execution queue populated with {len(queued_executions)} executions"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Main dispatch / polling loop -----------------------------
|
||||
@@ -818,13 +810,14 @@ class ExecutionProcessor:
|
||||
try:
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_count=self.execution_data.increment_execution_count(
|
||||
graph_exec.user_id
|
||||
),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
@@ -832,19 +825,17 @@ class ExecutionProcessor:
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
output_data=str(error),
|
||||
)
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
error,
|
||||
@@ -931,12 +922,13 @@ class ExecutionProcessor:
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
# Background task finalization moved to finally block
|
||||
|
||||
# Output moderation
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_outputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -990,7 +982,6 @@ class ExecutionProcessor:
|
||||
error=error,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
log_metadata=log_metadata,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
@error_logged(swallow=True)
|
||||
@@ -1003,7 +994,6 @@ class ExecutionProcessor:
|
||||
error: Exception | None,
|
||||
graph_exec_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
db_client: "DatabaseManagerClient",
|
||||
) -> None:
|
||||
"""
|
||||
Clean up running node executions and evaluations when graph execution ends.
|
||||
@@ -1037,8 +1027,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
while queued_execution := execution_queue.get_or_none():
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=queued_execution.node_exec_id,
|
||||
status=execution_status,
|
||||
stats={"error": str(error)} if error else None,
|
||||
@@ -1066,12 +1055,10 @@ class ExecutionProcessor:
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
|
||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||
|
||||
for next_execution in await _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
execution_data_client=self.execution_data,
|
||||
node=output.node,
|
||||
output=output.data,
|
||||
user_id=graph_exec.user_id,
|
||||
@@ -1085,15 +1072,13 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_agent_run_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = db_client.get_graph_metadata(
|
||||
metadata = self.execution_data.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
outputs = self.execution_data.get_node_executions(
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
@@ -1122,13 +1107,12 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
metadata = self.execution_data.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
@@ -1147,7 +1131,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
@@ -1169,7 +1153,6 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
@@ -1198,7 +1181,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
@@ -1576,117 +1559,3 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
|
||||
@func_retry
|
||||
async def send_async_execution_update(
|
||||
entry: GraphExecution | NodeExecutionResult | None,
|
||||
) -> None:
|
||||
if entry is None:
|
||||
return
|
||||
await get_async_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@func_retry
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
async def async_update_node_execution_status(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = await db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
await send_async_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
def update_node_execution_status(
|
||||
db_client: "DatabaseManagerClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
async def async_update_graph_execution_state(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = await db_client.update_graph_execution_stats(
|
||||
graph_exec_id, status, stats
|
||||
)
|
||||
if graph_update:
|
||||
await send_async_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
def update_graph_execution_state(
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||||
if graph_update:
|
||||
send_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
@@ -32,13 +32,17 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -62,6 +66,19 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
assert "$4.00" in discord_message
|
||||
assert "$6.00" in discord_message
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
@@ -90,12 +107,17 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -105,6 +127,19 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
@@ -133,12 +168,17 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -147,3 +187,16 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
# Verify no notification was sent (user was already below threshold)
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
@@ -147,8 +147,10 @@ class AutoModManager:
|
||||
return None
|
||||
|
||||
# Get completed executions and collect outputs
|
||||
completed_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
|
||||
completed_executions = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not completed_executions:
|
||||
@@ -218,7 +220,7 @@ class AutoModManager:
|
||||
):
|
||||
"""Update node execution statuses for frontend display when moderation fails"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.executor.manager import send_async_execution_update
|
||||
from backend.util.clients import get_async_execution_event_bus
|
||||
|
||||
if moderation_type == "input":
|
||||
# For input moderation, mark queued/running/incomplete nodes as failed
|
||||
@@ -232,8 +234,10 @@ class AutoModManager:
|
||||
target_statuses = [ExecutionStatus.COMPLETED]
|
||||
|
||||
# Get the executions that need to be updated
|
||||
executions_to_update = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=target_statuses, include_exec_data=True
|
||||
executions_to_update = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=target_statuses,
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not executions_to_update:
|
||||
@@ -276,10 +280,12 @@ class AutoModManager:
|
||||
updated_execs = await asyncio.gather(*exec_updates)
|
||||
|
||||
# Send all websocket updates in parallel
|
||||
event_bus = get_async_execution_event_bus()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
send_async_execution_update(updated_exec)
|
||||
event_bus.publish(updated_exec)
|
||||
for updated_exec in updated_execs
|
||||
if updated_exec is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user