mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 08:38:09 -05:00
Compare commits
5 Commits
fix/flag-r
...
feat/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f652cb978 | ||
|
|
279552a2a3 | ||
|
|
fb6ac1d6ca | ||
|
|
9db15bff02 | ||
|
|
db4b94e0dc |
@@ -315,9 +315,10 @@ class NodeExecutionResult(BaseModel):
|
||||
input_data: BlockInput
|
||||
output_data: CompletedBlockOutput
|
||||
add_time: datetime
|
||||
queue_time: datetime | None
|
||||
start_time: datetime | None
|
||||
end_time: datetime | None
|
||||
queue_time: datetime | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
stats: NodeExecutionStats | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
@@ -369,6 +370,7 @@ class NodeExecutionResult(BaseModel):
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(
|
||||
@@ -654,6 +656,42 @@ async def upsert_execution_input(
|
||||
)
|
||||
|
||||
|
||||
async def create_node_execution(
|
||||
node_exec_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Create a new node execution with the first input."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecution.prisma().create(
|
||||
data=AgentNodeExecutionCreateInput(
|
||||
id=node_exec_id,
|
||||
agentNodeId=node_id,
|
||||
agentGraphExecutionId=graph_exec_id,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
Input={"create": {"name": input_name, "data": json_input_data}},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def add_input_to_node_execution(
|
||||
node_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Add an input to an existing node execution."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=input_name,
|
||||
data=json_input_data,
|
||||
referencedByInputExecId=node_exec_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
|
||||
@@ -115,7 +115,7 @@ async def generate_activity_status_for_execution(
|
||||
|
||||
# Get all node executions for this graph execution
|
||||
node_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, include_exec_data=True
|
||||
graph_exec_id=graph_exec_id, include_exec_data=True
|
||||
)
|
||||
|
||||
# Get graph metadata and full graph structure for name, description, and links
|
||||
|
||||
@@ -4,12 +4,13 @@ from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
add_input_to_node_execution,
|
||||
create_graph_execution,
|
||||
create_node_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
@@ -17,7 +18,6 @@ from backend.data.execution import (
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.generate_data import get_user_execution_summary_data
|
||||
@@ -105,13 +105,13 @@ class DatabaseManager(AppService):
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
create_node_execution = _(create_node_execution)
|
||||
add_input_to_node_execution = _(add_input_to_node_execution)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
@@ -171,10 +171,12 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
create_node_execution = _(d.create_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
add_input_to_node_execution = _(d.add_input_to_node_execution)
|
||||
|
||||
# Graphs
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
@@ -189,14 +191,6 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
add_store_agent_to_library = _(d.add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -207,16 +201,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
|
||||
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def with_lock(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionCache:
|
||||
def __init__(self, graph_exec_id: str, db_client: "DatabaseManagerClient"):
|
||||
self._lock = threading.RLock()
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._graph_stats: GraphExecutionStats = GraphExecutionStats()
|
||||
self._node_executions: OrderedDict[str, NodeExecutionResult] = OrderedDict()
|
||||
|
||||
for execution in db_client.get_node_executions(self._graph_exec_id):
|
||||
self._node_executions[execution.node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def get_node_execution(self, node_exec_id: str) -> Optional[NodeExecutionResult]:
|
||||
execution = self._node_executions.get(node_exec_id)
|
||||
return execution.model_copy(deep=True) if execution else None
|
||||
|
||||
@with_lock
|
||||
def get_latest_node_execution(self, node_id: str) -> Optional[NodeExecutionResult]:
|
||||
for execution in reversed(self._node_executions.values()):
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status != ExecutionStatus.INCOMPLETE
|
||||
):
|
||||
return execution.model_copy(deep=True)
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
statuses: Optional[list] = None,
|
||||
block_ids: Optional[list] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
results = []
|
||||
for execution in self._node_executions.values():
|
||||
if statuses and execution.status not in statuses:
|
||||
continue
|
||||
if block_ids and execution.block_id not in block_ids:
|
||||
continue
|
||||
if node_id and execution.node_id != node_id:
|
||||
continue
|
||||
results.append(execution.model_copy(deep=True))
|
||||
return results
|
||||
|
||||
@with_lock
|
||||
def update_node_execution_status(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: Optional[dict] = None,
|
||||
stats: Optional[dict] = None,
|
||||
):
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.status = status
|
||||
|
||||
if execution_data:
|
||||
execution.input_data.update(execution_data)
|
||||
|
||||
if stats:
|
||||
execution.stats = execution.stats or NodeExecutionStats()
|
||||
current_stats = execution.stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
execution.stats = NodeExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
) -> NodeExecutionResult:
|
||||
if node_exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {node_exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[node_exec_id]
|
||||
if output_name not in execution.output_data:
|
||||
execution.output_data[output_name] = []
|
||||
execution.output_data[output_name].append(output_data)
|
||||
|
||||
return execution
|
||||
|
||||
@with_lock
|
||||
def update_graph_stats(
|
||||
self, status: Optional[ExecutionStatus] = None, stats: Optional[dict] = None
|
||||
):
|
||||
if status is not None:
|
||||
pass
|
||||
if stats is not None:
|
||||
current_stats = self._graph_stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
self._graph_stats = GraphExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def update_graph_start_time(self):
|
||||
"""Update graph start time (handled by database persistence)."""
|
||||
pass
|
||||
|
||||
@with_lock
|
||||
def find_incomplete_execution_for_input(
|
||||
self, node_id: str, input_name: str
|
||||
) -> tuple[str, NodeExecutionResult] | None:
|
||||
for exec_id, execution in self._node_executions.items():
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status == ExecutionStatus.INCOMPLETE
|
||||
and input_name not in execution.input_data
|
||||
):
|
||||
return exec_id, execution
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def add_node_execution(
|
||||
self, node_exec_id: str, execution: NodeExecutionResult
|
||||
) -> None:
|
||||
self._node_executions[node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def update_execution_input(
|
||||
self, exec_id: str, input_name: str, input_data: Any
|
||||
) -> dict:
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.input_data[input_name] = input_data
|
||||
return execution.input_data.copy()
|
||||
|
||||
def finalize(self) -> None:
|
||||
with self._lock:
|
||||
self._node_executions.clear()
|
||||
self._graph_stats = GraphExecutionStats()
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def execution_client_with_mock_db(event_loop):
|
||||
"""Create an ExecutionDataClient with proper database records."""
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, User
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
# Create test database records to satisfy foreign key constraints
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": "test_user_123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
|
||||
await AgentGraph.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_456",
|
||||
"version": 1,
|
||||
"userId": "test_user_123",
|
||||
"name": "Test Graph",
|
||||
"description": "Test graph for execution tests",
|
||||
}
|
||||
)
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_exec_id",
|
||||
"userId": "test_user_123",
|
||||
"agentGraphId": "test_graph_456",
|
||||
"agentGraphVersion": 1,
|
||||
"executionStatus": AgentExecutionStatus.RUNNING,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Records might already exist, that's fine
|
||||
pass
|
||||
|
||||
# Mock the graph execution metadata - align with assertions below
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Create client with ThreadPoolExecutor and graph metadata (constructed inside patch)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Storage for tracking created executions
|
||||
created_executions = []
|
||||
|
||||
async def mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
def sync_mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock sync execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
# Prepare mock async and sync DB clients
|
||||
async_mock_client = AsyncMock()
|
||||
async_mock_client.create_node_execution = mock_create_node_execution
|
||||
|
||||
sync_mock_client = MagicMock()
|
||||
sync_mock_client.create_node_execution = sync_mock_create_node_execution
|
||||
# Mock graph execution for return values
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
mock_graph_update = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# No-ops for other sync methods used by the client during tests
|
||||
sync_mock_client.add_input_to_node_execution.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_node_execution_status.side_effect = (
|
||||
lambda *args, **kwargs: None
|
||||
)
|
||||
sync_mock_client.upsert_execution_output.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_graph_execution_stats.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
sync_mock_client.update_graph_execution_start_time.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus"
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
# Now construct the client under the patch so it captures the mocked clients
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
# Store the mocks for the test to access if needed
|
||||
setattr(client, "_test_async_client", async_mock_client)
|
||||
setattr(client, "_test_sync_client", sync_mock_client)
|
||||
setattr(client, "_created_executions", created_executions)
|
||||
yield client
|
||||
|
||||
# Cleanup test database records
|
||||
try:
|
||||
await AgentGraphExecution.prisma().delete_many(
|
||||
where={"id": "test_graph_exec_id"}
|
||||
)
|
||||
await AgentGraph.prisma().delete_many(where={"id": "test_graph_456"})
|
||||
await User.prisma().delete_many(where={"id": "test_user_123"})
|
||||
except Exception:
|
||||
# Cleanup may fail if records don't exist
|
||||
pass
|
||||
|
||||
# Cleanup
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionCreation:
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
async def test_execution_creation_with_valid_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that execution creation generates and persists valid IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "test_node_789"
|
||||
input_name = "test_input"
|
||||
input_data = "test_value"
|
||||
block_id = "test_block_abc"
|
||||
|
||||
# This should trigger execution creation since cache is empty
|
||||
exec_id, input_dict = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Verify execution ID is valid UUID
|
||||
try:
|
||||
uuid.UUID(exec_id)
|
||||
except ValueError:
|
||||
pytest.fail(f"Generated execution ID '{exec_id}' is not a valid UUID")
|
||||
|
||||
# Verify execution was created in cache with complete data
|
||||
assert exec_id in client._cache._node_executions
|
||||
cached_execution = client._cache._node_executions[exec_id]
|
||||
|
||||
# Check all required fields have valid values
|
||||
assert cached_execution.user_id == "test_user_123"
|
||||
assert cached_execution.graph_id == "test_graph_456"
|
||||
assert cached_execution.graph_version == 1
|
||||
assert cached_execution.graph_exec_id == "test_graph_exec_id"
|
||||
assert cached_execution.node_exec_id == exec_id
|
||||
assert cached_execution.node_id == node_id
|
||||
assert cached_execution.block_id == block_id
|
||||
assert cached_execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert cached_execution.input_data == {input_name: input_data}
|
||||
assert isinstance(cached_execution.add_time, datetime)
|
||||
|
||||
# Verify execution was persisted to database with our generated ID
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
created = created_executions[0]
|
||||
assert created["node_exec_id"] == exec_id # Our generated ID was used
|
||||
assert created["node_id"] == node_id
|
||||
assert created["graph_exec_id"] == "test_graph_exec_id"
|
||||
assert created["input_name"] == input_name
|
||||
assert created["input_data"] == input_data
|
||||
|
||||
# Verify input dict returned correctly
|
||||
assert input_dict == {input_name: input_data}
|
||||
|
||||
async def test_execution_reuse_vs_creation(self, execution_client_with_mock_db):
|
||||
"""Test that execution reuse works and creation only happens when needed."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "reuse_test_node"
|
||||
block_id = "reuse_test_block"
|
||||
|
||||
# Create first execution
|
||||
exec_id_1, input_dict_1 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_1",
|
||||
input_data="value_1",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# This should reuse the existing INCOMPLETE execution
|
||||
exec_id_2, input_dict_2 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_2",
|
||||
input_data="value_2",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should reuse the same execution
|
||||
assert exec_id_1 == exec_id_2
|
||||
assert input_dict_2 == {"input_1": "value_1", "input_2": "value_2"}
|
||||
|
||||
# Only one execution should be created in database
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
|
||||
# Verify cache has the merged inputs
|
||||
cached_execution = client._cache._node_executions[exec_id_1]
|
||||
assert cached_execution.input_data == {
|
||||
"input_1": "value_1",
|
||||
"input_2": "value_2",
|
||||
}
|
||||
|
||||
# Now complete the execution and try to add another input
|
||||
client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# Verify the execution status was actually updated in the cache
|
||||
updated_execution = client._cache._node_executions[exec_id_1]
|
||||
assert (
|
||||
updated_execution.status == ExecutionStatus.COMPLETED
|
||||
), f"Expected COMPLETED but got {updated_execution.status}"
|
||||
|
||||
# This should create a NEW execution since the first is no longer INCOMPLETE
|
||||
exec_id_3, input_dict_3 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_3",
|
||||
input_data="value_3",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should be a different execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
assert input_dict_3 == {"input_3": "value_3"}
|
||||
|
||||
# Verify cache behavior: should have two different executions in cache now
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_1 in cached_executions
|
||||
assert exec_id_3 in cached_executions
|
||||
|
||||
# First execution should be COMPLETED
|
||||
assert cached_executions[exec_id_1].status == ExecutionStatus.COMPLETED
|
||||
# Third execution should be INCOMPLETE (newly created)
|
||||
assert cached_executions[exec_id_3].status == ExecutionStatus.INCOMPLETE
|
||||
|
||||
async def test_multiple_nodes_get_different_execution_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that different nodes get different execution IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
# Create executions for different nodes
|
||||
exec_id_a, _ = client.upsert_execution_input(
|
||||
node_id="node_a",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_a",
|
||||
)
|
||||
|
||||
exec_id_b, _ = client.upsert_execution_input(
|
||||
node_id="node_b",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_b",
|
||||
)
|
||||
|
||||
# Should be different executions with different IDs
|
||||
assert exec_id_a != exec_id_b
|
||||
|
||||
# Both should be valid UUIDs
|
||||
uuid.UUID(exec_id_a)
|
||||
uuid.UUID(exec_id_b)
|
||||
|
||||
# Both should be in cache
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_a in cached_executions
|
||||
assert exec_id_b in cached_executions
|
||||
|
||||
# Both should have correct node IDs
|
||||
assert cached_executions[exec_id_a].node_id == "node_a"
|
||||
assert cached_executions[exec_id_b].node_id == "node_b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionResult,
|
||||
)
|
||||
from backend.data.graph import Node
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.execution_cache import ExecutionCache
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def non_blocking_persist(func: Callable[P, T]) -> Callable[P, None]:
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
# First argument is always self for methods - access through cast for typing
|
||||
self = cast("ExecutionDataClient", args[0])
|
||||
future = self._executor.submit(func, *args, **kwargs)
|
||||
self._pending_tasks.add(future)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionDataClient:
|
||||
def __init__(
|
||||
self, executor: Executor, graph_exec_id: str, graph_metadata: GraphExecutionMeta
|
||||
):
|
||||
self._executor = executor
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._cache = ExecutionCache(graph_exec_id, self.db_client_sync)
|
||||
self._pending_tasks = set()
|
||||
self._graph_metadata = graph_metadata
|
||||
self.graph_lock = threading.RLock()
|
||||
|
||||
def finalize_execution(self, timeout: float = 30.0):
|
||||
logger.info(f"Flushing db writes for execution {self._graph_exec_id}")
|
||||
exceptions = []
|
||||
|
||||
# Wait for all pending database operations to complete
|
||||
logger.debug(
|
||||
f"Waiting for {len(self._pending_tasks)} pending database operations"
|
||||
)
|
||||
for future in list(self._pending_tasks):
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Background database operation failed: {e}")
|
||||
exceptions.append(e)
|
||||
finally:
|
||||
self._pending_tasks.discard(future)
|
||||
|
||||
self._cache.finalize()
|
||||
|
||||
if exceptions:
|
||||
logger.error(f"Background persistence failed with {len(exceptions)} errors")
|
||||
raise RuntimeError(
|
||||
f"Background persistence failed with {len(exceptions)} errors: {exceptions}"
|
||||
)
|
||||
|
||||
@property
|
||||
def db_client_async(self) -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@property
|
||||
def db_client_sync(self) -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
@property
|
||||
def event_bus(self):
|
||||
return get_execution_event_bus()
|
||||
|
||||
async def get_node(self, node_id: str) -> Node:
|
||||
return await self.db_client_async.get_node(node_id)
|
||||
|
||||
def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
) -> int:
|
||||
return self.db_client_sync.spend_credits(
|
||||
user_id=user_id, cost=cost, metadata=metadata
|
||||
)
|
||||
|
||||
def get_graph_execution_meta(
|
||||
self, user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
return self.db_client_sync.get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=execution_id
|
||||
)
|
||||
|
||||
def get_graph_metadata(
|
||||
self, graph_id: str, graph_version: int | None = None
|
||||
) -> Any:
|
||||
return self.db_client_sync.get_graph_metadata(graph_id, graph_version)
|
||||
|
||||
def get_credits(self, user_id: str) -> int:
|
||||
return self.db_client_sync.get_credits(user_id)
|
||||
|
||||
def get_user_email_by_id(self, user_id: str) -> str | None:
|
||||
return self.db_client_sync.get_user_email_by_id(user_id)
|
||||
|
||||
def get_latest_node_execution(self, node_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_latest_node_execution(node_id)
|
||||
|
||||
def get_node_execution(self, node_exec_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_node_execution(node_exec_id)
|
||||
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
) -> list[NodeExecutionResult]:
|
||||
return self._cache.get_node_executions(
|
||||
statuses=statuses, block_ids=block_ids, node_id=node_id
|
||||
)
|
||||
|
||||
def update_node_status_and_publish(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
self._cache.update_node_execution_status(exec_id, status, execution_data, stats)
|
||||
self._persist_node_status_to_db(exec_id, status, execution_data, stats)
|
||||
|
||||
def upsert_execution_input(
|
||||
self, node_id: str, input_name: str, input_data: Any, block_id: str
|
||||
) -> tuple[str, dict]:
|
||||
# Validate input parameters to prevent foreign key constraint errors
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
raise ValueError(f"Invalid node_id: {node_id}")
|
||||
if not self._graph_exec_id or not isinstance(self._graph_exec_id, str):
|
||||
raise ValueError(f"Invalid graph_exec_id: {self._graph_exec_id}")
|
||||
if not block_id or not isinstance(block_id, str):
|
||||
raise ValueError(f"Invalid block_id: {block_id}")
|
||||
|
||||
# UPDATE: Try to find an existing incomplete execution for this node and input
|
||||
if result := self._cache.find_incomplete_execution_for_input(
|
||||
node_id, input_name
|
||||
):
|
||||
exec_id, _ = result
|
||||
updated_input_data = self._cache.update_execution_input(
|
||||
exec_id, input_name, input_data
|
||||
)
|
||||
self._persist_add_input_to_db(exec_id, input_name, input_data)
|
||||
return exec_id, updated_input_data
|
||||
|
||||
# CREATE: No suitable execution found, create new one
|
||||
node_exec_id = str(uuid.uuid4())
|
||||
logger.debug(
|
||||
f"Creating new execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}"
|
||||
)
|
||||
|
||||
new_execution = NodeExecutionResult(
|
||||
user_id=self._graph_metadata.user_id,
|
||||
graph_id=self._graph_metadata.graph_id,
|
||||
graph_version=self._graph_metadata.graph_version,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={input_name: input_data},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
)
|
||||
self._cache.add_node_execution(node_exec_id, new_execution)
|
||||
self._persist_new_node_execution_to_db(
|
||||
node_exec_id, node_id, input_name, input_data
|
||||
)
|
||||
|
||||
return node_exec_id, {input_name: input_data}
|
||||
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self._cache.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
self._persist_execution_output_to_db(node_exec_id, output_name, output_data)
|
||||
|
||||
def update_graph_stats_and_publish(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> None:
|
||||
stats_dict = stats.model_dump() if stats else None
|
||||
self._cache.update_graph_stats(status=status, stats=stats_dict)
|
||||
self._persist_graph_stats_to_db(status=status, stats=stats)
|
||||
|
||||
def update_graph_start_time_and_publish(self) -> None:
|
||||
self._cache.update_graph_start_time()
|
||||
self._persist_graph_start_time_to_db()
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_node_status_to_db(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
exec_update = self.db_client_sync.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_add_input_to_db(
|
||||
self, node_exec_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
self.db_client_sync.add_input_to_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_execution_output_to_db(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self.db_client_sync.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := self.get_node_execution(node_exec_id):
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_stats_to_db(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
):
|
||||
graph_update = self.db_client_sync.update_graph_execution_stats(
|
||||
self._graph_exec_id, status, stats
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution stats for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_start_time_to_db(self):
|
||||
graph_update = self.db_client_sync.update_graph_execution_start_time(
|
||||
self._graph_exec_id
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution start time for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
async def generate_activity_status(
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
user_id: str,
|
||||
execution_status: ExecutionStatus,
|
||||
) -> str | None:
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
|
||||
return await generate_activity_status_for_execution(
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_stats=execution_stats,
|
||||
db_client=self.db_client_async,
|
||||
user_id=user_id,
|
||||
execution_status=execution_status,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _send_execution_update(self, execution: NodeExecutionResult):
|
||||
"""Send execution update to event bus."""
|
||||
try:
|
||||
self.event_bus.publish(execution)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send execution update: {e}")
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_new_node_execution_to_db(
|
||||
self, node_exec_id: str, node_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
try:
|
||||
self.db_client_sync.create_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create node execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def increment_execution_count(self, user_id: str) -> int:
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}"
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""Test suite for ExecutionDataClient."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def execution_client(event_loop):
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_id",
|
||||
graph_id="test_graph_id",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Mock all database operations to prevent connection attempts
|
||||
async_mock_client = AsyncMock()
|
||||
sync_mock_client = MagicMock()
|
||||
|
||||
# Mock all database methods to return None or empty results
|
||||
sync_mock_client.get_node_executions.return_value = []
|
||||
sync_mock_client.create_node_execution.return_value = None
|
||||
sync_mock_client.add_input_to_node_execution.return_value = None
|
||||
sync_mock_client.update_node_execution_status.return_value = None
|
||||
sync_mock_client.upsert_execution_output.return_value = None
|
||||
sync_mock_client.update_graph_execution_stats.return_value = mock_graph_meta
|
||||
sync_mock_client.update_graph_execution_start_time.return_value = mock_graph_meta
|
||||
|
||||
# Mock event bus to prevent connection attempts
|
||||
mock_event_bus = MagicMock()
|
||||
mock_event_bus.publish.return_value = None
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus",
|
||||
return_value=mock_event_bus,
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
yield client
|
||||
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionDataClient:
|
||||
|
||||
async def test_update_node_status_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that node status updates are immediately visible in cache."""
|
||||
# First create an execution to update
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
status = ExecutionStatus.RUNNING
|
||||
execution_data = {"step": "processing"}
|
||||
stats = {"duration": 5.2}
|
||||
|
||||
# Update status of existing execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=status,
|
||||
execution_data=execution_data,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify immediate visibility in cache
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == status
|
||||
# execution_data should be merged with existing input_data, not replace it
|
||||
expected_input_data = {"test_input": "test_value", "step": "processing"}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
def test_update_node_status_execution_not_found_raises_error(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that updating non-existent execution raises error instead of creating it."""
|
||||
non_existent_id = "does-not-exist"
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Execution does-not-exist not found in cache"
|
||||
):
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=non_existent_id, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
async def test_upsert_execution_output_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that output updates are immediately visible in cache."""
|
||||
# First create an execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
output_name = "result"
|
||||
output_data = {"answer": 42, "confidence": 0.95}
|
||||
|
||||
# Update to RUNNING status first
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name=output_name, output_data=output_data
|
||||
)
|
||||
# Check output through the node execution
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert output_name in cached_exec.output_data
|
||||
assert cached_exec.output_data[output_name] == [output_data]
|
||||
|
||||
async def test_get_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_node_execution returns cached data immediately."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"result": "success"},
|
||||
)
|
||||
|
||||
result = execution_client.get_node_execution(node_exec_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "result": "success"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_latest_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_latest_node_execution returns cached data."""
|
||||
node_id = "node-1"
|
||||
|
||||
# First create an execution for this node
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status to make it non-INCOMPLETE (so it's returned by get_latest)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"from": "cache"},
|
||||
)
|
||||
|
||||
result = execution_client.get_latest_node_execution(node_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.RUNNING
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "from": "cache"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_node_executions_sync_filters_correctly(self, execution_client):
|
||||
# Create executions with different statuses
|
||||
executions = [
|
||||
(ExecutionStatus.RUNNING, "block-a"),
|
||||
(ExecutionStatus.COMPLETED, "block-a"),
|
||||
(ExecutionStatus.FAILED, "block-b"),
|
||||
(ExecutionStatus.RUNNING, "block-b"),
|
||||
]
|
||||
|
||||
exec_ids = []
|
||||
for i, (status, block_id) in enumerate(executions):
|
||||
# First create the execution
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=block_id,
|
||||
)
|
||||
exec_ids.append(exec_id)
|
||||
|
||||
# Then update its status and metadata
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id, status=status, execution_data={"block": block_id}
|
||||
)
|
||||
# Update cached execution with graph_exec_id and block_id for filtering
|
||||
# Note: In real implementation, these would be set during creation
|
||||
# For test purposes, we'll skip this manual update since the filtering
|
||||
# logic should work with the data as created
|
||||
|
||||
# Test status filtering
|
||||
running_execs = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
assert len(running_execs) == 2
|
||||
assert all(e.status == ExecutionStatus.RUNNING for e in running_execs)
|
||||
|
||||
# Test block_id filtering
|
||||
block_a_execs = execution_client.get_node_executions(block_ids=["block-a"])
|
||||
assert len(block_a_execs) == 2
|
||||
assert all(e.block_id == "block-a" for e in block_a_execs)
|
||||
|
||||
# Test combined filtering
|
||||
running_block_b = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING], block_ids=["block-b"]
|
||||
)
|
||||
assert len(running_block_b) == 1
|
||||
assert running_block_b[0].status == ExecutionStatus.RUNNING
|
||||
assert running_block_b[0].block_id == "block-b"
|
||||
|
||||
async def test_write_then_read_consistency(self, execution_client):
|
||||
"""Test critical race condition scenario: immediate read after write."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="consistency-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Write status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"step": 1},
|
||||
)
|
||||
|
||||
# Write output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="intermediate",
|
||||
output_data={"progress": 50},
|
||||
)
|
||||
|
||||
# Update status again
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"step": 2},
|
||||
)
|
||||
|
||||
# All changes should be immediately visible
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data - step 2 overwrites step 1
|
||||
expected_input_data = {"test_input": "test_value", "step": 2}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
# Output should be visible in execution record
|
||||
assert cached_exec.output_data["intermediate"] == [{"progress": 50}]
|
||||
|
||||
async def test_concurrent_operations_are_thread_safe(self, execution_client):
|
||||
"""Test that concurrent operations don't corrupt cache."""
|
||||
num_threads = 3 # Reduced for simpler test
|
||||
operations_per_thread = 5 # Reduced for simpler test
|
||||
|
||||
# Create all executions upfront
|
||||
created_exec_ids = []
|
||||
for thread_id in range(num_threads):
|
||||
for i in range(operations_per_thread):
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{thread_id}-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=f"block-{thread_id}-{i}",
|
||||
)
|
||||
created_exec_ids.append((exec_id, thread_id, i))
|
||||
|
||||
def worker(thread_data):
|
||||
"""Perform multiple operations from a thread."""
|
||||
thread_id, ops = thread_data
|
||||
for i, (exec_id, _, _) in enumerate(ops):
|
||||
# Status updates
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"thread": thread_id, "op": i},
|
||||
)
|
||||
|
||||
# Output updates (use just one exec_id per thread for outputs)
|
||||
if i == 0: # Only add outputs to first execution of each thread
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=exec_id,
|
||||
output_name=f"output_{i}",
|
||||
output_data={"thread": thread_id, "value": i},
|
||||
)
|
||||
|
||||
# Organize executions by thread
|
||||
thread_data = []
|
||||
for tid in range(num_threads):
|
||||
thread_ops = [
|
||||
exec_data for exec_data in created_exec_ids if exec_data[1] == tid
|
||||
]
|
||||
thread_data.append((tid, thread_ops))
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for data in thread_data:
|
||||
thread = threading.Thread(target=worker, args=(data,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for completion
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify data integrity
|
||||
expected_executions = num_threads * operations_per_thread
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == expected_executions
|
||||
|
||||
# Verify outputs - only first execution of each thread should have outputs
|
||||
output_count = 0
|
||||
for execution in all_executions:
|
||||
if execution.output_data:
|
||||
output_count += 1
|
||||
assert output_count == num_threads # One output per thread
|
||||
|
||||
async def test_sync_and_async_versions_consistent(self, execution_client):
|
||||
"""Test that sync and async versions of output operations behave the same."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="sync-async-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="sync_result",
|
||||
output_data={"method": "sync"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="async_result",
|
||||
output_data={"method": "async"},
|
||||
)
|
||||
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert "sync_result" in cached_exec.output_data
|
||||
assert "async_result" in cached_exec.output_data
|
||||
assert cached_exec.output_data["sync_result"] == [{"method": "sync"}]
|
||||
assert cached_exec.output_data["async_result"] == [{"method": "async"}]
|
||||
|
||||
async def test_finalize_execution_completes_and_clears_cache(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that finalize_execution waits for background tasks and clears cache."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="pending-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Trigger some background operations
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name="test", output_data={"value": 1}
|
||||
)
|
||||
|
||||
# Wait for background tasks - may fail in test environment due to DB issues
|
||||
try:
|
||||
execution_client.finalize_execution(timeout=5.0)
|
||||
except RuntimeError as e:
|
||||
# In test environment, background DB operations may fail, but cache should still be cleared
|
||||
assert "Background persistence failed" in str(e)
|
||||
|
||||
# Cache should be cleared regardless of background task failures
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 0 # Cache should be cleared
|
||||
|
||||
async def test_manager_usage_pattern(self, execution_client):
|
||||
# Create executions first
|
||||
node_exec_id_1, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-1",
|
||||
input_name="input1",
|
||||
input_data="data1",
|
||||
block_id="block-1",
|
||||
)
|
||||
|
||||
node_exec_id_2, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-2",
|
||||
input_name="input_from_node1",
|
||||
input_data="value1",
|
||||
block_id="block-2",
|
||||
)
|
||||
|
||||
# Simulate manager.py workflow
|
||||
|
||||
# 1. Start execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "data1"},
|
||||
)
|
||||
|
||||
# 2. Node produces output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id_1,
|
||||
output_name="result",
|
||||
output_data={"output": "value1"},
|
||||
)
|
||||
|
||||
# 3. Complete first node
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# 4. Start second node (would read output from first)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_2,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input_from_node1": "value1"},
|
||||
)
|
||||
|
||||
# 5. Manager queries for executions
|
||||
|
||||
all_executions = execution_client.get_node_executions()
|
||||
running_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
completed_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.COMPLETED]
|
||||
)
|
||||
|
||||
# Verify manager can see all data immediately
|
||||
assert len(all_executions) == 2
|
||||
assert len(running_executions) == 1
|
||||
assert len(completed_executions) == 1
|
||||
|
||||
# Verify output is accessible
|
||||
exec_1 = execution_client.get_node_execution(node_exec_id_1)
|
||||
assert exec_1 is not None
|
||||
assert exec_1.output_data["result"] == [{"output": "value1"}]
|
||||
|
||||
def test_stats_handling_in_update_node_status(self, execution_client):
|
||||
"""Test that stats parameter is properly handled in update_node_status_and_publish."""
|
||||
# Create a fake execution directly in cache to avoid database issues
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
|
||||
node_exec_id = "test-stats-exec-id"
|
||||
fake_execution = NodeExecutionResult(
|
||||
user_id="test-user",
|
||||
graph_id="test-graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec",
|
||||
node_exec_id=node_exec_id,
|
||||
node_id="stats-test-node",
|
||||
block_id="test-block",
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={"test_input": "test_value"},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Add directly to cache
|
||||
execution_client._cache.add_node_execution(node_exec_id, fake_execution)
|
||||
|
||||
stats = {"token_count": 150, "processing_time": 2.5}
|
||||
|
||||
# Update status with stats
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify execution was updated and stats are stored
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.RUNNING
|
||||
|
||||
# Stats should be stored in proper stats field
|
||||
assert execution.stats is not None
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Only check the fields we set, ignore defaults
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
|
||||
# Update with additional stats
|
||||
additional_stats = {"error_count": 0}
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
stats=additional_stats,
|
||||
)
|
||||
|
||||
# Stats should be merged
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.COMPLETED
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Check the merged stats
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
assert stats_dict["error_count"] == 0
|
||||
|
||||
async def test_upsert_execution_input_scenarios(self, execution_client):
|
||||
"""Test different scenarios of upsert_execution_input - create vs update."""
|
||||
node_id = "test-node"
|
||||
graph_exec_id = (
|
||||
"test_graph_exec_id" # Must match the ExecutionDataClient's scope
|
||||
)
|
||||
|
||||
# Scenario 1: Create new execution when none exists
|
||||
exec_id_1, input_data_1 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="first_input",
|
||||
input_data="value1",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert execution.node_id == node_id
|
||||
assert execution.graph_exec_id == graph_exec_id
|
||||
assert input_data_1 == {"first_input": "value1"}
|
||||
|
||||
# Scenario 2: Add input to existing INCOMPLETE execution
|
||||
exec_id_2, input_data_2 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="second_input",
|
||||
input_data="value2",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should use same execution
|
||||
assert exec_id_2 == exec_id_1
|
||||
assert input_data_2 == {"first_input": "value1", "second_input": "value2"}
|
||||
|
||||
# Verify execution has both inputs
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.input_data == {
|
||||
"first_input": "value1",
|
||||
"second_input": "value2",
|
||||
}
|
||||
|
||||
# Scenario 3: Create new execution when existing is not INCOMPLETE
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
exec_id_3, input_data_3 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="third_input",
|
||||
input_data="value3",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
execution_3 = execution_client.get_node_execution(exec_id_3)
|
||||
assert execution_3 is not None
|
||||
assert input_data_3 == {"third_input": "value3"}
|
||||
|
||||
# Verify we now have 2 executions
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 2
|
||||
|
||||
def test_graph_stats_operations(self, execution_client):
|
||||
"""Test graph-level stats and start time operations."""
|
||||
|
||||
# Test update_graph_stats_and_publish
|
||||
from backend.data.model import GraphExecutionStats
|
||||
|
||||
stats = GraphExecutionStats(
|
||||
walltime=10.5, cputime=8.2, node_count=5, node_error_count=1
|
||||
)
|
||||
|
||||
execution_client.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING, stats=stats
|
||||
)
|
||||
|
||||
# Verify stats are stored in cache
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
execution_client.update_graph_start_time_and_publish()
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
def test_public_methods_accessible(self, execution_client):
|
||||
"""Test that public methods are accessible."""
|
||||
assert hasattr(execution_client._cache, "update_node_execution_status")
|
||||
assert hasattr(execution_client._cache, "upsert_execution_output")
|
||||
assert hasattr(execution_client._cache, "add_node_execution")
|
||||
assert hasattr(execution_client._cache, "find_incomplete_execution_for_input")
|
||||
assert hasattr(execution_client._cache, "update_execution_input")
|
||||
assert hasattr(execution_client, "upsert_execution_input")
|
||||
assert hasattr(execution_client, "update_node_status_and_publish")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -5,38 +5,15 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
@@ -48,19 +25,28 @@ from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
@@ -69,21 +55,17 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
async_time_measured,
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -138,7 +120,6 @@ async def execute_node(
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
db_client: The client to send execution updates to the server.
|
||||
creds_manager: The manager to acquire and release credentials.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
@@ -235,7 +216,7 @@ async def execute_node(
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
execution_data_client: ExecutionDataClient,
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
@@ -248,8 +229,7 @@ async def _enqueue_next_nodes(
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
) -> NodeExecutionEntry:
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
execution_data=data,
|
||||
@@ -282,21 +262,22 @@ async def _enqueue_next_nodes(
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
next_node = await execution_data_client.get_node(next_node_id)
|
||||
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
with execution_data_client.graph_lock:
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = await db_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
next_node_exec_id, next_node_input = (
|
||||
execution_data_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
block_id=next_node.block_id,
|
||||
)
|
||||
)
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=next_node_exec_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
)
|
||||
@@ -308,8 +289,8 @@ async def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := await db_client.get_latest_node_execution(
|
||||
next_node_id, graph_exec_id
|
||||
latest_execution := execution_data_client.get_latest_node_execution(
|
||||
next_node_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
@@ -348,9 +329,8 @@ async def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in await db_client.get_node_executions(
|
||||
for iexec in execution_data_client.get_node_executions(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.INCOMPLETE],
|
||||
):
|
||||
idata = iexec.input_data
|
||||
@@ -414,6 +394,9 @@ class ExecutionProcessor:
|
||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||
"""
|
||||
|
||||
# Current execution data client (scoped to current graph execution)
|
||||
execution_data: ExecutionDataClient
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
self,
|
||||
@@ -431,8 +414,7 @@ class ExecutionProcessor:
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
node = await self.execution_data.get_node(node_exec.node_id)
|
||||
execution_stats = NodeExecutionStats()
|
||||
|
||||
timing_info, status = await self._on_node_execution(
|
||||
@@ -440,7 +422,6 @@ class ExecutionProcessor:
|
||||
node_exec=node_exec,
|
||||
node_exec_progress=node_exec_progress,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
@@ -464,15 +445,12 @@ class ExecutionProcessor:
|
||||
if node_error and not isinstance(node_error, str):
|
||||
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
|
||||
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=status,
|
||||
stats=node_stats,
|
||||
)
|
||||
await async_update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
@@ -485,22 +463,17 @@ class ExecutionProcessor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
async def persist_output(output_name: str, output_data: Any) -> None:
|
||||
await db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := await db_client.get_node_execution(
|
||||
node_exec.node_exec_id
|
||||
):
|
||||
await send_async_execution_update(exec_update)
|
||||
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -512,8 +485,7 @@ class ExecutionProcessor:
|
||||
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
@@ -574,6 +546,8 @@ class ExecutionProcessor:
|
||||
self.node_evaluation_thread = threading.Thread(
|
||||
target=self.node_evaluation_loop.run_forever, daemon=True
|
||||
)
|
||||
# single thread executor
|
||||
self.execution_data_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
@@ -593,9 +567,13 @@ class ExecutionProcessor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_client()
|
||||
|
||||
exec_meta = db_client.get_graph_execution_meta(
|
||||
# Get graph execution metadata first via sync client
|
||||
from backend.util.clients import get_database_manager_client
|
||||
|
||||
db_client_sync = get_database_manager_client()
|
||||
|
||||
exec_meta = db_client_sync.get_graph_execution_meta(
|
||||
user_id=graph_exec.user_id,
|
||||
execution_id=graph_exec.graph_exec_id,
|
||||
)
|
||||
@@ -605,12 +583,15 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
# Create scoped ExecutionDataClient for this graph execution with metadata
|
||||
self.execution_data = ExecutionDataClient(
|
||||
self.execution_data_executor, graph_exec.graph_exec_id, exec_meta
|
||||
)
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
|
||||
)
|
||||
self.execution_data.update_graph_start_time_and_publish()
|
||||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||||
@@ -620,9 +601,7 @@ class ExecutionProcessor:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
|
||||
)
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
else:
|
||||
@@ -653,12 +632,10 @@ class ExecutionProcessor:
|
||||
|
||||
# Activity status handling
|
||||
activity_status = asyncio.run_coroutine_threadsafe(
|
||||
generate_activity_status_for_execution(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.generate_activity_status(
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_version=graph_exec.graph_version,
|
||||
execution_stats=exec_stats,
|
||||
db_client=get_db_async_client(),
|
||||
user_id=graph_exec.user_id,
|
||||
execution_status=status,
|
||||
),
|
||||
@@ -673,15 +650,14 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
self._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=exec_meta.status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
self.execution_data.finalize_execution()
|
||||
|
||||
def _charge_usage(
|
||||
self,
|
||||
@@ -690,7 +666,6 @@ class ExecutionProcessor:
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
@@ -700,7 +675,7 @@ class ExecutionProcessor:
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -718,7 +693,7 @@ class ExecutionProcessor:
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -751,7 +726,6 @@ class ExecutionProcessor:
|
||||
"""
|
||||
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
|
||||
error: Exception | None = None
|
||||
db_client = get_db_client()
|
||||
execution_stats_lock = threading.Lock()
|
||||
|
||||
# State holders ----------------------------------------------------
|
||||
@@ -762,7 +736,7 @@ class ExecutionProcessor:
|
||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
|
||||
try:
|
||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
if self.execution_data.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
@@ -774,7 +748,7 @@ class ExecutionProcessor:
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_inputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec=graph_exec,
|
||||
),
|
||||
self.node_evaluation_loop,
|
||||
@@ -789,16 +763,34 @@ class ExecutionProcessor:
|
||||
# ------------------------------------------------------------
|
||||
# Pre‑populate queue ---------------------------------------
|
||||
# ------------------------------------------------------------
|
||||
for node_exec in db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
|
||||
queued_executions = self.execution_data.get_node_executions(
|
||||
statuses=[
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
],
|
||||
):
|
||||
node_entry = node_exec.to_node_execution_entry(graph_exec.user_context)
|
||||
execution_queue.add(node_entry)
|
||||
)
|
||||
log_metadata.info(
|
||||
f"Pre-populating queue with {len(queued_executions)} executions from cache"
|
||||
)
|
||||
|
||||
for i, node_exec in enumerate(queued_executions):
|
||||
log_metadata.info(
|
||||
f" [{i}] {node_exec.node_exec_id}: status={node_exec.status}, node={node_exec.node_id}"
|
||||
)
|
||||
try:
|
||||
node_entry = node_exec.to_node_execution_entry(
|
||||
graph_exec.user_context
|
||||
)
|
||||
execution_queue.add(node_entry)
|
||||
log_metadata.info(" Added to execution queue successfully")
|
||||
except Exception as e:
|
||||
log_metadata.error(f" Failed to add to execution queue: {e}")
|
||||
|
||||
log_metadata.info(
|
||||
f"Execution queue populated with {len(queued_executions)} executions"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Main dispatch / polling loop -----------------------------
|
||||
@@ -818,13 +810,14 @@ class ExecutionProcessor:
|
||||
try:
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_count=self.execution_data.increment_execution_count(
|
||||
graph_exec.user_id
|
||||
),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
@@ -832,19 +825,17 @@ class ExecutionProcessor:
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
output_data=str(error),
|
||||
)
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
error,
|
||||
@@ -931,12 +922,13 @@ class ExecutionProcessor:
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
# Background task finalization moved to finally block
|
||||
|
||||
# Output moderation
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_outputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -990,7 +982,6 @@ class ExecutionProcessor:
|
||||
error=error,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
log_metadata=log_metadata,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
@error_logged(swallow=True)
|
||||
@@ -1003,7 +994,6 @@ class ExecutionProcessor:
|
||||
error: Exception | None,
|
||||
graph_exec_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
db_client: "DatabaseManagerClient",
|
||||
) -> None:
|
||||
"""
|
||||
Clean up running node executions and evaluations when graph execution ends.
|
||||
@@ -1037,8 +1027,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
while queued_execution := execution_queue.get_or_none():
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=queued_execution.node_exec_id,
|
||||
status=execution_status,
|
||||
stats={"error": str(error)} if error else None,
|
||||
@@ -1066,12 +1055,10 @@ class ExecutionProcessor:
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
|
||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||
|
||||
for next_execution in await _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
execution_data_client=self.execution_data,
|
||||
node=output.node,
|
||||
output=output.data,
|
||||
user_id=graph_exec.user_id,
|
||||
@@ -1085,15 +1072,13 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_agent_run_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = db_client.get_graph_metadata(
|
||||
metadata = self.execution_data.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
outputs = self.execution_data.get_node_executions(
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
@@ -1122,13 +1107,12 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
metadata = self.execution_data.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
@@ -1147,7 +1131,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
@@ -1169,7 +1153,6 @@ class ExecutionProcessor:
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
@@ -1198,7 +1181,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
@@ -1576,117 +1559,3 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
|
||||
@func_retry
|
||||
async def send_async_execution_update(
|
||||
entry: GraphExecution | NodeExecutionResult | None,
|
||||
) -> None:
|
||||
if entry is None:
|
||||
return
|
||||
await get_async_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@func_retry
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
async def async_update_node_execution_status(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = await db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
await send_async_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
def update_node_execution_status(
|
||||
db_client: "DatabaseManagerClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
async def async_update_graph_execution_state(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = await db_client.update_graph_execution_stats(
|
||||
graph_exec_id, status, stats
|
||||
)
|
||||
if graph_update:
|
||||
await send_async_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
def update_graph_execution_state(
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||||
if graph_update:
|
||||
send_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
@@ -32,13 +32,17 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -62,6 +66,19 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
assert "$4.00" in discord_message
|
||||
assert "$6.00" in discord_message
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
@@ -90,12 +107,17 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -105,6 +127,19 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
@@ -133,12 +168,17 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
@@ -147,3 +187,16 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
# Verify no notification was sent (user was already below threshold)
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
@@ -147,8 +147,10 @@ class AutoModManager:
|
||||
return None
|
||||
|
||||
# Get completed executions and collect outputs
|
||||
completed_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
|
||||
completed_executions = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not completed_executions:
|
||||
@@ -218,7 +220,7 @@ class AutoModManager:
|
||||
):
|
||||
"""Update node execution statuses for frontend display when moderation fails"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.executor.manager import send_async_execution_update
|
||||
from backend.util.clients import get_async_execution_event_bus
|
||||
|
||||
if moderation_type == "input":
|
||||
# For input moderation, mark queued/running/incomplete nodes as failed
|
||||
@@ -232,8 +234,10 @@ class AutoModManager:
|
||||
target_statuses = [ExecutionStatus.COMPLETED]
|
||||
|
||||
# Get the executions that need to be updated
|
||||
executions_to_update = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=target_statuses, include_exec_data=True
|
||||
executions_to_update = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=target_statuses,
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not executions_to_update:
|
||||
@@ -276,10 +280,12 @@ class AutoModManager:
|
||||
updated_execs = await asyncio.gather(*exec_updates)
|
||||
|
||||
# Send all websocket updates in parallel
|
||||
event_bus = get_async_execution_event_bus()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
send_async_execution_update(updated_exec)
|
||||
event_bus.publish(updated_exec)
|
||||
for updated_exec in updated_execs
|
||||
if updated_exec is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
import { OnboardingText } from "@/components/onboarding/OnboardingText";
|
||||
import StarRating from "@/components/onboarding/StarRating";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
|
||||
@@ -17,7 +18,6 @@ import { cn } from "@/lib/utils";
|
||||
import { Play } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/RunAgentInputs/RunAgentInputs";
|
||||
|
||||
export default function Page() {
|
||||
const { state, updateState, setStep } = useOnboarding(
|
||||
@@ -233,7 +233,7 @@ export default function Page() {
|
||||
description={inputSubSchema.description}
|
||||
/>
|
||||
</label>
|
||||
<RunAgentInputs
|
||||
<TypeBasedInput
|
||||
schema={inputSubSchema}
|
||||
value={state?.agentInput?.[key]}
|
||||
placeholder={inputSubSchema.description}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { OAuthPopupResultMessage } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import { OAuthPopupResultMessage } from "@/components/integrations/credentials-input";
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
// This route is intended to be used as the callback for integration OAuth flows,
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
import { z } from "zod";
|
||||
import { useForm, type UseFormReturn } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export type APIKeyFormValues = {
|
||||
apiKey: string;
|
||||
title: string;
|
||||
expiresAt?: string;
|
||||
};
|
||||
|
||||
type Args = {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
siblingInputs?: Record<string, any>;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
};
|
||||
|
||||
export function useAPIKeyCredentialsModal({
|
||||
schema,
|
||||
siblingInputs,
|
||||
onCredentialsCreate,
|
||||
}: Args): {
|
||||
form: UseFormReturn<APIKeyFormValues>;
|
||||
isLoading: boolean;
|
||||
supportsApiKey: boolean;
|
||||
provider?: string;
|
||||
providerName?: string;
|
||||
schemaDescription?: string;
|
||||
onSubmit: (values: APIKeyFormValues) => Promise<void>;
|
||||
} {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
const form = useForm<APIKeyFormValues>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
async function onSubmit(values: APIKeyFormValues) {
|
||||
if (!credentials || credentials.isLoading) return;
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await credentials.createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider: credentials.provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
form,
|
||||
isLoading: !credentials || credentials.isLoading,
|
||||
supportsApiKey: !!credentials?.supportsApiKey,
|
||||
provider: credentials?.provider,
|
||||
providerName:
|
||||
!credentials || credentials.isLoading
|
||||
? undefined
|
||||
: credentials.providerName,
|
||||
schemaDescription: schema.description,
|
||||
onSubmit,
|
||||
};
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
|
||||
type Props = {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
providerName: string;
|
||||
};
|
||||
|
||||
export function OAuthFlowWaitingModal({ open, onClose, providerName }: Props) {
|
||||
return (
|
||||
<Dialog
|
||||
title={`Waiting on ${providerName} sign-in process...`}
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: (isOpen) => {
|
||||
if (!isOpen) onClose();
|
||||
},
|
||||
}}
|
||||
onClose={onClose}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Complete the sign-in process in the pop-up window.
|
||||
<br />
|
||||
Closing this dialog will cancel the sign-in process.
|
||||
</p>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,225 +0,0 @@
|
||||
import React from "react";
|
||||
import { format } from "date-fns";
|
||||
|
||||
import { Input as DSInput } from "@/components/atoms/Input/Input";
|
||||
import { Select as DSSelect } from "@/components/atoms/Select/Select";
|
||||
import { MultiToggle } from "@/components/molecules/MultiToggle/MultiToggle";
|
||||
// Removed shadcn Select usage in favor of DS Select for time picker
|
||||
import {
|
||||
BlockIOObjectSubSchema,
|
||||
BlockIOSubSchema,
|
||||
DataType,
|
||||
determineDataType,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { TimePicker } from "@/components/molecules/TimePicker/TimePicker";
|
||||
import { FileInput } from "@/components/atoms/FileInput/FileInput";
|
||||
import { useRunAgentInputs } from "./useRunAgentInputs";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
|
||||
/**
|
||||
* A generic prop structure for the TypeBasedInput.
|
||||
*
|
||||
* onChange expects an event-like object with e.target.value so the parent
|
||||
* can do something like setInputValues(e.target.value).
|
||||
*/
|
||||
interface Props {
|
||||
schema: BlockIOSubSchema;
|
||||
value?: any;
|
||||
placeholder?: string;
|
||||
onChange: (value: any) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* A generic, data-type-based input component that uses Shadcn UI.
|
||||
* It inspects the schema via `determineDataType` and renders
|
||||
* the correct UI component.
|
||||
*/
|
||||
export function RunAgentInputs({
|
||||
schema,
|
||||
value,
|
||||
placeholder,
|
||||
onChange,
|
||||
...props
|
||||
}: Props & React.HTMLAttributes<HTMLElement>) {
|
||||
const { handleUploadFile, uploadProgress } = useRunAgentInputs();
|
||||
|
||||
const dataType = determineDataType(schema);
|
||||
const baseId = String(schema.title ?? "input")
|
||||
.replace(/\s+/g, "-")
|
||||
.toLowerCase();
|
||||
|
||||
let innerInputElement: React.ReactNode = null;
|
||||
switch (dataType) {
|
||||
case DataType.NUMBER:
|
||||
innerInputElement = (
|
||||
<DSInput
|
||||
id={`${baseId}-number`}
|
||||
label={schema.title ?? placeholder ?? "Number"}
|
||||
hideLabel
|
||||
type="number"
|
||||
value={value ?? ""}
|
||||
placeholder={placeholder || "Enter number"}
|
||||
onChange={(e) =>
|
||||
onChange(Number((e.target as HTMLInputElement).value))
|
||||
}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.LONG_TEXT:
|
||||
innerInputElement = (
|
||||
<DSInput
|
||||
id={`${baseId}-textarea`}
|
||||
label={schema.title ?? placeholder ?? "Text"}
|
||||
hideLabel
|
||||
type="textarea"
|
||||
rows={3}
|
||||
value={value ?? ""}
|
||||
placeholder={placeholder || "Enter text"}
|
||||
onChange={(e) => onChange((e.target as HTMLInputElement).value)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.BOOLEAN:
|
||||
innerInputElement = (
|
||||
<>
|
||||
<span className="text-sm text-gray-500">
|
||||
{placeholder || (value ? "Enabled" : "Disabled")}
|
||||
</span>
|
||||
<Switch
|
||||
className="ml-auto"
|
||||
checked={!!value}
|
||||
onCheckedChange={(checked: boolean) => onChange(checked)}
|
||||
{...props}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.DATE:
|
||||
innerInputElement = (
|
||||
<DSInput
|
||||
id={`${baseId}-date`}
|
||||
label={schema.title ?? placeholder ?? "Date"}
|
||||
hideLabel
|
||||
type="date"
|
||||
value={value ? format(value as Date, "yyyy-MM-dd") : ""}
|
||||
onChange={(e) => {
|
||||
const v = (e.target as HTMLInputElement).value;
|
||||
if (!v) onChange(undefined);
|
||||
else {
|
||||
const [y, m, d] = v.split("-").map(Number);
|
||||
onChange(new Date(y, m - 1, d));
|
||||
}
|
||||
}}
|
||||
placeholder={placeholder || "Pick a date"}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.TIME:
|
||||
innerInputElement = (
|
||||
<TimePicker value={value?.toString()} onChange={onChange} />
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.DATE_TIME:
|
||||
innerInputElement = (
|
||||
<DSInput
|
||||
id={`${baseId}-datetime`}
|
||||
label={schema.title ?? placeholder ?? "Date time"}
|
||||
hideLabel
|
||||
type="datetime-local"
|
||||
value={value ?? ""}
|
||||
onChange={(e) => onChange((e.target as HTMLInputElement).value)}
|
||||
placeholder={placeholder || "Enter date and time"}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.FILE:
|
||||
innerInputElement = (
|
||||
<FileInput
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
onChange={onChange}
|
||||
onUploadFile={handleUploadFile}
|
||||
uploadProgress={uploadProgress}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.SELECT:
|
||||
if (
|
||||
"enum" in schema &&
|
||||
Array.isArray(schema.enum) &&
|
||||
schema.enum.length > 0
|
||||
) {
|
||||
innerInputElement = (
|
||||
<DSSelect
|
||||
id={`${baseId}-select`}
|
||||
label={schema.title ?? placeholder ?? "Select"}
|
||||
hideLabel
|
||||
value={value ?? ""}
|
||||
onValueChange={(val: string) => onChange(val)}
|
||||
placeholder={placeholder || "Select an option"}
|
||||
options={schema.enum
|
||||
.filter((opt) => opt)
|
||||
.map((opt) => ({ value: opt, label: String(opt) }))}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
case DataType.MULTI_SELECT: {
|
||||
const _schema = schema as BlockIOObjectSubSchema;
|
||||
const allKeys = Object.keys(_schema.properties);
|
||||
const selectedValues = Object.entries(value || {})
|
||||
.filter(([_, v]) => v)
|
||||
.map(([k]) => k);
|
||||
|
||||
innerInputElement = (
|
||||
<MultiToggle
|
||||
items={allKeys.map((key) => ({
|
||||
value: key,
|
||||
label: _schema.properties[key]?.title ?? key,
|
||||
}))}
|
||||
selectedValues={selectedValues}
|
||||
onChange={(values: string[]) =>
|
||||
onChange(
|
||||
Object.fromEntries(
|
||||
allKeys.map((opt) => [opt, values.includes(opt)]),
|
||||
),
|
||||
)
|
||||
}
|
||||
className="nodrag"
|
||||
aria-label={schema.title}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
case DataType.SHORT_TEXT:
|
||||
default:
|
||||
innerInputElement = (
|
||||
<DSInput
|
||||
id={`${baseId}-text`}
|
||||
label={schema.title ?? placeholder ?? "Text"}
|
||||
hideLabel
|
||||
type="text"
|
||||
value={value ?? ""}
|
||||
onChange={(e) => onChange((e.target as HTMLInputElement).value)}
|
||||
placeholder={placeholder || "Enter text"}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <div className="no-drag relative flex">{innerInputElement}</div>;
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { useState } from "react";
|
||||
|
||||
export function useRunAgentInputs() {
|
||||
const api = new BackendAPI();
|
||||
const [uploadProgress, setUploadProgress] = useState(0);
|
||||
|
||||
async function handleUploadFile(file: File) {
|
||||
const result = await api.uploadFile(file, "gcs", 24, (progress) =>
|
||||
setUploadProgress(progress),
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
return {
|
||||
uploadProgress,
|
||||
handleUploadFile,
|
||||
};
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { useState } from "react";
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useAgentRunModal } from "./useAgentRunModal";
|
||||
@@ -9,13 +8,10 @@ import { ModalHeader } from "./components/ModalHeader/ModalHeader";
|
||||
import { AgentCostSection } from "./components/AgentCostSection/AgentCostSection";
|
||||
import { AgentSectionHeader } from "./components/AgentSectionHeader/AgentSectionHeader";
|
||||
import { DefaultRunView } from "./components/DefaultRunView/DefaultRunView";
|
||||
import { RunAgentModalContextProvider } from "./context";
|
||||
import { ScheduleView } from "./components/ScheduleView/ScheduleView";
|
||||
import { AgentDetails } from "./components/AgentDetails/AgentDetails";
|
||||
import { RunActions } from "./components/RunActions/RunActions";
|
||||
import { ScheduleActions } from "./components/ScheduleActions/ScheduleActions";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { AlarmIcon, TrashIcon } from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
triggerSlot: React.ReactNode;
|
||||
@@ -32,18 +28,10 @@ export function RunAgentModal({ triggerSlot, agent }: Props) {
|
||||
defaultRunType,
|
||||
inputValues,
|
||||
setInputValues,
|
||||
inputCredentials,
|
||||
setInputCredentials,
|
||||
presetName,
|
||||
presetDescription,
|
||||
setPresetName,
|
||||
setPresetDescription,
|
||||
scheduleName,
|
||||
cronExpression,
|
||||
allRequiredInputsAreSet,
|
||||
// agentInputFields, // Available if needed for future use
|
||||
agentInputFields,
|
||||
agentCredentialsInputFields,
|
||||
hasInputFields,
|
||||
isExecuting,
|
||||
isCreatingSchedule,
|
||||
@@ -65,158 +53,104 @@ export function RunAgentModal({ triggerSlot, agent }: Props) {
|
||||
}));
|
||||
}
|
||||
|
||||
function handleCredentialsChange(key: string, value: any | undefined) {
|
||||
setInputCredentials((prev) => {
|
||||
const next = { ...prev } as Record<string, any>;
|
||||
if (value === undefined) {
|
||||
delete next[key];
|
||||
return next;
|
||||
}
|
||||
next[key] = value;
|
||||
return next;
|
||||
});
|
||||
}
|
||||
|
||||
function handleSetOpen(open: boolean) {
|
||||
setIsOpen(open);
|
||||
// Always reset to Run view when opening/closing
|
||||
if (open || !open) handleGoBack();
|
||||
}
|
||||
|
||||
function handleRemoveSchedule() {
|
||||
handleGoBack();
|
||||
handleSetScheduleName("");
|
||||
handleSetCronExpression("");
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog
|
||||
controlled={{ isOpen, set: handleSetOpen }}
|
||||
styling={{ maxWidth: "600px", maxHeight: "90vh" }}
|
||||
>
|
||||
<Dialog.Trigger>{triggerSlot}</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="flex h-full flex-col">
|
||||
{/* Header */}
|
||||
<div className="flex-shrink-0">
|
||||
<ModalHeader agent={agent} />
|
||||
<AgentCostSection flowId={agent.graph_id} />
|
||||
</div>
|
||||
<Dialog
|
||||
controlled={{ isOpen, set: handleSetOpen }}
|
||||
styling={{ maxWidth: "600px", maxHeight: "90vh" }}
|
||||
>
|
||||
<Dialog.Trigger>{triggerSlot}</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="flex h-full flex-col">
|
||||
{/* Header */}
|
||||
<div className="flex-shrink-0">
|
||||
<ModalHeader agent={agent} />
|
||||
<AgentCostSection flowId={agent.graph_id} />
|
||||
</div>
|
||||
|
||||
{/* Scrollable content */}
|
||||
<div
|
||||
className="flex-1 overflow-y-auto overflow-x-hidden pr-1"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
>
|
||||
{/* Setup Section */}
|
||||
<div className="mt-10">
|
||||
{hasInputFields ? (
|
||||
<RunAgentModalContextProvider
|
||||
value={{
|
||||
agent,
|
||||
defaultRunType,
|
||||
presetName,
|
||||
setPresetName,
|
||||
presetDescription,
|
||||
setPresetDescription,
|
||||
inputValues,
|
||||
setInputValue: handleInputChange,
|
||||
agentInputFields,
|
||||
inputCredentials,
|
||||
setInputCredentialsValue: handleCredentialsChange,
|
||||
agentCredentialsInputFields,
|
||||
}}
|
||||
>
|
||||
<>
|
||||
<AgentSectionHeader
|
||||
title={
|
||||
defaultRunType === "automatic-trigger"
|
||||
? "Trigger Setup"
|
||||
: "Agent Setup"
|
||||
}
|
||||
/>
|
||||
<div>
|
||||
<DefaultRunView />
|
||||
</div>
|
||||
</>
|
||||
</RunAgentModalContextProvider>
|
||||
) : null}
|
||||
</div>
|
||||
|
||||
{/* Schedule Section - always visible */}
|
||||
<div className="mt-8">
|
||||
<AgentSectionHeader title="Schedule Setup" />
|
||||
{showScheduleView ? (
|
||||
<>
|
||||
<div className="mb-3 flex justify-start">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={handleRemoveSchedule}
|
||||
>
|
||||
<TrashIcon size={16} />
|
||||
Remove schedule
|
||||
</Button>
|
||||
</div>
|
||||
{/* Scrollable content */}
|
||||
<div
|
||||
className="flex-1 overflow-y-auto overflow-x-hidden pr-1"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
>
|
||||
{/* Setup Section */}
|
||||
<div className="mt-10">
|
||||
{showScheduleView ? (
|
||||
<>
|
||||
<AgentSectionHeader title="Schedule Setup" />
|
||||
<div>
|
||||
<ScheduleView
|
||||
agent={agent}
|
||||
scheduleName={scheduleName}
|
||||
cronExpression={cronExpression}
|
||||
inputValues={inputValues}
|
||||
onScheduleNameChange={handleSetScheduleName}
|
||||
onCronExpressionChange={handleSetCronExpression}
|
||||
onInputChange={handleInputChange}
|
||||
onValidityChange={setIsScheduleFormValid}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<div className="flex flex-col items-start gap-2">
|
||||
<Text variant="body" className="mb-3 !text-zinc-500">
|
||||
No schedule configured. Create a schedule to run this
|
||||
agent automatically at a specific time.{" "}
|
||||
</Text>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={handleShowSchedule}
|
||||
>
|
||||
<AlarmIcon size={16} />
|
||||
Create schedule
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Agent Details Section */}
|
||||
<div className="mt-8">
|
||||
<AgentSectionHeader title="Agent Details" />
|
||||
<AgentDetails agent={agent} />
|
||||
</div>
|
||||
</>
|
||||
) : hasInputFields ? (
|
||||
<>
|
||||
<AgentSectionHeader
|
||||
title={
|
||||
defaultRunType === "automatic-trigger"
|
||||
? "Trigger Setup"
|
||||
: "Agent Setup"
|
||||
}
|
||||
/>
|
||||
<div>
|
||||
<DefaultRunView
|
||||
agent={agent}
|
||||
defaultRunType={defaultRunType}
|
||||
inputValues={inputValues}
|
||||
onInputChange={handleInputChange}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
) : null}
|
||||
</div>
|
||||
|
||||
{/* Fixed Actions - sticky inside dialog scroll */}
|
||||
<Dialog.Footer className="sticky bottom-0 z-10 bg-white">
|
||||
{showScheduleView ? (
|
||||
<ScheduleActions
|
||||
onSchedule={handleSchedule}
|
||||
isCreatingSchedule={isCreatingSchedule}
|
||||
allRequiredInputsAreSet={
|
||||
allRequiredInputsAreSet &&
|
||||
!!scheduleName.trim() &&
|
||||
isScheduleFormValid
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<RunActions
|
||||
defaultRunType={defaultRunType}
|
||||
onRun={handleRun}
|
||||
isExecuting={isExecuting}
|
||||
isSettingUpTrigger={isSettingUpTrigger}
|
||||
allRequiredInputsAreSet={allRequiredInputsAreSet}
|
||||
/>
|
||||
)}
|
||||
</Dialog.Footer>
|
||||
{/* Agent Details Section */}
|
||||
<div className="mt-8">
|
||||
<AgentSectionHeader title="Agent Details" />
|
||||
<AgentDetails agent={agent} />
|
||||
</div>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
</>
|
||||
|
||||
{/* Fixed Actions - sticky inside dialog scroll */}
|
||||
<Dialog.Footer className="sticky bottom-0 z-10 bg-white">
|
||||
{!showScheduleView ? (
|
||||
<RunActions
|
||||
hasExternalTrigger={agent.has_external_trigger}
|
||||
defaultRunType={defaultRunType}
|
||||
onShowSchedule={handleShowSchedule}
|
||||
onRun={handleRun}
|
||||
isExecuting={isExecuting}
|
||||
isSettingUpTrigger={isSettingUpTrigger}
|
||||
allRequiredInputsAreSet={allRequiredInputsAreSet}
|
||||
/>
|
||||
) : (
|
||||
<ScheduleActions
|
||||
onGoBack={handleGoBack}
|
||||
onSchedule={handleSchedule}
|
||||
isCreatingSchedule={isCreatingSchedule}
|
||||
allRequiredInputsAreSet={
|
||||
allRequiredInputsAreSet &&
|
||||
!!scheduleName.trim() &&
|
||||
isScheduleFormValid
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { formatAgentStatus, getStatusColor } from "./helpers";
|
||||
import { formatDate } from "@/lib/utils/time";
|
||||
|
||||
interface Props {
|
||||
@@ -10,6 +11,20 @@ interface Props {
|
||||
export function AgentDetails({ agent }: Props) {
|
||||
return (
|
||||
<div className="mt-4 flex flex-col gap-5">
|
||||
<div>
|
||||
<Text variant="body-medium" className="mb-1 !text-black">
|
||||
Current Status
|
||||
</Text>
|
||||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
className={`h-2 w-2 rounded-full ${getStatusColor(agent.status)}`}
|
||||
/>
|
||||
<Text variant="body" className="!text-zinc-700">
|
||||
{formatAgentStatus(agent.status)}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Text variant="body-medium" className="mb-1 !text-black">
|
||||
Version
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import { LibraryAgentStatus } from "@/app/api/__generated__/models/libraryAgentStatus";
|
||||
|
||||
export function formatAgentStatus(status: LibraryAgentStatus) {
|
||||
const statusMap: Record<string, string> = {
|
||||
COMPLETED: "Ready",
|
||||
HEALTHY: "Running",
|
||||
WAITING: "Run Queued",
|
||||
ERROR: "Failed Run",
|
||||
};
|
||||
|
||||
return statusMap[status];
|
||||
}
|
||||
|
||||
export function getStatusColor(status: LibraryAgentStatus): string {
|
||||
const colorMap: Record<LibraryAgentStatus, string> = {
|
||||
COMPLETED: "bg-blue-300",
|
||||
HEALTHY: "bg-green-300",
|
||||
WAITING: "bg-amber-300",
|
||||
ERROR: "bg-red-300",
|
||||
};
|
||||
|
||||
return colorMap[status] || "bg-gray-300";
|
||||
}
|
||||
@@ -1,100 +1,30 @@
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { RunVariant } from "../../useAgentRunModal";
|
||||
import { WebhookTriggerBanner } from "../WebhookTriggerBanner/WebhookTriggerBanner";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import { useRunAgentModalContext } from "../../context";
|
||||
import { RunAgentInputs } from "../../../RunAgentInputs/RunAgentInputs";
|
||||
import { AgentInputFields } from "../AgentInputFields/AgentInputFields";
|
||||
|
||||
export function DefaultRunView() {
|
||||
const {
|
||||
agent,
|
||||
defaultRunType,
|
||||
presetName,
|
||||
setPresetName,
|
||||
presetDescription,
|
||||
setPresetDescription,
|
||||
inputValues,
|
||||
setInputValue,
|
||||
agentInputFields,
|
||||
inputCredentials,
|
||||
setInputCredentialsValue,
|
||||
agentCredentialsInputFields,
|
||||
} = useRunAgentModalContext();
|
||||
interface Props {
|
||||
agent: LibraryAgent;
|
||||
defaultRunType: RunVariant;
|
||||
inputValues: Record<string, any>;
|
||||
onInputChange: (key: string, value: string) => void;
|
||||
}
|
||||
|
||||
export function DefaultRunView({
|
||||
agent,
|
||||
defaultRunType,
|
||||
inputValues,
|
||||
onInputChange,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="mb-12 mt-6">
|
||||
<div className="mt-6">
|
||||
{defaultRunType === "automatic-trigger" && <WebhookTriggerBanner />}
|
||||
|
||||
{/* Preset/Trigger fields */}
|
||||
{defaultRunType === "automatic-trigger" && (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
Trigger Name
|
||||
<SchemaTooltip description="Name of the trigger you are setting up" />
|
||||
</label>
|
||||
<Input
|
||||
id="trigger_name"
|
||||
label="Trigger Name"
|
||||
hideLabel
|
||||
value={presetName}
|
||||
placeholder="Enter trigger name"
|
||||
onChange={(e) => setPresetName(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
Trigger Description
|
||||
<SchemaTooltip description="Description of the trigger you are setting up" />
|
||||
</label>
|
||||
<Input
|
||||
id="trigger_description"
|
||||
label="Trigger Description"
|
||||
hideLabel
|
||||
value={presetDescription}
|
||||
placeholder="Enter trigger description"
|
||||
onChange={(e) => setPresetDescription(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Credentials inputs */}
|
||||
{Object.entries(agentCredentialsInputFields || {}).map(
|
||||
([key, inputSubSchema]) => (
|
||||
<CredentialsInput
|
||||
key={key}
|
||||
schema={{ ...inputSubSchema, discriminator: undefined } as any}
|
||||
selectedCredentials={
|
||||
(inputCredentials && inputCredentials[key]) ??
|
||||
inputSubSchema.default
|
||||
}
|
||||
onSelectCredentials={(value) =>
|
||||
setInputCredentialsValue(key, value)
|
||||
}
|
||||
siblingInputs={inputValues}
|
||||
hideIfSingleCredentialAvailable={!agent.has_external_trigger}
|
||||
/>
|
||||
),
|
||||
)}
|
||||
|
||||
{/* Regular inputs */}
|
||||
{Object.entries(agentInputFields || {}).map(([key, inputSubSchema]) => (
|
||||
<div key={key} className="flex flex-col gap-0 space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
{inputSubSchema.title || key}
|
||||
<SchemaTooltip description={inputSubSchema.description} />
|
||||
</label>
|
||||
|
||||
<RunAgentInputs
|
||||
schema={inputSubSchema}
|
||||
value={inputValues[key] ?? inputSubSchema.default}
|
||||
placeholder={inputSubSchema.description}
|
||||
onChange={(value) => setInputValue(key, value)}
|
||||
data-testid={`agent-input-${key}`}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
<AgentInputFields
|
||||
agent={agent}
|
||||
inputValues={inputValues}
|
||||
onInputChange={onInputChange}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ interface ModalHeaderProps {
|
||||
}
|
||||
|
||||
export function ModalHeader({ agent }: ModalHeaderProps) {
|
||||
const isUnknownCreator = agent.creator_name === "Unknown";
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center gap-3">
|
||||
@@ -16,9 +15,9 @@ export function ModalHeader({ agent }: ModalHeaderProps) {
|
||||
</div>
|
||||
<div>
|
||||
<Text variant="h3">{agent.name}</Text>
|
||||
{!isUnknownCreator ? (
|
||||
<Text variant="body-medium">by {agent.creator_name}</Text>
|
||||
) : null}
|
||||
<Text variant="body-medium">
|
||||
by {agent.creator_name === "Unknown" ? "–" : agent.creator_name}
|
||||
</Text>
|
||||
<ShowMoreText
|
||||
previewLimit={80}
|
||||
variant="small"
|
||||
|
||||
@@ -2,7 +2,9 @@ import { Button } from "@/components/atoms/Button/Button";
|
||||
import { RunVariant } from "../../useAgentRunModal";
|
||||
|
||||
interface Props {
|
||||
hasExternalTrigger: boolean;
|
||||
defaultRunType: RunVariant;
|
||||
onShowSchedule: () => void;
|
||||
onRun: () => void;
|
||||
isExecuting?: boolean;
|
||||
isSettingUpTrigger?: boolean;
|
||||
@@ -10,7 +12,9 @@ interface Props {
|
||||
}
|
||||
|
||||
export function RunActions({
|
||||
hasExternalTrigger,
|
||||
defaultRunType,
|
||||
onShowSchedule,
|
||||
onRun,
|
||||
isExecuting = false,
|
||||
isSettingUpTrigger = false,
|
||||
@@ -18,6 +22,11 @@ export function RunActions({
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="flex justify-end gap-3">
|
||||
{!hasExternalTrigger && (
|
||||
<Button variant="secondary" onClick={onShowSchedule}>
|
||||
Schedule Run
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={onRun}
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
|
||||
interface Props {
|
||||
onGoBack: () => void;
|
||||
onSchedule: () => void;
|
||||
isCreatingSchedule?: boolean;
|
||||
allRequiredInputsAreSet?: boolean;
|
||||
}
|
||||
|
||||
export function ScheduleActions({
|
||||
onGoBack,
|
||||
onSchedule,
|
||||
isCreatingSchedule = false,
|
||||
allRequiredInputsAreSet = true,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="flex justify-end gap-3">
|
||||
<Button variant="ghost" onClick={onGoBack}>
|
||||
Go Back
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={onSchedule}
|
||||
disabled={!allRequiredInputsAreSet || isCreatingSchedule}
|
||||
loading={isCreatingSchedule}
|
||||
>
|
||||
Schedule Agent
|
||||
Create Schedule
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { MultiToggle } from "@/components/molecules/MultiToggle/MultiToggle";
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { AgentInputFields } from "../AgentInputFields/AgentInputFields";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Select } from "@/components/atoms/Select/Select";
|
||||
import { useScheduleView } from "./useScheduleView";
|
||||
@@ -7,18 +9,24 @@ import { useCallback, useState } from "react";
|
||||
import { validateSchedule } from "./helpers";
|
||||
|
||||
interface Props {
|
||||
agent: LibraryAgent;
|
||||
scheduleName: string;
|
||||
cronExpression: string;
|
||||
inputValues: Record<string, any>;
|
||||
onScheduleNameChange: (name: string) => void;
|
||||
onCronExpressionChange: (expression: string) => void;
|
||||
onInputChange: (key: string, value: string) => void;
|
||||
onValidityChange?: (valid: boolean) => void;
|
||||
}
|
||||
|
||||
export function ScheduleView({
|
||||
agent,
|
||||
scheduleName,
|
||||
cronExpression: _cronExpression,
|
||||
inputValues,
|
||||
onScheduleNameChange,
|
||||
onCronExpressionChange,
|
||||
onInputChange,
|
||||
onValidityChange,
|
||||
}: Props) {
|
||||
const {
|
||||
@@ -131,7 +139,12 @@ export function ScheduleView({
|
||||
error={errors.time}
|
||||
/>
|
||||
|
||||
{/** Agent inputs are rendered in the main modal; none here. */}
|
||||
<AgentInputFields
|
||||
agent={agent}
|
||||
inputValues={inputValues}
|
||||
onInputChange={onInputChange}
|
||||
variant="schedule"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { createContext, useContext } from "react";
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { RunVariant } from "./useAgentRunModal";
|
||||
|
||||
export interface RunAgentModalContextValue {
|
||||
agent: LibraryAgent;
|
||||
defaultRunType: RunVariant;
|
||||
// Preset / Trigger
|
||||
presetName: string;
|
||||
setPresetName: (value: string) => void;
|
||||
presetDescription: string;
|
||||
setPresetDescription: (value: string) => void;
|
||||
// Inputs
|
||||
inputValues: Record<string, any>;
|
||||
setInputValue: (key: string, value: any) => void;
|
||||
agentInputFields: Record<string, any>;
|
||||
// Credentials
|
||||
inputCredentials: Record<string, any>;
|
||||
setInputCredentialsValue: (key: string, value: any | undefined) => void;
|
||||
agentCredentialsInputFields: Record<string, any>;
|
||||
}
|
||||
|
||||
const RunAgentModalContext = createContext<RunAgentModalContextValue | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
export function useRunAgentModalContext(): RunAgentModalContextValue {
|
||||
const ctx = useContext(RunAgentModalContext);
|
||||
if (!ctx) throw new Error("RunAgentModalContext missing provider");
|
||||
return ctx;
|
||||
}
|
||||
|
||||
interface ProviderProps {
|
||||
value: RunAgentModalContextValue;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function RunAgentModalContextProvider({
|
||||
value,
|
||||
children,
|
||||
}: ProviderProps) {
|
||||
return (
|
||||
<RunAgentModalContext.Provider value={value}>
|
||||
{children}
|
||||
</RunAgentModalContext.Provider>
|
||||
);
|
||||
}
|
||||
@@ -65,125 +65,3 @@ export function parseCronDescription(cron: string): string {
|
||||
|
||||
return cron; // Fallback to showing the raw cron
|
||||
}
|
||||
|
||||
export function getMissingRequiredInputs(
|
||||
inputSchema: any,
|
||||
values: Record<string, any>,
|
||||
): string[] {
|
||||
if (!inputSchema || typeof inputSchema !== "object") return [];
|
||||
const required: string[] = Array.isArray(inputSchema.required)
|
||||
? inputSchema.required
|
||||
: [];
|
||||
const properties: Record<string, any> = inputSchema.properties || {};
|
||||
return required.filter((key) => {
|
||||
const field = properties[key];
|
||||
if (field?.hidden) return false;
|
||||
return isEmpty(values[key]);
|
||||
});
|
||||
}
|
||||
|
||||
export function getMissingCredentials(
|
||||
credentialsProperties: Record<string, any> | undefined,
|
||||
values: Record<string, any>,
|
||||
): string[] {
|
||||
const props = credentialsProperties || {};
|
||||
return Object.keys(props).filter((key) => !(key in values));
|
||||
}
|
||||
|
||||
type DeriveReadinessParams = {
|
||||
inputSchema: any;
|
||||
credentialsProperties?: Record<string, any>;
|
||||
values: Record<string, any>;
|
||||
credentialsValues: Record<string, any>;
|
||||
};
|
||||
|
||||
export function deriveReadiness(params: DeriveReadinessParams): {
|
||||
missingInputs: string[];
|
||||
missingCredentials: string[];
|
||||
credentialsRequired: boolean;
|
||||
allRequiredInputsAreSet: boolean;
|
||||
} {
|
||||
const missingInputs = getMissingRequiredInputs(
|
||||
params.inputSchema,
|
||||
params.values,
|
||||
);
|
||||
const credentialsRequired =
|
||||
Object.keys(params.credentialsProperties || {}).length > 0;
|
||||
const missingCredentials = getMissingCredentials(
|
||||
params.credentialsProperties,
|
||||
params.credentialsValues,
|
||||
);
|
||||
const allRequiredInputsAreSet =
|
||||
missingInputs.length === 0 &&
|
||||
(!credentialsRequired || missingCredentials.length === 0);
|
||||
return {
|
||||
missingInputs,
|
||||
missingCredentials,
|
||||
credentialsRequired,
|
||||
allRequiredInputsAreSet,
|
||||
};
|
||||
}
|
||||
|
||||
export function getVisibleInputFields(inputSchema: any): Record<string, any> {
|
||||
if (
|
||||
!inputSchema ||
|
||||
typeof inputSchema !== "object" ||
|
||||
!("properties" in inputSchema) ||
|
||||
!inputSchema.properties
|
||||
) {
|
||||
return {} as Record<string, any>;
|
||||
}
|
||||
const properties = inputSchema.properties as Record<string, any>;
|
||||
return Object.fromEntries(
|
||||
Object.entries(properties).filter(([, subSchema]) => !subSchema?.hidden),
|
||||
);
|
||||
}
|
||||
|
||||
export function getCredentialFields(
|
||||
credentialsInputSchema: any,
|
||||
): Record<string, any> {
|
||||
if (
|
||||
!credentialsInputSchema ||
|
||||
typeof credentialsInputSchema !== "object" ||
|
||||
!("properties" in credentialsInputSchema) ||
|
||||
!credentialsInputSchema.properties
|
||||
) {
|
||||
return {} as Record<string, any>;
|
||||
}
|
||||
return credentialsInputSchema.properties as Record<string, any>;
|
||||
}
|
||||
|
||||
type CollectMissingFieldsOptions = {
|
||||
needScheduleName?: boolean;
|
||||
scheduleName: string;
|
||||
missingInputs: string[];
|
||||
credentialsRequired: boolean;
|
||||
allCredentialsAreSet: boolean;
|
||||
missingCredentials: string[];
|
||||
};
|
||||
|
||||
export function collectMissingFields(
|
||||
options: CollectMissingFieldsOptions,
|
||||
): string[] {
|
||||
const scheduleMissing =
|
||||
options.needScheduleName && !options.scheduleName ? ["schedule_name"] : [];
|
||||
|
||||
const missingCreds =
|
||||
options.credentialsRequired && !options.allCredentialsAreSet
|
||||
? options.missingCredentials.map((k) => `credentials:${k}`)
|
||||
: [];
|
||||
|
||||
return ([] as string[])
|
||||
.concat(scheduleMissing)
|
||||
.concat(options.missingInputs)
|
||||
.concat(missingCreds);
|
||||
}
|
||||
|
||||
export function getErrorMessage(error: unknown): string {
|
||||
if (typeof error === "string") return error;
|
||||
if (error && typeof error === "object" && "message" in error) {
|
||||
const msg = (error as any).message;
|
||||
if (typeof msg === "string" && msg.trim().length > 0) return msg;
|
||||
}
|
||||
return "An unexpected error occurred.";
|
||||
}
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useState, useCallback, useMemo } from "react";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { isEmpty } from "@/lib/utils";
|
||||
import { usePostV1ExecuteGraphAgent } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { usePostV1CreateExecutionSchedule as useCreateSchedule } from "@/app/api/__generated__/endpoints/schedules/schedules";
|
||||
import { usePostV2SetupTrigger } from "@/app/api/__generated__/endpoints/presets/presets";
|
||||
import { ExecuteGraphResponse } from "@/app/api/__generated__/models/executeGraphResponse";
|
||||
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
||||
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
|
||||
import {
|
||||
collectMissingFields,
|
||||
getErrorMessage,
|
||||
deriveReadiness,
|
||||
getVisibleInputFields,
|
||||
getCredentialFields,
|
||||
} from "./helpers";
|
||||
|
||||
export type RunVariant =
|
||||
| "manual"
|
||||
@@ -35,11 +29,6 @@ export function useAgentRunModal(
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [showScheduleView, setShowScheduleView] = useState(false);
|
||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
||||
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
|
||||
{},
|
||||
);
|
||||
const [presetName, setPresetName] = useState<string>("");
|
||||
const [presetDescription, setPresetDescription] = useState<string>("");
|
||||
const defaultScheduleName = useMemo(() => `Run ${agent.name}`, [agent.name]);
|
||||
const [scheduleName, setScheduleName] = useState(defaultScheduleName);
|
||||
const [cronExpression, setCronExpression] = useState("0 9 * * 1");
|
||||
@@ -50,9 +39,71 @@ export function useAgentRunModal(
|
||||
: "manual";
|
||||
|
||||
// API mutations
|
||||
const executeGraphMutation = usePostV1ExecuteGraphAgent();
|
||||
const createScheduleMutation = useCreateSchedule();
|
||||
const setupTriggerMutation = usePostV2SetupTrigger();
|
||||
const executeGraphMutation = usePostV1ExecuteGraphAgent({
|
||||
mutation: {
|
||||
onSuccess: (response) => {
|
||||
if (response.status === 200) {
|
||||
toast({
|
||||
title: "✅ Agent execution started",
|
||||
description: "Your agent is now running.",
|
||||
});
|
||||
callbacks?.onRun?.(response.data);
|
||||
setIsOpen(false);
|
||||
}
|
||||
},
|
||||
onError: (error: any) => {
|
||||
toast({
|
||||
title: "❌ Failed to execute agent",
|
||||
description: error.message || "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const createScheduleMutation = useCreateSchedule({
|
||||
mutation: {
|
||||
onSuccess: (response) => {
|
||||
if (response.status === 200) {
|
||||
toast({
|
||||
title: "✅ Schedule created",
|
||||
description: `Agent scheduled to run: ${scheduleName}`,
|
||||
});
|
||||
callbacks?.onCreateSchedule?.(response.data);
|
||||
setIsOpen(false);
|
||||
}
|
||||
},
|
||||
onError: (error: any) => {
|
||||
toast({
|
||||
title: "❌ Failed to create schedule",
|
||||
description: error.message || "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const setupTriggerMutation = usePostV2SetupTrigger({
|
||||
mutation: {
|
||||
onSuccess: (response: any) => {
|
||||
if (response.status === 200) {
|
||||
toast({
|
||||
title: "✅ Trigger setup complete",
|
||||
description: "Your webhook trigger is now active.",
|
||||
});
|
||||
callbacks?.onSetupTrigger?.(response.data);
|
||||
setIsOpen(false);
|
||||
}
|
||||
},
|
||||
onError: (error: any) => {
|
||||
toast({
|
||||
title: "❌ Failed to setup trigger",
|
||||
description: error.message || "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Input schema validation
|
||||
const agentInputSchema = useMemo(
|
||||
@@ -60,48 +111,42 @@ export function useAgentRunModal(
|
||||
[agent.input_schema],
|
||||
);
|
||||
|
||||
const agentInputFields = useMemo(
|
||||
() => getVisibleInputFields(agentInputSchema),
|
||||
[agentInputSchema],
|
||||
);
|
||||
const agentInputFields = useMemo(() => {
|
||||
if (
|
||||
!agentInputSchema ||
|
||||
typeof agentInputSchema !== "object" ||
|
||||
!("properties" in agentInputSchema) ||
|
||||
!agentInputSchema.properties
|
||||
) {
|
||||
return {};
|
||||
}
|
||||
const properties = agentInputSchema.properties as Record<string, any>;
|
||||
return Object.fromEntries(
|
||||
Object.entries(properties).filter(
|
||||
([_, subSchema]: [string, any]) => !subSchema.hidden,
|
||||
),
|
||||
);
|
||||
}, [agentInputSchema]);
|
||||
|
||||
const agentCredentialsInputFields = useMemo(
|
||||
() => getCredentialFields(agent.credentials_input_schema),
|
||||
[agent.credentials_input_schema],
|
||||
);
|
||||
// Validation logic
|
||||
const [allRequiredInputsAreSet, missingInputs] = useMemo(() => {
|
||||
const nonEmptyInputs = new Set(
|
||||
Object.keys(inputValues).filter((k) => !isEmpty(inputValues[k])),
|
||||
);
|
||||
const requiredInputs = new Set(
|
||||
(agentInputSchema.required as string[]) || [],
|
||||
);
|
||||
const missing = [...requiredInputs].filter(
|
||||
(input) => !nonEmptyInputs.has(input),
|
||||
);
|
||||
return [missing.length === 0, missing];
|
||||
}, [agentInputSchema.required, inputValues]);
|
||||
|
||||
// Validation logic (presence checks derived from schemas)
|
||||
const {
|
||||
missingInputs,
|
||||
missingCredentials,
|
||||
credentialsRequired,
|
||||
allRequiredInputsAreSet,
|
||||
} = useMemo(
|
||||
() =>
|
||||
deriveReadiness({
|
||||
inputSchema: agentInputSchema,
|
||||
credentialsProperties: agentCredentialsInputFields,
|
||||
values: inputValues,
|
||||
credentialsValues: inputCredentials,
|
||||
}),
|
||||
[
|
||||
agentInputSchema,
|
||||
agentCredentialsInputFields,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
],
|
||||
);
|
||||
|
||||
const notifyMissingRequirements = useCallback(
|
||||
const notifyMissingInputs = useCallback(
|
||||
(needScheduleName: boolean = false) => {
|
||||
const allMissingFields = collectMissingFields({
|
||||
needScheduleName,
|
||||
scheduleName,
|
||||
missingInputs,
|
||||
credentialsRequired,
|
||||
allCredentialsAreSet: missingCredentials.length === 0,
|
||||
missingCredentials,
|
||||
});
|
||||
const allMissingFields = (
|
||||
needScheduleName && !scheduleName ? ["schedule_name"] : []
|
||||
).concat(missingInputs);
|
||||
|
||||
toast({
|
||||
title: "⚠️ Missing required inputs",
|
||||
@@ -109,35 +154,19 @@ export function useAgentRunModal(
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
[
|
||||
missingInputs,
|
||||
scheduleName,
|
||||
toast,
|
||||
credentialsRequired,
|
||||
missingCredentials,
|
||||
],
|
||||
[missingInputs, scheduleName, toast],
|
||||
);
|
||||
|
||||
function showError(title: string, error: unknown) {
|
||||
toast({
|
||||
title,
|
||||
description: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
|
||||
async function handleRun() {
|
||||
// Action handlers
|
||||
const handleRun = useCallback(() => {
|
||||
if (!allRequiredInputsAreSet) {
|
||||
notifyMissingRequirements();
|
||||
notifyMissingInputs();
|
||||
return;
|
||||
}
|
||||
|
||||
const shouldUseTrigger = defaultRunType === "automatic-trigger";
|
||||
|
||||
if (shouldUseTrigger) {
|
||||
if (defaultRunType === "automatic-trigger") {
|
||||
// Setup trigger
|
||||
const hasScheduleName = scheduleName.trim().length > 0;
|
||||
if (!hasScheduleName) {
|
||||
if (!scheduleName.trim()) {
|
||||
toast({
|
||||
title: "⚠️ Trigger name required",
|
||||
description: "Please provide a name for your trigger.",
|
||||
@@ -145,63 +174,47 @@ export function useAgentRunModal(
|
||||
});
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const nameToUse = presetName || scheduleName;
|
||||
const descriptionToUse =
|
||||
presetDescription || `Trigger for ${agent.name}`;
|
||||
const response = await setupTriggerMutation.mutateAsync({
|
||||
data: {
|
||||
name: nameToUse,
|
||||
description: descriptionToUse,
|
||||
graph_id: agent.graph_id,
|
||||
graph_version: agent.graph_version,
|
||||
trigger_config: inputValues,
|
||||
agent_credentials: inputCredentials,
|
||||
},
|
||||
});
|
||||
if (response.status === 200) {
|
||||
toast({ title: "Trigger setup complete" });
|
||||
callbacks?.onSetupTrigger?.(response.data);
|
||||
setIsOpen(false);
|
||||
} else {
|
||||
throw new Error(JSON.stringify(response?.data?.detail));
|
||||
}
|
||||
} catch (error: any) {
|
||||
showError("❌ Failed to setup trigger", error);
|
||||
}
|
||||
|
||||
setupTriggerMutation.mutate({
|
||||
data: {
|
||||
name: scheduleName,
|
||||
description: `Trigger for ${agent.name}`,
|
||||
graph_id: agent.graph_id,
|
||||
graph_version: agent.graph_version,
|
||||
trigger_config: inputValues,
|
||||
agent_credentials: {}, // TODO: Add credentials handling if needed
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// Manual execution
|
||||
try {
|
||||
const response = await executeGraphMutation.mutateAsync({
|
||||
graphId: agent.graph_id,
|
||||
graphVersion: agent.graph_version,
|
||||
data: {
|
||||
inputs: inputValues,
|
||||
credentials_inputs: inputCredentials,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
toast({ title: "Agent execution started" });
|
||||
callbacks?.onRun?.(response.data);
|
||||
setIsOpen(false);
|
||||
} else {
|
||||
throw new Error(JSON.stringify(response?.data?.detail));
|
||||
}
|
||||
} catch (error: any) {
|
||||
showError("Failed to execute agent", error);
|
||||
}
|
||||
executeGraphMutation.mutate({
|
||||
graphId: agent.graph_id,
|
||||
graphVersion: agent.graph_version,
|
||||
data: {
|
||||
inputs: inputValues,
|
||||
credentials_inputs: {}, // TODO: Add credentials handling if needed
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [
|
||||
allRequiredInputsAreSet,
|
||||
defaultRunType,
|
||||
scheduleName,
|
||||
inputValues,
|
||||
agent,
|
||||
notifyMissingInputs,
|
||||
setupTriggerMutation,
|
||||
executeGraphMutation,
|
||||
toast,
|
||||
]);
|
||||
|
||||
async function handleSchedule() {
|
||||
const handleSchedule = useCallback(() => {
|
||||
if (!allRequiredInputsAreSet) {
|
||||
notifyMissingRequirements(true);
|
||||
notifyMissingInputs(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const hasScheduleName = scheduleName.trim().length > 0;
|
||||
if (!hasScheduleName) {
|
||||
if (!scheduleName.trim()) {
|
||||
toast({
|
||||
title: "⚠️ Schedule name required",
|
||||
description: "Please provide a name for your schedule.",
|
||||
@@ -209,27 +222,27 @@ export function useAgentRunModal(
|
||||
});
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const nameToUse = presetName || scheduleName;
|
||||
const response = await createScheduleMutation.mutateAsync({
|
||||
graphId: agent.graph_id,
|
||||
data: {
|
||||
name: nameToUse,
|
||||
cron: cronExpression,
|
||||
inputs: inputValues,
|
||||
graph_version: agent.graph_version,
|
||||
credentials: inputCredentials,
|
||||
},
|
||||
});
|
||||
if (response.status === 200) {
|
||||
toast({ title: "Schedule created" });
|
||||
callbacks?.onCreateSchedule?.(response.data);
|
||||
setIsOpen(false);
|
||||
}
|
||||
} catch (error: any) {
|
||||
showError("❌ Failed to create schedule", error);
|
||||
}
|
||||
}
|
||||
|
||||
createScheduleMutation.mutate({
|
||||
graphId: agent.graph_id,
|
||||
data: {
|
||||
name: scheduleName,
|
||||
cron: cronExpression,
|
||||
inputs: inputValues,
|
||||
graph_version: agent.graph_version,
|
||||
credentials: {}, // TODO: Add credentials handling if needed
|
||||
},
|
||||
});
|
||||
}, [
|
||||
allRequiredInputsAreSet,
|
||||
scheduleName,
|
||||
cronExpression,
|
||||
inputValues,
|
||||
agent,
|
||||
notifyMissingInputs,
|
||||
createScheduleMutation,
|
||||
toast,
|
||||
]);
|
||||
|
||||
function handleShowSchedule() {
|
||||
// Initialize with sensible defaults when entering schedule view
|
||||
@@ -264,18 +277,11 @@ export function useAgentRunModal(
|
||||
defaultRunType,
|
||||
inputValues,
|
||||
setInputValues,
|
||||
inputCredentials,
|
||||
setInputCredentials,
|
||||
presetName,
|
||||
presetDescription,
|
||||
setPresetName,
|
||||
setPresetDescription,
|
||||
scheduleName,
|
||||
cronExpression,
|
||||
allRequiredInputsAreSet,
|
||||
missingInputs,
|
||||
agentInputFields,
|
||||
agentCredentialsInputFields,
|
||||
hasInputFields,
|
||||
isExecuting: executeGraphMutation.isPending,
|
||||
isCreatingSchedule: createScheduleMutation.isPending,
|
||||
|
||||
@@ -41,7 +41,7 @@ import LoadingBox, { LoadingSpinner } from "@/components/ui/loading";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { AgentRunDetailsView } from "./components/agent-run-details-view";
|
||||
import { AgentRunDraftView } from "./components/agent-run-draft-view";
|
||||
import { useAgentRunsInfinite } from "./use-agent-runs";
|
||||
import { useAgentRunsInfinite } from "../use-agent-runs";
|
||||
import { AgentRunsSelectorList } from "./components/agent-runs-selector-list";
|
||||
import { AgentScheduleDetailsView } from "./components/agent-schedule-details-view";
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { IconCross, IconPlay, IconSave } from "@/components/ui/icons";
|
||||
import { CalendarClockIcon, Trash2Icon } from "lucide-react";
|
||||
import { CronSchedulerDialog } from "@/components/cron-scheduler-dialog";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/RunAgentInputs/RunAgentInputs";
|
||||
import { CredentialsInput } from "@/components/integrations/credentials-input";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { cn, isEmpty } from "@/lib/utils";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
@@ -596,7 +596,7 @@ export function AgentRunDraftView({
|
||||
<SchemaTooltip description={inputSubSchema.description} />
|
||||
</label>
|
||||
|
||||
<RunAgentInputs
|
||||
<TypeBasedInput
|
||||
schema={inputSubSchema}
|
||||
value={inputValues[key] ?? inputSubSchema.default}
|
||||
placeholder={inputSubSchema.description}
|
||||
|
||||
@@ -21,12 +21,8 @@ import { Separator } from "@/components/ui/separator";
|
||||
|
||||
import { agentRunStatusMap } from "@/components/agents/agent-run-status-chip";
|
||||
import AgentRunSummaryCard from "@/components/agents/agent-run-summary-card";
|
||||
import { AgentRunsQuery } from "../use-agent-runs";
|
||||
import { AgentRunsQuery } from "../../use-agent-runs";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { RunAgentModal } from "../../AgentRunsView/components/RunAgentModal/RunAgentModal";
|
||||
import { PlusIcon } from "@phosphor-icons/react";
|
||||
import { LibraryAgent as GeneratedLibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
|
||||
interface AgentRunsSelectorListProps {
|
||||
agent: LibraryAgent;
|
||||
@@ -71,8 +67,6 @@ export function AgentRunsSelectorList({
|
||||
"runs",
|
||||
);
|
||||
|
||||
const isNewAgentRunsEnabled = useGetFlag(Flag.NEW_AGENT_RUNS);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedView.type === "schedule") {
|
||||
setActiveListTab("scheduled");
|
||||
@@ -85,17 +79,7 @@ export function AgentRunsSelectorList({
|
||||
|
||||
return (
|
||||
<aside className={cn("flex flex-col gap-4", className)}>
|
||||
{isNewAgentRunsEnabled ? (
|
||||
<RunAgentModal
|
||||
triggerSlot={
|
||||
<Button variant="primary" size="large" className="w-full">
|
||||
<PlusIcon size={20} /> New Run
|
||||
</Button>
|
||||
}
|
||||
agent={agent as unknown as GeneratedLibraryAgent}
|
||||
agentId={agent.id.toString()}
|
||||
/>
|
||||
) : allowDraftNewRun ? (
|
||||
{allowDraftNewRun && (
|
||||
<Button
|
||||
className={"mb-4 hidden lg:flex"}
|
||||
onClick={onSelectDraftNewRun}
|
||||
@@ -103,7 +87,7 @@ export function AgentRunsSelectorList({
|
||||
>
|
||||
New {agent.has_external_trigger ? "trigger" : "run"}
|
||||
</Button>
|
||||
) : null}
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Badge
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
"use client";
|
||||
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
import { OldAgentLibraryView } from "./components/OldAgentLibraryView/OldAgentLibraryView";
|
||||
import { AgentRunsView } from "./components/AgentRunsView/AgentRunsView";
|
||||
|
||||
export default function AgentLibraryPage() {
|
||||
const isNewAgentRunsEnabled = useGetFlag(Flag.NEW_AGENT_RUNS);
|
||||
|
||||
if (isNewAgentRunsEnabled) {
|
||||
return <AgentRunsView />;
|
||||
}
|
||||
|
||||
return <OldAgentLibraryView />;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { IconKey, IconUser } from "@/components/ui/icons";
|
||||
import { Trash2Icon } from "lucide-react";
|
||||
import { KeyIcon } from "@phosphor-icons/react/dist/ssr";
|
||||
import { providerIcons } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import { providerIcons } from "@/components/integrations/credentials-input";
|
||||
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
|
||||
import {
|
||||
Table,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { Form, FormControl, FormField, FormItem } from "@/components/ui/form";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { NotificationPreference } from "@/app/api/__generated__/models/notificationPreference";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
import { useNotificationForm } from "./useNotificationForm";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
|
||||
type NotificationFormProps = {
|
||||
preferences: NotificationPreference;
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
setNestedProperty,
|
||||
} from "@/lib/utils";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { TextRenderer } from "@/components/ui/render";
|
||||
import { history } from "./history";
|
||||
import NodeHandle from "./NodeHandle";
|
||||
@@ -52,20 +53,11 @@ import {
|
||||
TrashIcon,
|
||||
CopyIcon,
|
||||
ExitIcon,
|
||||
Pencil1Icon,
|
||||
} from "@radix-ui/react-icons";
|
||||
import { Key } from "@phosphor-icons/react";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
import { getV1GetAyrshareSsoUrl } from "@/app/api/__generated__/endpoints/integrations/integrations";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Switch } from "./atoms/Switch/Switch";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
export type ConnectionData = Array<{
|
||||
edge_id: string;
|
||||
@@ -100,7 +92,6 @@ export type CustomNodeData = {
|
||||
errors?: { [key: string]: string };
|
||||
isOutputStatic?: boolean;
|
||||
uiType: BlockUIType;
|
||||
metadata?: { [key: string]: any };
|
||||
};
|
||||
|
||||
export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
@@ -115,12 +106,6 @@ export const CustomNode = React.memo(
|
||||
const [activeKey, setActiveKey] = useState<string | null>(null);
|
||||
const [inputModalValue, setInputModalValue] = useState<string>("");
|
||||
const [isOutputModalOpen, setIsOutputModalOpen] = useState(false);
|
||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||
const [customTitle, setCustomTitle] = useState(
|
||||
data.metadata?.customized_name || "",
|
||||
);
|
||||
const [isTitleHovered, setIsTitleHovered] = useState(false);
|
||||
const titleInputRef = useRef<HTMLInputElement>(null);
|
||||
const { updateNodeData, deleteElements, addNodes, getNode } = useReactFlow<
|
||||
CustomNode,
|
||||
Edge
|
||||
@@ -200,39 +185,6 @@ export const CustomNode = React.memo(
|
||||
[id, updateNodeData],
|
||||
);
|
||||
|
||||
const handleTitleEdit = useCallback(() => {
|
||||
setIsEditingTitle(true);
|
||||
setTimeout(() => {
|
||||
titleInputRef.current?.focus();
|
||||
titleInputRef.current?.select();
|
||||
}, 0);
|
||||
}, []);
|
||||
|
||||
const handleTitleSave = useCallback(() => {
|
||||
setIsEditingTitle(false);
|
||||
const newMetadata = {
|
||||
...data.metadata,
|
||||
customized_name: customTitle.trim() || undefined,
|
||||
};
|
||||
updateNodeData(id, { metadata: newMetadata });
|
||||
}, [customTitle, data.metadata, id, updateNodeData]);
|
||||
|
||||
const handleTitleKeyDown = useCallback(
|
||||
(e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter") {
|
||||
handleTitleSave();
|
||||
} else if (e.key === "Escape") {
|
||||
setCustomTitle(data.metadata?.customized_name || "");
|
||||
setIsEditingTitle(false);
|
||||
}
|
||||
},
|
||||
[handleTitleSave, data.metadata],
|
||||
);
|
||||
|
||||
const displayTitle =
|
||||
customTitle ||
|
||||
beautifyString(data.blockType?.replace(/Block$/, "") || data.title);
|
||||
|
||||
useEffect(() => {
|
||||
isInitialSetup.current = false;
|
||||
if (data.uiType === BlockUIType.AGENT) {
|
||||
@@ -597,10 +549,6 @@ export const CustomNode = React.memo(
|
||||
block_id: data.block_id,
|
||||
connections: [],
|
||||
isOutputOpen: false,
|
||||
metadata: {
|
||||
...data.metadata,
|
||||
customized_name: undefined, // Don't copy the custom name
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -867,58 +815,14 @@ export const CustomNode = React.memo(
|
||||
|
||||
<div className="flex w-full flex-col justify-start space-y-2.5 px-4 pt-4">
|
||||
<div className="flex flex-row items-center space-x-2 font-semibold">
|
||||
<div
|
||||
className="group flex items-center gap-1"
|
||||
onMouseEnter={() => setIsTitleHovered(true)}
|
||||
onMouseLeave={() => setIsTitleHovered(false)}
|
||||
>
|
||||
{isEditingTitle ? (
|
||||
<Input
|
||||
ref={titleInputRef}
|
||||
value={customTitle}
|
||||
onChange={(e) => setCustomTitle(e.target.value)}
|
||||
onBlur={handleTitleSave}
|
||||
onKeyDown={handleTitleKeyDown}
|
||||
className="h-7 w-auto min-w-[100px] max-w-[200px] px-2 py-1 text-lg font-semibold"
|
||||
placeholder={beautifyString(
|
||||
data.blockType?.replace(/Block$/, "") || data.title,
|
||||
)}
|
||||
/>
|
||||
) : (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<h3 className="font-roboto cursor-default text-lg">
|
||||
<TextRenderer
|
||||
value={displayTitle}
|
||||
truncateLengthLimit={80}
|
||||
/>
|
||||
</h3>
|
||||
</TooltipTrigger>
|
||||
{customTitle && (
|
||||
<TooltipContent>
|
||||
<p>
|
||||
Type:{" "}
|
||||
{beautifyString(
|
||||
data.blockType?.replace(/Block$/, "") ||
|
||||
data.title,
|
||||
)}
|
||||
</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
{isTitleHovered && !isEditingTitle && (
|
||||
<button
|
||||
onClick={handleTitleEdit}
|
||||
className="cursor-pointer rounded p-1 opacity-0 transition-opacity hover:bg-gray-100 group-hover:opacity-100"
|
||||
aria-label="Edit title"
|
||||
>
|
||||
<Pencil1Icon className="h-4 w-4" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<h3 className="font-roboto text-lg">
|
||||
<TextRenderer
|
||||
value={beautifyString(
|
||||
data.blockType?.replace(/Block$/, "") || data.title,
|
||||
)}
|
||||
truncateLengthLimit={80}
|
||||
/>
|
||||
</h3>
|
||||
<span className="text-xs text-gray-500">#{id.split("-")[0]}</span>
|
||||
|
||||
<div className="w-auto grow" />
|
||||
|
||||
@@ -37,11 +37,7 @@ import {
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import {
|
||||
getTypeColor,
|
||||
findNewlyAddedBlockCoordinates,
|
||||
beautifyString,
|
||||
} from "@/lib/utils";
|
||||
import { getTypeColor, findNewlyAddedBlockCoordinates } from "@/lib/utils";
|
||||
import { history } from "./history";
|
||||
import { CustomEdge } from "./CustomEdge";
|
||||
import ConnectionLine from "./ConnectionLine";
|
||||
@@ -98,7 +94,6 @@ const FlowEditor: React.FC<{
|
||||
updateNode,
|
||||
getViewport,
|
||||
setViewport,
|
||||
screenToFlowPosition,
|
||||
} = useReactFlow<CustomNode, CustomEdge>();
|
||||
const [nodeId, setNodeId] = useState<number>(1);
|
||||
const [isAnyModalOpen, setIsAnyModalOpen] = useState(false);
|
||||
@@ -683,85 +678,6 @@ const FlowEditor: React.FC<{
|
||||
|
||||
const isNewBlockEnabled = useGetFlag(Flag.NEW_BLOCK_MENU);
|
||||
|
||||
const onDragOver = useCallback((event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
event.dataTransfer.dropEffect = "copy";
|
||||
}, []);
|
||||
|
||||
const onDrop = useCallback(
|
||||
(event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
|
||||
const blockData = event.dataTransfer.getData("application/reactflow");
|
||||
if (!blockData) return;
|
||||
|
||||
try {
|
||||
const { blockId, blockName, hardcodedValues } = JSON.parse(blockData);
|
||||
|
||||
// Convert screen coordinates to flow coordinates
|
||||
const position = screenToFlowPosition({
|
||||
x: event.clientX,
|
||||
y: event.clientY,
|
||||
});
|
||||
|
||||
// Find the block schema
|
||||
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
||||
if (!nodeSchema) {
|
||||
console.error(`Schema not found for block ID: ${blockId}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create the new node at the drop position
|
||||
const newNode: CustomNode = {
|
||||
id: nodeId.toString(),
|
||||
type: "custom",
|
||||
position,
|
||||
data: {
|
||||
blockType: blockName,
|
||||
blockCosts: nodeSchema.costs || [],
|
||||
title: `${beautifyString(blockName)} ${nodeId}`,
|
||||
description: nodeSchema.description,
|
||||
categories: nodeSchema.categories,
|
||||
inputSchema: nodeSchema.inputSchema,
|
||||
outputSchema: nodeSchema.outputSchema,
|
||||
hardcodedValues: hardcodedValues,
|
||||
connections: [],
|
||||
isOutputOpen: false,
|
||||
block_id: blockId,
|
||||
uiType: nodeSchema.uiType,
|
||||
},
|
||||
};
|
||||
|
||||
history.push({
|
||||
type: "ADD_NODE",
|
||||
payload: { node: { ...newNode, ...newNode.data } },
|
||||
undo: () => {
|
||||
deleteElements({ nodes: [{ id: newNode.id } as any], edges: [] });
|
||||
},
|
||||
redo: () => {
|
||||
addNodes([newNode]);
|
||||
},
|
||||
});
|
||||
addNodes([newNode]);
|
||||
clearNodesStatusAndOutput();
|
||||
|
||||
setNodeId((prevId) => prevId + 1);
|
||||
} catch (error) {
|
||||
console.error("Failed to drop block:", error);
|
||||
}
|
||||
},
|
||||
[
|
||||
nodeId,
|
||||
availableBlocks,
|
||||
nodes,
|
||||
edges,
|
||||
addNodes,
|
||||
screenToFlowPosition,
|
||||
deleteElements,
|
||||
clearNodesStatusAndOutput,
|
||||
],
|
||||
);
|
||||
|
||||
return (
|
||||
<FlowContext.Provider
|
||||
value={{ visualizeBeads, setIsAnyModalOpen, getNextNodeId }}
|
||||
@@ -779,8 +695,6 @@ const FlowEditor: React.FC<{
|
||||
onEdgesChange={onEdgesChange}
|
||||
onNodeDragStop={onNodeDragEnd}
|
||||
onNodeDragStart={onNodeDragStart}
|
||||
onDrop={onDrop}
|
||||
onDragOver={onDragOver}
|
||||
deleteKeyCode={["Backspace", "Delete"]}
|
||||
minZoom={0.1}
|
||||
maxZoom={2}
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { useState } from "react";
|
||||
import { FileInput } from "./FileInput";
|
||||
|
||||
const meta: Meta<typeof FileInput> = {
|
||||
title: "Atoms/FileInput",
|
||||
component: FileInput,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"File upload input with progress and removable preview.\n\nProps:\n- accept: optional MIME/extensions filter (e.g. ['image/*', '.pdf']).\n- maxFileSize: optional maximum size in bytes; larger files are rejected with an inline error.",
|
||||
},
|
||||
},
|
||||
},
|
||||
argTypes: {
|
||||
onUploadFile: { action: "upload" },
|
||||
accept: {
|
||||
control: "object",
|
||||
description:
|
||||
"Optional accept filter. Supports MIME types (image/*) and extensions (.pdf).",
|
||||
},
|
||||
maxFileSize: {
|
||||
control: "number",
|
||||
description: "Optional maximum file size in bytes.",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
function mockUpload(file: File): Promise<{
|
||||
file_name: string;
|
||||
size: number;
|
||||
content_type: string;
|
||||
file_uri: string;
|
||||
}> {
|
||||
return new Promise((resolve) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
resolve({
|
||||
file_name: file.name,
|
||||
size: file.size,
|
||||
content_type: file.type || "application/octet-stream",
|
||||
file_uri: URL.createObjectURL(file),
|
||||
}),
|
||||
400,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
export const Basic: Story = {
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"This example accepts images or PDFs only and limits size to 5MB. Oversized or disallowed file types show an inline error and do not upload.",
|
||||
},
|
||||
},
|
||||
},
|
||||
render: function BasicStory() {
|
||||
const [value, setValue] = useState<string>("");
|
||||
const [progress, setProgress] = useState<number>(0);
|
||||
|
||||
async function onUploadFile(file: File) {
|
||||
setProgress(0);
|
||||
const interval = setInterval(() => {
|
||||
setProgress((p) => (p >= 100 ? 100 : p + 20));
|
||||
}, 80);
|
||||
const result = await mockUpload(file);
|
||||
clearInterval(interval);
|
||||
setProgress(100);
|
||||
return result;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-[560px]">
|
||||
<FileInput
|
||||
value={value}
|
||||
onChange={setValue}
|
||||
onUploadFile={onUploadFile}
|
||||
uploadProgress={progress}
|
||||
accept={["image/*", ".pdf"]}
|
||||
maxFileSize={5 * 1024 * 1024}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
};
|
||||
@@ -1,213 +0,0 @@
|
||||
import { FileTextIcon, TrashIcon, UploadIcon } from "@phosphor-icons/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { Button } from "../Button/Button";
|
||||
import { formatFileSize, getFileLabel } from "./helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Progress } from "../Progress/Progress";
|
||||
|
||||
type UploadFileResult = {
|
||||
file_name: string;
|
||||
size: number;
|
||||
content_type: string;
|
||||
file_uri: string;
|
||||
};
|
||||
|
||||
interface Props {
|
||||
onUploadFile: (file: File) => Promise<UploadFileResult>;
|
||||
uploadProgress: number;
|
||||
value?: string; // file URI or empty
|
||||
placeholder?: string; // e.g. "Resume", "Document", etc.
|
||||
onChange: (value: string) => void;
|
||||
className?: string;
|
||||
maxFileSize?: number; // bytes (optional)
|
||||
accept?: string | string[]; // input accept filter (optional)
|
||||
}
|
||||
|
||||
export function FileInput({
|
||||
onUploadFile,
|
||||
uploadProgress,
|
||||
value,
|
||||
onChange,
|
||||
className,
|
||||
maxFileSize,
|
||||
accept,
|
||||
}: Props) {
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [uploadError, setUploadError] = useState<string | null>(null);
|
||||
const [fileInfo, setFileInfo] = useState<{
|
||||
name: string;
|
||||
size: number;
|
||||
content_type: string;
|
||||
} | null>(null);
|
||||
|
||||
const uploadFile = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setUploadError(null);
|
||||
|
||||
try {
|
||||
const result = await onUploadFile(file);
|
||||
|
||||
setFileInfo({
|
||||
name: result.file_name,
|
||||
size: result.size,
|
||||
content_type: result.content_type,
|
||||
});
|
||||
|
||||
// Set the file URI as the value
|
||||
onChange(result.file_uri);
|
||||
} catch (error) {
|
||||
console.error("Upload failed:", error);
|
||||
setUploadError(error instanceof Error ? error.message : "Upload failed");
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (!file) return;
|
||||
// Validate max size
|
||||
if (typeof maxFileSize === "number" && file.size > maxFileSize) {
|
||||
setUploadError(
|
||||
`File exceeds maximum size of ${formatFileSize(maxFileSize)} (selected ${formatFileSize(file.size)})`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
// Validate accept types
|
||||
if (!isAcceptedType(file, accept)) {
|
||||
setUploadError("Selected file type is not allowed");
|
||||
return;
|
||||
}
|
||||
uploadFile(file);
|
||||
};
|
||||
|
||||
const handleFileDrop = (event: React.DragEvent<HTMLDivElement>) => {
|
||||
event.preventDefault();
|
||||
const file = event.dataTransfer.files[0];
|
||||
if (file) uploadFile(file);
|
||||
};
|
||||
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const storageNote =
|
||||
"Files are stored securely and will be automatically deleted at most 24 hours after upload.";
|
||||
|
||||
function acceptToString(a?: string | string[]) {
|
||||
if (!a) return "*/*";
|
||||
return Array.isArray(a) ? a.join(",") : a;
|
||||
}
|
||||
|
||||
function isAcceptedType(file: File, a?: string | string[]) {
|
||||
if (!a) return true;
|
||||
const list = Array.isArray(a) ? a : a.split(",").map((s) => s.trim());
|
||||
const fileType = file.type; // e.g. image/png
|
||||
const fileExt = file.name.includes(".")
|
||||
? `.${file.name.split(".").pop()}`.toLowerCase()
|
||||
: "";
|
||||
|
||||
for (const entry of list) {
|
||||
if (!entry) continue;
|
||||
const e = entry.toLowerCase();
|
||||
if (e.includes("/")) {
|
||||
// MIME type, support wildcards like image/*
|
||||
const [main, sub] = e.split("/");
|
||||
const [fMain, fSub] = fileType.toLowerCase().split("/");
|
||||
if (!fMain || !fSub) continue;
|
||||
if (sub === "*") {
|
||||
if (main === fMain) return true;
|
||||
} else {
|
||||
if (e === fileType.toLowerCase()) return true;
|
||||
}
|
||||
} else if (e.startsWith(".")) {
|
||||
// Extension match
|
||||
if (fileExt === e) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("w-full", className)}>
|
||||
{isUploading ? (
|
||||
<div className="space-y-2">
|
||||
<div className="flex min-h-14 items-center gap-4">
|
||||
<div className="agpt-border-input flex min-h-14 w-full flex-col justify-center rounded-xl bg-zinc-50 p-4 text-sm">
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<UploadIcon className="h-5 w-5 text-blue-600" />
|
||||
<span className="text-gray-700">Uploading...</span>
|
||||
<span className="text-gray-500">
|
||||
{Math.round(uploadProgress)}%
|
||||
</span>
|
||||
</div>
|
||||
<Progress value={uploadProgress} className="w-full" />
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-gray-500">{storageNote}</p>
|
||||
</div>
|
||||
) : value ? (
|
||||
<div className="space-y-2">
|
||||
<div className="flex min-h-14 items-center gap-4">
|
||||
<div className="agpt-border-input flex min-h-14 w-full items-center justify-between rounded-xl bg-zinc-50 p-4 text-sm text-gray-500">
|
||||
<div className="flex items-center gap-2">
|
||||
<FileTextIcon className="h-7 w-7 text-black" />
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<span className="font-normal text-black">
|
||||
{fileInfo
|
||||
? getFileLabel(fileInfo.name, fileInfo.content_type)
|
||||
: "File"}
|
||||
</span>
|
||||
<span>{fileInfo ? formatFileSize(fileInfo.size) : ""}</span>
|
||||
</div>
|
||||
</div>
|
||||
<TrashIcon
|
||||
className="h-5 w-5 cursor-pointer text-black"
|
||||
onClick={() => {
|
||||
if (inputRef.current) {
|
||||
inputRef.current.value = "";
|
||||
}
|
||||
onChange("");
|
||||
setFileInfo(null);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-gray-500">{storageNote}</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
<div className="flex min-h-14 items-center gap-4">
|
||||
<div
|
||||
onDrop={handleFileDrop}
|
||||
onDragOver={(e) => e.preventDefault()}
|
||||
className="agpt-border-input flex min-h-14 w-full items-center justify-center rounded-xl border-dashed bg-zinc-50 text-sm text-gray-500"
|
||||
>
|
||||
Choose a file or drag and drop it here
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={() => inputRef.current?.click()}
|
||||
className="min-w-40"
|
||||
>
|
||||
Browse File
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{uploadError && (
|
||||
<div className="text-sm text-red-600">Error: {uploadError}</div>
|
||||
)}
|
||||
|
||||
<p className="text-xs text-gray-500">{storageNote}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="file"
|
||||
accept={acceptToString(accept)}
|
||||
className="hidden"
|
||||
onChange={handleFileChange}
|
||||
disabled={isUploading}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
export function getFileLabel(filename: string, contentType?: string) {
|
||||
if (contentType) {
|
||||
const mimeParts = contentType.split("/");
|
||||
if (mimeParts.length > 1) {
|
||||
return `${mimeParts[1].toUpperCase()} file`;
|
||||
}
|
||||
return `${contentType} file`;
|
||||
}
|
||||
|
||||
const pathParts = filename.split(".");
|
||||
if (pathParts.length > 1) {
|
||||
const ext = pathParts.pop();
|
||||
if (ext) return `${ext.toUpperCase()} file`;
|
||||
}
|
||||
return "File";
|
||||
}
|
||||
|
||||
export function formatFileSize(bytes: number): string {
|
||||
if (bytes >= 1024 * 1024) {
|
||||
return `${(bytes / (1024 * 1024)).toFixed(2)} MB`;
|
||||
} else if (bytes >= 1024) {
|
||||
return `${(bytes / 1024).toFixed(2)} KB`;
|
||||
} else {
|
||||
return `${bytes} B`;
|
||||
}
|
||||
}
|
||||
@@ -26,8 +26,6 @@ const meta: Meta<typeof Input> = {
|
||||
"tel",
|
||||
"url",
|
||||
"textarea",
|
||||
"date",
|
||||
"datetime-local",
|
||||
],
|
||||
description: "Input type",
|
||||
},
|
||||
@@ -112,38 +110,6 @@ export const WithError: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const DateInput: Story = {
|
||||
args: {
|
||||
label: "Date",
|
||||
type: "date",
|
||||
placeholder: "Select a date",
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Native HTML date input integrated in the design system Input. Value format is yyyy-MM-dd.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const DateTimeLocalInput: Story = {
|
||||
args: {
|
||||
label: "Date & Time",
|
||||
type: "datetime-local",
|
||||
placeholder: "Select date and time",
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Native datetime-local input. Value is a local time string (e.g. 2025-08-28T14:30).",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const TextareaInput: Story = {
|
||||
args: {
|
||||
label: "Description",
|
||||
@@ -246,21 +212,6 @@ function renderInputTypes() {
|
||||
rows={4}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-4">
|
||||
<p className="font-mono text-sm">Native date input.</p>
|
||||
<Input label="Date" type="date" placeholder="Select a date" id="date" />
|
||||
</div>
|
||||
<div className="flex flex-col gap-4">
|
||||
<p className="font-mono text-sm">
|
||||
Native datetime-local input (local time, no timezone).
|
||||
</p>
|
||||
<Input
|
||||
label="Date & Time"
|
||||
type="datetime-local"
|
||||
placeholder="Select date and time"
|
||||
id="datetime"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -22,9 +22,7 @@ export interface TextFieldProps extends Omit<InputProps, "size"> {
|
||||
| "amount"
|
||||
| "tel"
|
||||
| "url"
|
||||
| "textarea"
|
||||
| "date"
|
||||
| "datetime-local";
|
||||
| "textarea";
|
||||
// Textarea-specific props
|
||||
rows?: number;
|
||||
}
|
||||
|
||||
@@ -17,9 +17,7 @@ interface ExtendedInputProps extends InputProps {
|
||||
| "amount"
|
||||
| "tel"
|
||||
| "url"
|
||||
| "textarea"
|
||||
| "date"
|
||||
| "datetime-local";
|
||||
| "textarea";
|
||||
}
|
||||
|
||||
export function useInput(args: ExtendedInputProps) {
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Progress } from "./Progress";
|
||||
|
||||
const meta: Meta<typeof Progress> = {
|
||||
title: "Atoms/Progress",
|
||||
component: Progress,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Simple progress bar with value and optional max (default 100).",
|
||||
},
|
||||
},
|
||||
},
|
||||
argTypes: {
|
||||
value: {
|
||||
control: { type: "number", min: 0, max: 100 },
|
||||
description: "Current value.",
|
||||
},
|
||||
max: {
|
||||
control: { type: "number", min: 1 },
|
||||
description: "Maximum value (default 100).",
|
||||
},
|
||||
className: {
|
||||
control: "text",
|
||||
description: "Optional className for container (e.g. height).",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const Basic: Story = {
|
||||
args: { value: 50 },
|
||||
render: function BasicStory(args) {
|
||||
return (
|
||||
<div className="w-80">
|
||||
<Progress {...args} />
|
||||
</div>
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
export const CustomMax: Story = {
|
||||
args: { value: 30, max: 60 },
|
||||
render: function CustomMaxStory(args) {
|
||||
return (
|
||||
<div className="w-80">
|
||||
<Progress {...args} />
|
||||
</div>
|
||||
);
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: { story: "With max=60, value=30 renders as 50%." },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const Sizes: Story = {
|
||||
render: function SizesStory() {
|
||||
return (
|
||||
<div className="w-80 space-y-4">
|
||||
<Progress value={40} className="h-1" />
|
||||
<Progress value={60} className="h-2" />
|
||||
<Progress value={80} className="h-3" />
|
||||
</div>
|
||||
);
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story: "Adjust height via className (e.g., h-1, h-2, h-3).",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const Live: Story = {
|
||||
render: function LiveStory() {
|
||||
const [value, setValue] = useState<number>(0);
|
||||
useEffect(() => {
|
||||
const id = setInterval(
|
||||
() => setValue((v) => (v >= 100 ? 0 : v + 10)),
|
||||
400,
|
||||
);
|
||||
return () => clearInterval(id);
|
||||
}, []);
|
||||
return <Progress value={value} className="w-80" />;
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: { story: "Animated example updating value on an interval." },
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -10,7 +10,7 @@ const meta: Meta<typeof Select> = {
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Select component based on our design system. Built on shadcn/ui with styling that matches our Input. Supports size variants (small | medium) and optional hidden label.",
|
||||
"Select component based on our design system. Built on top of shadcn/ui select with custom styling matching Figma designs and consistent with the Input component.",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -44,12 +44,6 @@ const meta: Meta<typeof Select> = {
|
||||
control: "object",
|
||||
description: "Array of options with value and label properties",
|
||||
},
|
||||
size: {
|
||||
control: { type: "radio" },
|
||||
options: ["small", "medium"],
|
||||
description:
|
||||
"Visual size variant. small = compact trigger (22px line-height), medium = default (46px height).",
|
||||
},
|
||||
},
|
||||
args: {
|
||||
placeholder: "Select an option...",
|
||||
@@ -97,119 +91,6 @@ export const WithValue: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Small: Story = {
|
||||
args: {
|
||||
id: "select-small",
|
||||
label: "Compact",
|
||||
hideLabel: true,
|
||||
size: "small",
|
||||
placeholder: "Choose option",
|
||||
options: [
|
||||
{ value: "opt1", label: "Option 1" },
|
||||
{ value: "opt2", label: "Option 2" },
|
||||
{ value: "opt3", label: "Option 3" },
|
||||
],
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Small size is ideal for dense UIs (e.g., inline controls like TimePicker).",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const Medium: Story = {
|
||||
args: {
|
||||
id: "select-medium",
|
||||
label: "Medium",
|
||||
size: "medium",
|
||||
placeholder: "Choose option",
|
||||
options: [
|
||||
{ value: "opt1", label: "Option 1" },
|
||||
{ value: "opt2", label: "Option 2" },
|
||||
{ value: "opt3", label: "Option 3" },
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
export const WithIconsAndSeparators: Story = {
|
||||
render: function IconsStory() {
|
||||
const opts = [
|
||||
{ value: "oauth", label: "Your Google account", icon: <span>✅</span> },
|
||||
{ separator: true, value: "sep1", label: "" } as any,
|
||||
{
|
||||
value: "signin",
|
||||
label: "Sign in with Google",
|
||||
icon: <span>🔐</span>,
|
||||
onSelect: () => alert("Sign in"),
|
||||
},
|
||||
{
|
||||
value: "add-key",
|
||||
label: "Add API key",
|
||||
onSelect: () => alert("Add key"),
|
||||
},
|
||||
];
|
||||
return (
|
||||
<div className="w-[320px]">
|
||||
<Select
|
||||
id="rich"
|
||||
label="Rich"
|
||||
hideLabel
|
||||
options={opts as any}
|
||||
placeholder="Choose"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Demonstrates icons, separators, and actionable rows via onSelect. onSelect prevents value change and triggers the action.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const WithRenderItem: Story = {
|
||||
render: function RenderItemStory() {
|
||||
const opts = [
|
||||
{ value: "opt1", label: "Option 1" },
|
||||
{ value: "opt2", label: "Option 2", disabled: true },
|
||||
{ value: "opt3", label: "Option 3" },
|
||||
];
|
||||
return (
|
||||
<div className="w-[320px]">
|
||||
<Select
|
||||
id="render"
|
||||
label="Custom"
|
||||
hideLabel
|
||||
options={opts}
|
||||
placeholder="Pick one"
|
||||
renderItem={(o) => (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium">{o.label}</span>
|
||||
{o.disabled && (
|
||||
<span className="text-xs text-zinc-400">(disabled)</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Custom rendering for options via renderItem prop; disabled items are styled and non-selectable.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const WithoutLabel: Story = {
|
||||
args: {
|
||||
label: "Country",
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
SelectSeparator,
|
||||
} from "@/components/ui/select";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ReactNode } from "react";
|
||||
@@ -16,10 +15,6 @@ import { Text } from "../Text/Text";
|
||||
export interface SelectOption {
|
||||
value: string;
|
||||
label: string;
|
||||
icon?: ReactNode;
|
||||
disabled?: boolean;
|
||||
separator?: boolean;
|
||||
onSelect?: () => void; // optional action handler
|
||||
}
|
||||
|
||||
export interface SelectFieldProps {
|
||||
@@ -34,8 +29,6 @@ export interface SelectFieldProps {
|
||||
value?: string;
|
||||
onValueChange?: (value: string) => void;
|
||||
options: SelectOption[];
|
||||
size?: "small" | "medium";
|
||||
renderItem?: (option: SelectOption) => React.ReactNode;
|
||||
}
|
||||
|
||||
export function Select({
|
||||
@@ -50,24 +43,14 @@ export function Select({
|
||||
value,
|
||||
onValueChange,
|
||||
options,
|
||||
size = "medium",
|
||||
renderItem,
|
||||
}: SelectFieldProps) {
|
||||
const triggerStyles = cn(
|
||||
// Base styles matching Input
|
||||
"rounded-3xl border border-zinc-200 bg-white px-4 shadow-none",
|
||||
"font-normal text-black w-full",
|
||||
// Override the default select styles with Figma design matching Input
|
||||
"h-[2.875rem] rounded-3xl border border-zinc-200 bg-white px-4 py-2.5 shadow-none",
|
||||
"font-normal text-black text-sm w-full",
|
||||
"placeholder:font-normal !placeholder:text-zinc-400",
|
||||
// Focus and hover states
|
||||
"focus:border-zinc-400 focus:shadow-none focus:outline-none focus:ring-1 focus:ring-zinc-400 focus:ring-offset-0",
|
||||
// Size variants
|
||||
size === "small" && [
|
||||
"h-[2.25rem]",
|
||||
"py-2",
|
||||
"text-sm leading-[22px]",
|
||||
"placeholder:text-sm placeholder:leading-[22px]",
|
||||
],
|
||||
size === "medium" && ["h-[2.875rem]", "py-2.5", "text-sm"],
|
||||
// Error state
|
||||
error &&
|
||||
"border-1.5 border-red-500 focus:border-red-500 focus:ring-red-500",
|
||||
@@ -86,32 +69,11 @@ export function Select({
|
||||
<SelectValue placeholder={placeholder || label} />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{options.map((option, idx) => {
|
||||
if (option.separator) return <SelectSeparator key={`sep-${idx}`} />;
|
||||
const content = renderItem ? (
|
||||
renderItem(option)
|
||||
) : (
|
||||
<div className="flex items-center gap-2">
|
||||
{option.icon}
|
||||
<span>{option.label}</span>
|
||||
</div>
|
||||
);
|
||||
return (
|
||||
<SelectItem
|
||||
key={option.value}
|
||||
value={option.value}
|
||||
disabled={option.disabled}
|
||||
onMouseDown={(e) => {
|
||||
if (option.onSelect) {
|
||||
e.preventDefault();
|
||||
option.onSelect();
|
||||
}
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
</SelectItem>
|
||||
);
|
||||
})}
|
||||
{options.map((option) => (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
{option.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</BaseSelect>
|
||||
);
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { useState } from "react";
|
||||
import { Switch } from "./Switch";
|
||||
|
||||
const meta: Meta<typeof Switch> = {
|
||||
title: "Atoms/Switch",
|
||||
component: Switch,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Shadcn-based toggle switch. Controlled via checked and onCheckedChange.",
|
||||
},
|
||||
},
|
||||
},
|
||||
argTypes: {
|
||||
checked: { control: "boolean", description: "Checked state (controlled)." },
|
||||
disabled: { control: "boolean", description: "Disable the switch." },
|
||||
onCheckedChange: { action: "change", description: "Change handler." },
|
||||
className: { control: "text", description: "Optional className." },
|
||||
},
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const Basic: Story = {
|
||||
render: function BasicStory(args) {
|
||||
const [on, setOn] = useState<boolean>(true);
|
||||
return (
|
||||
<div className="flex items-center gap-3">
|
||||
<Switch
|
||||
aria-label="Toggle"
|
||||
checked={on}
|
||||
onCheckedChange={(v) => {
|
||||
setOn(v);
|
||||
if (args.onCheckedChange) args.onCheckedChange(v);
|
||||
}}
|
||||
/>
|
||||
<span className="text-sm">{on ? "On" : "Off"}</span>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: { disabled: true },
|
||||
render: function DisabledStory(args) {
|
||||
return <Switch aria-label="Disabled switch" disabled {...args} />;
|
||||
},
|
||||
};
|
||||
|
||||
export const WithLabel: Story = {
|
||||
render: function WithLabelStory() {
|
||||
const [on, setOn] = useState<boolean>(false);
|
||||
const id = "ds-switch-label";
|
||||
return (
|
||||
<label htmlFor={id} className="flex items-center gap-3">
|
||||
<Switch id={id} checked={on} onCheckedChange={setOn} />
|
||||
<span className="text-sm">Enable notifications</span>
|
||||
</label>
|
||||
);
|
||||
},
|
||||
};
|
||||
@@ -278,22 +278,9 @@ export function BlocksControl({
|
||||
className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${
|
||||
block.notAvailable
|
||||
? "cursor-not-allowed opacity-50"
|
||||
: "cursor-move hover:shadow-lg"
|
||||
: "cursor-pointer hover:shadow-lg"
|
||||
}`}
|
||||
data-id={`block-card-${block.id}`}
|
||||
draggable={!block.notAvailable}
|
||||
onDragStart={(e) => {
|
||||
if (block.notAvailable) return;
|
||||
e.dataTransfer.effectAllowed = "copy";
|
||||
e.dataTransfer.setData(
|
||||
"application/reactflow",
|
||||
JSON.stringify({
|
||||
blockId: block.id,
|
||||
blockName: block.name,
|
||||
hardcodedValues: block?.hardcodedValues || {},
|
||||
}),
|
||||
);
|
||||
}}
|
||||
onClick={() =>
|
||||
!block.notAvailable &&
|
||||
addBlock(block.id, block.name, block?.hardcodedValues || {})
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { FC } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
@@ -10,55 +20,73 @@ import {
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { useAPIKeyCredentialsModal } from "./useAPIKeyCredentialsModal";
|
||||
|
||||
type Props = {
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
};
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
export function APIKeyCredentialsModal({
|
||||
schema,
|
||||
open,
|
||||
onClose,
|
||||
onCredentialsCreate,
|
||||
siblingInputs,
|
||||
}: Props) {
|
||||
const {
|
||||
form,
|
||||
isLoading,
|
||||
supportsApiKey,
|
||||
providerName,
|
||||
schemaDescription,
|
||||
onSubmit,
|
||||
} = useAPIKeyCredentialsModal({ schema, siblingInputs, onCredentialsCreate });
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
if (isLoading || !supportsApiKey) {
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createAPIKeyCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={`Add new API key for ${providerName ?? ""}`}
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: (isOpen) => {
|
||||
if (!isOpen) onClose();
|
||||
},
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
onClose={onClose}
|
||||
>
|
||||
<Dialog.Content>
|
||||
{schemaDescription && (
|
||||
<p className="mb-4 text-sm text-zinc-600">{schemaDescription}</p>
|
||||
)}
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add new API key for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
@@ -81,9 +109,6 @@ export function APIKeyCredentialsModal({
|
||||
)}
|
||||
<FormControl>
|
||||
<Input
|
||||
id="apiKey"
|
||||
label="API Key"
|
||||
hideLabel
|
||||
type="password"
|
||||
placeholder="Enter API key..."
|
||||
{...field}
|
||||
@@ -101,9 +126,6 @@ export function APIKeyCredentialsModal({
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
id="title"
|
||||
label="Name"
|
||||
hideLabel
|
||||
type="text"
|
||||
placeholder="Enter a name for this API key..."
|
||||
{...field}
|
||||
@@ -121,9 +143,6 @@ export function APIKeyCredentialsModal({
|
||||
<FormLabel>Expiration Date (Optional)</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
id="expiresAt"
|
||||
label="Expiration Date"
|
||||
hideLabel
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
{...field}
|
||||
@@ -138,7 +157,7 @@ export function APIKeyCredentialsModal({
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</Dialog.Content>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { IconKey, IconKeyPlus, IconUserPlus } from "@/components/ui/icons";
|
||||
import {
|
||||
Select,
|
||||
@@ -28,10 +28,10 @@ import {
|
||||
FaMedium,
|
||||
FaTwitter,
|
||||
} from "react-icons/fa";
|
||||
import { APIKeyCredentialsModal } from "../APIKeyCredentialsModal/APIKeyCredentialsModal";
|
||||
import { HostScopedCredentialsModal } from "../HotScopedCredentialsModal/HotScopedCredentialsModal";
|
||||
import { OAuthFlowWaitingModal } from "../OAuthWaitingModal/OAuthWaitingModal";
|
||||
import { PasswordCredentialsModal } from "../PasswordCredentialsModal/PasswordCredentialsModal";
|
||||
import { APIKeyCredentialsModal } from "./api-key-credentials-modal";
|
||||
import { HostScopedCredentialsModal } from "./host-scoped-credentials-modal";
|
||||
import { OAuth2FlowWaitingModal } from "./oauth2-flow-waiting-modal";
|
||||
import { UserPasswordCredentialsModal } from "./user-password-credentials-modal";
|
||||
|
||||
const fallbackIcon = FaKey;
|
||||
|
||||
@@ -290,14 +290,14 @@ export const CredentialsInput: FC<{
|
||||
/>
|
||||
)}
|
||||
{supportsOAuth2 && (
|
||||
<OAuthFlowWaitingModal
|
||||
<OAuth2FlowWaitingModal
|
||||
open={isOAuth2FlowInProgress}
|
||||
onClose={() => oAuthPopupController?.abort("canceled")}
|
||||
providerName={providerName}
|
||||
/>
|
||||
)}
|
||||
{supportsUserPassword && (
|
||||
<PasswordCredentialsModal
|
||||
<UserPasswordCredentialsModal
|
||||
schema={schema}
|
||||
open={isUserPasswordCredentialsModalOpen}
|
||||
onClose={() => setUserPasswordCredentialsModalOpen(false)}
|
||||
@@ -1,10 +1,16 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { FC, useEffect, useState } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
@@ -21,21 +27,13 @@ import {
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
|
||||
type Props = {
|
||||
export const HostScopedCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
};
|
||||
|
||||
export function HostScopedCredentialsModal({
|
||||
schema,
|
||||
open,
|
||||
onClose,
|
||||
onCredentialsCreate,
|
||||
siblingInputs,
|
||||
}: Props) {
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
// Get current host from siblingInputs or discriminator_values
|
||||
@@ -131,19 +129,18 @@ export function HostScopedCredentialsModal({
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={`Add sensitive headers for ${providerName}`}
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: (isOpen) => {
|
||||
if (!isOpen) onClose();
|
||||
},
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
onClose={onClose}
|
||||
>
|
||||
<Dialog.Content>
|
||||
{schema.description && (
|
||||
<p className="mb-4 text-sm text-zinc-600">{schema.description}</p>
|
||||
)}
|
||||
<DialogContent className="max-h-[90vh] max-w-2xl overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add sensitive headers for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
@@ -160,9 +157,6 @@ export function HostScopedCredentialsModal({
|
||||
</FormDescription>
|
||||
<FormControl>
|
||||
<Input
|
||||
id="host"
|
||||
label="Host Pattern"
|
||||
hideLabel
|
||||
type="text"
|
||||
readOnly={!!currentHost}
|
||||
placeholder={
|
||||
@@ -190,9 +184,6 @@ export function HostScopedCredentialsModal({
|
||||
<div key={index} className="flex items-end gap-2">
|
||||
<div className="flex-1">
|
||||
<Input
|
||||
id={`header-${index}-key`}
|
||||
label="Header Name"
|
||||
hideLabel
|
||||
placeholder="Header name (e.g., Authorization)"
|
||||
value={pair.key}
|
||||
onChange={(e) =>
|
||||
@@ -202,9 +193,6 @@ export function HostScopedCredentialsModal({
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<Input
|
||||
id={`header-${index}-value`}
|
||||
label="Header Value"
|
||||
hideLabel
|
||||
type="password"
|
||||
placeholder="Header value (e.g., Bearer token123)"
|
||||
value={pair.value}
|
||||
@@ -216,7 +204,7 @@ export function HostScopedCredentialsModal({
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
size="sm"
|
||||
onClick={() => removeHeaderPair(index)}
|
||||
disabled={headerPairs.length === 1}
|
||||
>
|
||||
@@ -228,7 +216,7 @@ export function HostScopedCredentialsModal({
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
size="sm"
|
||||
onClick={addHeaderPair}
|
||||
className="w-full"
|
||||
>
|
||||
@@ -241,7 +229,7 @@ export function HostScopedCredentialsModal({
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</Dialog.Content>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,36 @@
|
||||
import { FC } from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
|
||||
export const OAuth2FlowWaitingModal: FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
providerName: string;
|
||||
}> = ({ open, onClose, providerName }) => {
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Waiting on {providerName} sign-in process...
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Complete the sign-in process in the pop-up window.
|
||||
<br />
|
||||
Closing this dialog will cancel the sign-in process.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -1,3 +1,4 @@
|
||||
import { FC } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
@@ -23,21 +24,13 @@ import {
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
type Props = {
|
||||
export const UserPasswordCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
};
|
||||
|
||||
export function PasswordCredentialsModal({
|
||||
schema,
|
||||
open,
|
||||
onClose,
|
||||
onCredentialsCreate,
|
||||
siblingInputs,
|
||||
}: Props) {
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
@@ -153,4 +146,4 @@ export function PasswordCredentialsModal({
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -69,10 +69,6 @@ export const CustomStyling: Story = {
|
||||
render: renderCustomStyledDialog,
|
||||
};
|
||||
|
||||
export const ModalOverModal: Story = {
|
||||
render: renderModalOverModal,
|
||||
};
|
||||
|
||||
function renderBasicDialog() {
|
||||
return (
|
||||
<Dialog title="Basic Dialog">
|
||||
@@ -199,33 +195,3 @@ function renderCustomStyledDialog() {
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
function renderModalOverModal() {
|
||||
return (
|
||||
<Dialog title="Parent Dialog">
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary">Open Parent</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<p>
|
||||
This is the parent dialog. You can open another modal on top of it
|
||||
using a nested Dialog.
|
||||
</p>
|
||||
|
||||
<Dialog title="Child Dialog">
|
||||
<Dialog.Trigger>
|
||||
<Button size="small">Open Child Modal</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<p>
|
||||
This child dialog is rendered above the parent. Close it first
|
||||
to interact with the parent again.
|
||||
</p>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { useState } from "react";
|
||||
import { TimePicker } from "./TimePicker";
|
||||
|
||||
const meta: Meta<typeof TimePicker> = {
|
||||
title: "Molecules/TimePicker",
|
||||
component: TimePicker,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Compact time selector using three small Selects (hour, minute, AM/PM).",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const Basic: Story = {
|
||||
render: function BasicStory() {
|
||||
const [value, setValue] = useState<string>("12:00");
|
||||
return <TimePicker value={value} onChange={setValue} />;
|
||||
},
|
||||
};
|
||||
@@ -1,76 +0,0 @@
|
||||
import { Select } from "@/components/atoms/Select/Select";
|
||||
|
||||
interface Props {
|
||||
value?: string;
|
||||
onChange: (time: string) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function TimePicker({ value, onChange }: Props) {
|
||||
const pad = (n: number) => n.toString().padStart(2, "0");
|
||||
const [hourNum, minuteNum] = value ? value.split(":").map(Number) : [0, 0];
|
||||
|
||||
const meridiem = hourNum >= 12 ? "PM" : "AM";
|
||||
const hour = pad(hourNum % 12 || 12);
|
||||
const minute = pad(minuteNum);
|
||||
|
||||
const changeTime = (hour: string, minute: string, meridiem: string) => {
|
||||
const hour24 = (Number(hour) % 12) + (meridiem === "PM" ? 12 : 0);
|
||||
onChange(`${pad(hour24)}:${minute}`);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className="flex flex-col items-center">
|
||||
<Select
|
||||
id={`time-hour`}
|
||||
label="Hour"
|
||||
hideLabel
|
||||
size="small"
|
||||
value={hour}
|
||||
onValueChange={(val: string) => changeTime(val, minute, meridiem)}
|
||||
options={Array.from({ length: 12 }, (_, i) => pad(i + 1)).map(
|
||||
(h) => ({
|
||||
value: h,
|
||||
label: h,
|
||||
}),
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mb-6 flex flex-col items-center">
|
||||
<span className="m-auto text-xl font-bold">:</span>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-center">
|
||||
<Select
|
||||
id={`time-minute`}
|
||||
label="Minute"
|
||||
hideLabel
|
||||
size="small"
|
||||
value={minute}
|
||||
onValueChange={(val: string) => changeTime(hour, val, meridiem)}
|
||||
options={Array.from({ length: 60 }, (_, i) => pad(i)).map((m) => ({
|
||||
value: m.toString(),
|
||||
label: m,
|
||||
}))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-center">
|
||||
<Select
|
||||
id={`time-meridiem`}
|
||||
label="AM/PM"
|
||||
hideLabel
|
||||
size="small"
|
||||
value={meridiem}
|
||||
onValueChange={(val: string) => changeTime(hour, minute, val)}
|
||||
options={[
|
||||
{ value: "AM", label: "AM" },
|
||||
{ value: "PM", label: "PM" },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -34,6 +34,7 @@ import React, {
|
||||
useRef,
|
||||
} from "react";
|
||||
import { Button } from "./ui/button";
|
||||
import { Switch } from "./ui/switch";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
@@ -51,8 +52,7 @@ import {
|
||||
} from "./ui/multiselect";
|
||||
import { LocalValuedInput } from "./ui/input";
|
||||
import NodeHandle from "./NodeHandle";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import { Switch } from "./atoms/Switch/Switch";
|
||||
import { CredentialsInput } from "@/components/integrations/credentials-input";
|
||||
|
||||
type NodeObjectInputTreeProps = {
|
||||
nodeId: string;
|
||||
|
||||
547
autogpt_platform/frontend/src/components/type-based-input.tsx
Normal file
547
autogpt_platform/frontend/src/components/type-based-input.tsx
Normal file
@@ -0,0 +1,547 @@
|
||||
import React, { FC, useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { format } from "date-fns";
|
||||
import { CalendarIcon, UploadIcon } from "lucide-react";
|
||||
import { Cross2Icon, FileTextIcon } from "@radix-ui/react-icons";
|
||||
|
||||
import { Input as BaseInput } from "@/components/ui/input";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Progress } from "@/components/ui/progress";
|
||||
import {
|
||||
Popover,
|
||||
PopoverTrigger,
|
||||
PopoverContent,
|
||||
} from "@/components/ui/popover";
|
||||
import { Calendar } from "@/components/ui/calendar";
|
||||
import {
|
||||
Select,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
} from "@/components/ui/select";
|
||||
import {
|
||||
MultiSelector,
|
||||
MultiSelectorContent,
|
||||
MultiSelectorInput,
|
||||
MultiSelectorItem,
|
||||
MultiSelectorList,
|
||||
MultiSelectorTrigger,
|
||||
} from "@/components/ui/multiselect";
|
||||
import {
|
||||
BlockIOObjectSubSchema,
|
||||
BlockIOSubSchema,
|
||||
DataType,
|
||||
determineDataType,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import BackendAPI from "@/lib/autogpt-server-api/client";
|
||||
|
||||
/**
|
||||
* A generic prop structure for the TypeBasedInput.
|
||||
*
|
||||
* onChange expects an event-like object with e.target.value so the parent
|
||||
* can do something like setInputValues(e.target.value).
|
||||
*/
|
||||
export interface TypeBasedInputProps {
|
||||
schema: BlockIOSubSchema;
|
||||
value?: any;
|
||||
placeholder?: string;
|
||||
onChange: (value: any) => void;
|
||||
}
|
||||
|
||||
const inputClasses = "min-h-11 rounded-[1.375rem] border px-4 py-2.5 bg-text";
|
||||
|
||||
function Input({
|
||||
className,
|
||||
...props
|
||||
}: React.InputHTMLAttributes<HTMLInputElement>) {
|
||||
return <BaseInput {...props} className={cn(inputClasses, className)} />;
|
||||
}
|
||||
|
||||
/**
|
||||
* A generic, data-type-based input component that uses Shadcn UI.
|
||||
* It inspects the schema via `determineDataType` and renders
|
||||
* the correct UI component.
|
||||
*/
|
||||
export const TypeBasedInput: FC<
|
||||
TypeBasedInputProps & React.HTMLAttributes<HTMLElement>
|
||||
> = ({ schema, value, placeholder, onChange, ...props }) => {
|
||||
const dataType = determineDataType(schema);
|
||||
|
||||
let innerInputElement: React.ReactNode = null;
|
||||
switch (dataType) {
|
||||
case DataType.NUMBER:
|
||||
innerInputElement = (
|
||||
<Input
|
||||
type="number"
|
||||
value={value ?? ""}
|
||||
placeholder={placeholder || "Enter number"}
|
||||
onChange={(e) => onChange(Number(e.target.value))}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.LONG_TEXT:
|
||||
innerInputElement = (
|
||||
<Textarea
|
||||
className="rounded-xl px-3 py-2"
|
||||
value={value ?? ""}
|
||||
placeholder={placeholder || "Enter text"}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.BOOLEAN:
|
||||
innerInputElement = (
|
||||
<>
|
||||
<span className="text-sm text-gray-500">
|
||||
{placeholder || (value ? "Enabled" : "Disabled")}
|
||||
</span>
|
||||
<Switch
|
||||
className="ml-auto"
|
||||
checked={!!value}
|
||||
onCheckedChange={(checked: boolean) => onChange(checked)}
|
||||
{...props}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.DATE:
|
||||
innerInputElement = (
|
||||
<DatePicker
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
onChange={onChange}
|
||||
className={cn(inputClasses)}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.TIME:
|
||||
innerInputElement = (
|
||||
<TimePicker value={value?.toString()} onChange={onChange} />
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.DATE_TIME:
|
||||
innerInputElement = (
|
||||
<Input
|
||||
type="datetime-local"
|
||||
value={value ?? ""}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
placeholder={placeholder || "Enter date and time"}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.FILE:
|
||||
innerInputElement = (
|
||||
<FileInput
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
onChange={onChange}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.SELECT:
|
||||
if (
|
||||
"enum" in schema &&
|
||||
Array.isArray(schema.enum) &&
|
||||
schema.enum.length > 0
|
||||
) {
|
||||
innerInputElement = (
|
||||
<Select
|
||||
value={value ?? ""}
|
||||
onValueChange={(val: string) => onChange(val)}
|
||||
>
|
||||
<SelectTrigger
|
||||
className={cn(
|
||||
inputClasses,
|
||||
"agpt-border-input text-sm text-gray-500",
|
||||
)}
|
||||
>
|
||||
<SelectValue placeholder={placeholder || "Select an option"} />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="rounded-xl">
|
||||
{schema.enum
|
||||
.filter((opt) => opt)
|
||||
.map((opt) => (
|
||||
<SelectItem key={opt} value={opt}>
|
||||
{String(opt)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
case DataType.MULTI_SELECT:
|
||||
const _schema = schema as BlockIOObjectSubSchema;
|
||||
|
||||
innerInputElement = (
|
||||
<MultiSelector
|
||||
className="nodrag"
|
||||
values={Object.entries(value || {})
|
||||
.filter(([_, v]) => v)
|
||||
.map(([k, _]) => k)}
|
||||
onValuesChange={(values: string[]) => {
|
||||
const allKeys = Object.keys(_schema.properties);
|
||||
onChange(
|
||||
Object.fromEntries(
|
||||
allKeys.map((opt) => [opt, values.includes(opt)]),
|
||||
),
|
||||
);
|
||||
}}
|
||||
>
|
||||
<MultiSelectorTrigger className={inputClasses}>
|
||||
<MultiSelectorInput
|
||||
placeholder={schema.placeholder ?? `Select ${schema.title}...`}
|
||||
/>
|
||||
</MultiSelectorTrigger>
|
||||
<MultiSelectorContent className="nowheel">
|
||||
<MultiSelectorList
|
||||
className={cn(inputClasses, "agpt-border-input bg-white")}
|
||||
>
|
||||
{Object.keys(_schema.properties)
|
||||
.map((key) => ({ ..._schema.properties[key], key }))
|
||||
.map(({ key, title, description }) => (
|
||||
<MultiSelectorItem key={key} value={key} title={description}>
|
||||
{title ?? key}
|
||||
</MultiSelectorItem>
|
||||
))}
|
||||
</MultiSelectorList>
|
||||
</MultiSelectorContent>
|
||||
</MultiSelector>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.SHORT_TEXT:
|
||||
default:
|
||||
innerInputElement = (
|
||||
<Input
|
||||
type="text"
|
||||
value={value ?? ""}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
placeholder={placeholder || "Enter text"}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <div className="no-drag relative flex">{innerInputElement}</div>;
|
||||
};
|
||||
|
||||
interface DatePickerProps {
|
||||
value?: Date;
|
||||
placeholder?: string;
|
||||
onChange: (date: Date | undefined) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function DatePicker({
|
||||
value,
|
||||
placeholder,
|
||||
onChange,
|
||||
className,
|
||||
}: DatePickerProps) {
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
className={cn(
|
||||
"agpt-border-input w-full justify-start font-normal",
|
||||
!value && "text-muted-foreground",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<CalendarIcon className="mr-2 h-5 w-5" />
|
||||
{value ? (
|
||||
format(value, "PPP")
|
||||
) : (
|
||||
<span>{placeholder || "Pick a date"}</span>
|
||||
)}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
|
||||
<PopoverContent className="flex min-h-[340px] w-auto p-0">
|
||||
<Calendar
|
||||
mode="single"
|
||||
selected={value}
|
||||
onSelect={(selected) => onChange(selected)}
|
||||
autoFocus
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
interface TimePickerProps {
|
||||
value?: string;
|
||||
onChange: (time: string) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function TimePicker({ value, onChange }: TimePickerProps) {
|
||||
const pad = (n: number) => n.toString().padStart(2, "0");
|
||||
const [hourNum, minuteNum] = value ? value.split(":").map(Number) : [0, 0];
|
||||
|
||||
const meridiem = hourNum >= 12 ? "PM" : "AM";
|
||||
const hour = pad(hourNum % 12 || 12);
|
||||
const minute = pad(minuteNum);
|
||||
|
||||
const changeTime = (hour: string, minute: string, meridiem: string) => {
|
||||
const hour24 = (Number(hour) % 12) + (meridiem === "PM" ? 12 : 0);
|
||||
onChange(`${pad(hour24)}:${minute}`);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className="flex flex-col items-center">
|
||||
<Select
|
||||
value={hour}
|
||||
onValueChange={(val: string) => changeTime(val, minute, meridiem)}
|
||||
>
|
||||
<SelectTrigger
|
||||
className={cn("agpt-border-input ml-1 text-center", inputClasses)}
|
||||
>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{Array.from({ length: 12 }, (_, i) => pad(i + 1)).map((h) => (
|
||||
<SelectItem key={h} value={h}>
|
||||
{h}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-center">
|
||||
<span className="m-auto text-xl font-bold">:</span>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-center">
|
||||
<Select
|
||||
value={minute}
|
||||
onValueChange={(val: string) => changeTime(hour, val, meridiem)}
|
||||
>
|
||||
<SelectTrigger
|
||||
className={cn("agpt-border-input text-center", inputClasses)}
|
||||
>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{Array.from({ length: 60 }, (_, i) => pad(i)).map((m) => (
|
||||
<SelectItem key={m} value={m.toString()}>
|
||||
{m}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-center">
|
||||
<Select
|
||||
value={meridiem}
|
||||
onValueChange={(val: string) => changeTime(hour, minute, val)}
|
||||
>
|
||||
<SelectTrigger
|
||||
className={cn("agpt-border-input text-center", inputClasses)}
|
||||
>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="AM">AM</SelectItem>
|
||||
<SelectItem value="PM">PM</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function getFileLabel(filename: string, contentType?: string) {
|
||||
if (contentType) {
|
||||
const mimeParts = contentType.split("/");
|
||||
if (mimeParts.length > 1) {
|
||||
return `${mimeParts[1].toUpperCase()} file`;
|
||||
}
|
||||
return `${contentType} file`;
|
||||
}
|
||||
|
||||
const pathParts = filename.split(".");
|
||||
if (pathParts.length > 1) {
|
||||
const ext = pathParts.pop();
|
||||
if (ext) return `${ext.toUpperCase()} file`;
|
||||
}
|
||||
return "File";
|
||||
}
|
||||
|
||||
function formatFileSize(bytes: number): string {
|
||||
if (bytes >= 1024 * 1024) {
|
||||
return `${(bytes / (1024 * 1024)).toFixed(2)} MB`;
|
||||
} else if (bytes >= 1024) {
|
||||
return `${(bytes / 1024).toFixed(2)} KB`;
|
||||
} else {
|
||||
return `${bytes} B`;
|
||||
}
|
||||
}
|
||||
|
||||
interface FileInputProps {
|
||||
value?: string; // file URI or empty
|
||||
placeholder?: string; // e.g. "Resume", "Document", etc.
|
||||
onChange: (value: string) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
const FileInput: FC<FileInputProps> = ({ value, onChange, className }) => {
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [uploadProgress, setUploadProgress] = useState(0);
|
||||
const [uploadError, setUploadError] = useState<string | null>(null);
|
||||
const [fileInfo, setFileInfo] = useState<{
|
||||
name: string;
|
||||
size: number;
|
||||
content_type: string;
|
||||
} | null>(null);
|
||||
|
||||
const api = new BackendAPI();
|
||||
|
||||
const uploadFile = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setUploadProgress(0);
|
||||
setUploadError(null);
|
||||
|
||||
try {
|
||||
const result = await api.uploadFile(
|
||||
file,
|
||||
"gcs",
|
||||
24, // 24 hours expiration
|
||||
(progress) => setUploadProgress(progress),
|
||||
);
|
||||
|
||||
setFileInfo({
|
||||
name: result.file_name,
|
||||
size: result.size,
|
||||
content_type: result.content_type,
|
||||
});
|
||||
|
||||
// Set the file URI as the value
|
||||
onChange(result.file_uri);
|
||||
} catch (error) {
|
||||
console.error("Upload failed:", error);
|
||||
setUploadError(error instanceof Error ? error.message : "Upload failed");
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
setUploadProgress(0);
|
||||
}
|
||||
};
|
||||
|
||||
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (file) uploadFile(file);
|
||||
};
|
||||
|
||||
const handleFileDrop = (event: React.DragEvent<HTMLDivElement>) => {
|
||||
event.preventDefault();
|
||||
const file = event.dataTransfer.files[0];
|
||||
if (file) uploadFile(file);
|
||||
};
|
||||
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
const storageNote =
|
||||
"Files are stored securely and will be automatically deleted at most 24 hours after upload.";
|
||||
|
||||
return (
|
||||
<div className={cn("w-full", className)}>
|
||||
{isUploading ? (
|
||||
<div className="space-y-2">
|
||||
<div className="flex min-h-14 items-center gap-4">
|
||||
<div className="agpt-border-input flex min-h-14 w-full flex-col justify-center rounded-xl bg-zinc-50 p-4 text-sm">
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<UploadIcon className="h-5 w-5 text-blue-600" />
|
||||
<span className="text-gray-700">Uploading...</span>
|
||||
<span className="text-gray-500">
|
||||
{Math.round(uploadProgress)}%
|
||||
</span>
|
||||
</div>
|
||||
<Progress value={uploadProgress} className="w-full" />
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-gray-500">{storageNote}</p>
|
||||
</div>
|
||||
) : value ? (
|
||||
<div className="space-y-2">
|
||||
<div className="flex min-h-14 items-center gap-4">
|
||||
<div className="agpt-border-input flex min-h-14 w-full items-center justify-between rounded-xl bg-zinc-50 p-4 text-sm text-gray-500">
|
||||
<div className="flex items-center gap-2">
|
||||
<FileTextIcon className="h-7 w-7 text-black" />
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<span className="font-normal text-black">
|
||||
{fileInfo
|
||||
? getFileLabel(fileInfo.name, fileInfo.content_type)
|
||||
: "File"}
|
||||
</span>
|
||||
<span>{fileInfo ? formatFileSize(fileInfo.size) : ""}</span>
|
||||
</div>
|
||||
</div>
|
||||
<Cross2Icon
|
||||
className="h-5 w-5 cursor-pointer text-black"
|
||||
onClick={() => {
|
||||
if (inputRef.current) {
|
||||
inputRef.current.value = "";
|
||||
}
|
||||
onChange("");
|
||||
setFileInfo(null);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-gray-500">{storageNote}</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
<div className="flex min-h-14 items-center gap-4">
|
||||
<div
|
||||
onDrop={handleFileDrop}
|
||||
onDragOver={(e) => e.preventDefault()}
|
||||
className="agpt-border-input flex min-h-14 w-full items-center justify-center rounded-xl border-dashed bg-zinc-50 text-sm text-gray-500"
|
||||
>
|
||||
Choose a file or drag and drop it here
|
||||
</div>
|
||||
|
||||
<Button variant="default" onClick={() => inputRef.current?.click()}>
|
||||
Browse File
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{uploadError && (
|
||||
<div className="text-sm text-red-600">Error: {uploadError}</div>
|
||||
)}
|
||||
|
||||
<p className="text-xs text-gray-500">{storageNote}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="file"
|
||||
accept="*/*"
|
||||
className="hidden"
|
||||
onChange={handleFileChange}
|
||||
disabled={isUploading}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -20,7 +20,7 @@ const Progress = React.forwardRef<HTMLDivElement, ProgressProps>(
|
||||
{...props}
|
||||
>
|
||||
<div
|
||||
className="h-full bg-zinc-800 transition-all duration-300 ease-in-out"
|
||||
className="h-full bg-blue-600 transition-all duration-300 ease-in-out"
|
||||
style={{ width: `${percentage}%` }}
|
||||
/>
|
||||
</div>
|
||||
@@ -196,7 +196,6 @@ export default function useAgentGraph(
|
||||
hardcodedValues: node.input_default,
|
||||
webhook: node.webhook,
|
||||
uiType: block.uiType,
|
||||
metadata: node.metadata,
|
||||
connections: graph.links
|
||||
.filter((l) => [l.source_id, l.sink_id].includes(node.id))
|
||||
.map((link) => ({
|
||||
@@ -602,10 +601,7 @@ export default function useAgentGraph(
|
||||
id: node.id,
|
||||
block_id: node.data.block_id,
|
||||
input_default: prepareNodeInputData(node),
|
||||
metadata: {
|
||||
position: node.position,
|
||||
...(node.data.metadata || {}),
|
||||
},
|
||||
metadata: { position: node.position },
|
||||
}),
|
||||
),
|
||||
links: links,
|
||||
@@ -684,7 +680,6 @@ export default function useAgentGraph(
|
||||
backend_id: backendNode.id,
|
||||
webhook: backendNode.webhook,
|
||||
executionResults: [],
|
||||
metadata: backendNode.metadata,
|
||||
},
|
||||
}
|
||||
: null;
|
||||
|
||||
@@ -395,9 +395,7 @@ export function isEmpty(value: any): boolean {
|
||||
return (
|
||||
value === undefined ||
|
||||
value === "" ||
|
||||
(typeof value === "object" &&
|
||||
(value instanceof Date ? isNaN(value.getTime()) : _isEmpty(value))) ||
|
||||
(typeof value === "number" && isNaN(value))
|
||||
(typeof value === "object" && _isEmpty(value))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,16 +11,9 @@ const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
|
||||
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
||||
const { user, isUserLoading } = useSupabase();
|
||||
const isCloud = true;
|
||||
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
|
||||
const isLaunchDarklyConfigured = isCloud && envEnabled && clientId;
|
||||
|
||||
console.log({
|
||||
clientId,
|
||||
envEnabled,
|
||||
isCloud,
|
||||
isLaunchDarklyConfigured,
|
||||
});
|
||||
|
||||
const context = useMemo(() => {
|
||||
if (isUserLoading || !user) {
|
||||
return {
|
||||
|
||||
@@ -206,7 +206,6 @@ export class LibraryPage extends BasePage {
|
||||
async clickAgent(agent: Agent): Promise<void> {
|
||||
await this.page
|
||||
.getByRole("heading", { name: agent.name, level: 3 })
|
||||
.first()
|
||||
.click();
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
This guide will walk you through the process of creating and testing a new block for the AutoGPT Agent Server, using the WikipediaSummaryBlock as an example.
|
||||
|
||||
!!! tip "New SDK-Based Approach"
|
||||
For a more comprehensive guide using the new SDK pattern with ProviderBuilder and advanced features like OAuth and webhooks, see the [Block SDK Guide](block-sdk-guide.md).
|
||||
For a more comprehensive guide using the new SDK pattern with ProviderBuilder and advanced features like OAuth and webhooks, see the [Block SDK Guide](block-sdk-guide.md).
|
||||
|
||||
## Understanding Blocks and Testing
|
||||
|
||||
@@ -17,74 +17,74 @@ Follow these steps to create and test a new block:
|
||||
|
||||
2. **Import necessary modules and create a class that inherits from `Block`**. Make sure to include all necessary imports for your block.
|
||||
|
||||
Every block should contain the following:
|
||||
Every block should contain the following:
|
||||
|
||||
```python
|
||||
from backend.data.block import Block, BlockSchema, BlockOutput
|
||||
```
|
||||
```python
|
||||
from backend.data.block import Block, BlockSchema, BlockOutput
|
||||
```
|
||||
|
||||
Example for the Wikipedia summary block:
|
||||
Example for the Wikipedia summary block:
|
||||
|
||||
```python
|
||||
from backend.data.block import Block, BlockSchema, BlockOutput
|
||||
from backend.utils.get_request import GetRequest
|
||||
import requests
|
||||
```python
|
||||
from backend.data.block import Block, BlockSchema, BlockOutput
|
||||
from backend.utils.get_request import GetRequest
|
||||
import requests
|
||||
|
||||
class WikipediaSummaryBlock(Block, GetRequest):
|
||||
# Block implementation will go here
|
||||
```
|
||||
class WikipediaSummaryBlock(Block, GetRequest):
|
||||
# Block implementation will go here
|
||||
```
|
||||
|
||||
3. **Define the input and output schemas** using `BlockSchema`. These schemas specify the data structure that the block expects to receive (input) and produce (output).
|
||||
|
||||
- The input schema defines the structure of the data the block will process. Each field in the schema represents a required piece of input data.
|
||||
- The output schema defines the structure of the data the block will return after processing. Each field in the schema represents a piece of output data.
|
||||
|
||||
Example:
|
||||
Example:
|
||||
|
||||
```python
|
||||
class Input(BlockSchema):
|
||||
topic: str # The topic to get the Wikipedia summary for
|
||||
```python
|
||||
class Input(BlockSchema):
|
||||
topic: str # The topic to get the Wikipedia summary for
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str # The summary of the topic from Wikipedia
|
||||
error: str # Any error message if the request fails, error field needs to be named `error`.
|
||||
```
|
||||
class Output(BlockSchema):
|
||||
summary: str # The summary of the topic from Wikipedia
|
||||
error: str # Any error message if the request fails, error field needs to be named `error`.
|
||||
```
|
||||
|
||||
4. **Implement the `__init__` method, including test data and mocks:**
|
||||
|
||||
!!! important
|
||||
Use UUID generator (e.g. https://www.uuidgenerator.net/) for every new block `id` and _do not_ make up your own. Alternatively, you can run this python code to generate an uuid: `print(__import__('uuid').uuid4())`
|
||||
!!! important
|
||||
Use UUID generator (e.g. https://www.uuidgenerator.net/) for every new block `id` and *do not* make up your own. Alternatively, you can run this python code to generate an uuid: `print(__import__('uuid').uuid4())`
|
||||
|
||||
```python
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
# Unique ID for the block, used across users for templates
|
||||
# If you are an AI leave it as is or change to "generate-proper-uuid"
|
||||
id="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
input_schema=WikipediaSummaryBlock.Input, # Assign input schema
|
||||
output_schema=WikipediaSummaryBlock.Output, # Assign output schema
|
||||
```python
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
# Unique ID for the block, used across users for templates
|
||||
# If you are an AI leave it as is or change to "generate-proper-uuid"
|
||||
id="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
input_schema=WikipediaSummaryBlock.Input, # Assign input schema
|
||||
output_schema=WikipediaSummaryBlock.Output, # Assign output schema
|
||||
|
||||
# Provide sample input, output and test mock for testing the block
|
||||
# Provide sample input, output and test mock for testing the block
|
||||
|
||||
test_input={"topic": "Artificial Intelligence"},
|
||||
test_output=("summary", "summary content"),
|
||||
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
|
||||
)
|
||||
```
|
||||
test_input={"topic": "Artificial Intelligence"},
|
||||
test_output=("summary", "summary content"),
|
||||
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
|
||||
)
|
||||
```
|
||||
|
||||
- `id`: A unique identifier for the block.
|
||||
- `id`: A unique identifier for the block.
|
||||
|
||||
- `input_schema` and `output_schema`: Define the structure of the input and output data.
|
||||
- `input_schema` and `output_schema`: Define the structure of the input and output data.
|
||||
|
||||
Let's break down the testing components:
|
||||
Let's break down the testing components:
|
||||
|
||||
- `test_input`: This is a sample input that will be used to test the block. It should be a valid input according to your Input schema.
|
||||
- `test_input`: This is a sample input that will be used to test the block. It should be a valid input according to your Input schema.
|
||||
|
||||
- `test_output`: This is the expected output when running the block with the `test_input`. It should match your Output schema. For non-deterministic outputs or when you only want to assert the type, you can use Python types instead of specific values. In this example, `("summary", str)` asserts that the output key is "summary" and its value is a string.
|
||||
- `test_output`: This is the expected output when running the block with the `test_input`. It should match your Output schema. For non-deterministic outputs or when you only want to assert the type, you can use Python types instead of specific values. In this example, `("summary", str)` asserts that the output key is "summary" and its value is a string.
|
||||
|
||||
- `test_mock`: This is crucial for blocks that make network calls. It provides a mock function that replaces the actual network call during testing.
|
||||
- `test_mock`: This is crucial for blocks that make network calls. It provides a mock function that replaces the actual network call during testing.
|
||||
|
||||
In this case, we're mocking the `get_request` method to always return a dictionary with an 'extract' key, simulating a successful API response. This allows us to test the block's logic without making actual network requests, which could be slow, unreliable, or rate-limited.
|
||||
In this case, we're mocking the `get_request` method to always return a dictionary with an 'extract' key, simulating a successful API response. This allows us to test the block's logic without making actual network requests, which could be slow, unreliable, or rate-limited.
|
||||
|
||||
5. **Implement the `run` method with error handling.** This should contain the main logic of the block:
|
||||
|
||||
@@ -106,21 +106,19 @@ Follow these steps to create and test a new block:
|
||||
- **Error handling**: Handle various exceptions that might occur during the API request and data processing. We don't need to catch all exceptions, only the ones we expect and can handle. The uncaught exceptions will be automatically yielded as `error` in the output. Any block that raises an exception (or yields an `error` output) will be marked as failed. Prefer raising exceptions over yielding `error`, as it will stop the execution immediately.
|
||||
- **Yield**: Use `yield` to output the results. Prefer to output one result object at a time. If you are calling a function that returns a list, you can yield each item in the list separately. You can also yield the whole list as well, but do both rather than yielding the list. For example: If you were writing a block that outputs emails, you'd yield each email as a separate result object, but you could also yield the whole list as an additional single result object. Yielding output named `error` will break the execution right away and mark the block execution as failed.
|
||||
- **kwargs**: The `kwargs` parameter is used to pass additional arguments to the block. It is not used in the example above, but it is available to the block. You can also have args as inline signatures in the run method ala `def run(self, input_data: Input, *, user_id: str, **kwargs) -> BlockOutput:`.
|
||||
Available kwargs are:
|
||||
- `user_id`: The ID of the user running the block.
|
||||
- `graph_id`: The ID of the agent that is executing the block. This is the same for every version of the agent
|
||||
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
|
||||
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
|
||||
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
|
||||
Available kwargs are:
|
||||
- `user_id`: The ID of the user running the block.
|
||||
- `graph_id`: The ID of the agent that is executing the block. This is the same for every version of the agent
|
||||
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
|
||||
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
|
||||
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
|
||||
|
||||
### Field Types
|
||||
|
||||
#### oneOf fields
|
||||
|
||||
`oneOf` allows you to specify that a field must be exactly one of several possible options. This is useful when you want your block to accept different types of inputs that are mutually exclusive.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
attachment: Union[Media, DeepLink, Poll, Place, Quote] = SchemaField(
|
||||
discriminator='discriminator',
|
||||
@@ -131,7 +129,6 @@ attachment: Union[Media, DeepLink, Poll, Place, Quote] = SchemaField(
|
||||
The `discriminator` parameter tells AutoGPT which field to look at in the input to determine which type it is.
|
||||
|
||||
In each model, you need to define the discriminator value:
|
||||
|
||||
```python
|
||||
class Media(BaseModel):
|
||||
discriminator: Literal['media']
|
||||
@@ -143,11 +140,9 @@ class DeepLink(BaseModel):
|
||||
```
|
||||
|
||||
#### OptionalOneOf fields
|
||||
|
||||
`OptionalOneOf` is similar to `oneOf` but allows the field to be optional (None). This means the field can be either one of the specified types or None.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None = SchemaField(
|
||||
discriminator='discriminator',
|
||||
@@ -284,20 +279,16 @@ response = requests.post(
|
||||
|
||||
The `ProviderName` enum is the single source of truth for which providers exist in our system.
|
||||
Naturally, to add an authenticated block for a new provider, you'll have to add it here too.
|
||||
|
||||
<details>
|
||||
<summary><code>ProviderName</code> definition</summary>
|
||||
|
||||
```python title="backend/integrations/providers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/integrations/providers.py:ProviderName"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Multiple credentials inputs
|
||||
|
||||
Multiple credentials inputs are supported, under the following conditions:
|
||||
|
||||
- The name of each of the credentials input fields must end with `_credentials`.
|
||||
- The names of the credentials input fields must match the names of the corresponding
|
||||
parameters on the `run(..)` method of the block.
|
||||
@@ -305,6 +296,7 @@ Multiple credentials inputs are supported, under the following conditions:
|
||||
is a `dict[str, Credentials]`, with for each required credentials input the
|
||||
parameter name as the key and suitable test credentials as the value.
|
||||
|
||||
|
||||
#### Adding an OAuth2 service integration
|
||||
|
||||
To add support for a new OAuth2-authenticated service, you'll need to add an `OAuthHandler`.
|
||||
@@ -342,25 +334,22 @@ Aside from implementing the `OAuthHandler` itself, adding a handler into the sys
|
||||
|
||||
#### Adding to the frontend
|
||||
|
||||
You will need to add the provider (api or oauth) to the `CredentialsInput` component in [`/frontend/src/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs.tsx`](<https://github.com/Significant-Gravitas/AutoGPT/blob/dev/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs.tsx>).
|
||||
You will need to add the provider (api or oauth) to the `CredentialsInput` component in [`frontend/src/components/integrations/credentials-input.tsx`](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/frontend/src/components/integrations/credentials-input.tsx).
|
||||
|
||||
```ts title="frontend/src/components/integrations/credentials-input.tsx"
|
||||
--8 <
|
||||
--"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs.tsx:ProviderIconsEmbed";
|
||||
--8<-- "autogpt_platform/frontend/src/components/integrations/credentials-input.tsx:ProviderIconsEmbed"
|
||||
```
|
||||
|
||||
You will also need to add the provider to the credentials provider list in [`frontend/src/components/integrations/helper.ts`](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/frontend/src/components/integrations/helper.ts).
|
||||
|
||||
```ts title="frontend/src/components/integrations/helper.ts"
|
||||
--8 <
|
||||
--"autogpt_platform/frontend/src/components/integrations/helper.ts:CredentialsProviderNames";
|
||||
--8<-- "autogpt_platform/frontend/src/components/integrations/helper.ts:CredentialsProviderNames"
|
||||
```
|
||||
|
||||
Finally you will need to add the provider to the `CredentialsType` enum in [`frontend/src/lib/autogpt-server-api/types.ts`](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts).
|
||||
|
||||
```ts title="frontend/src/lib/autogpt-server-api/types.ts"
|
||||
--8 <
|
||||
--"autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts:BlockIOCredentialsSubSchema";
|
||||
--8<-- "autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts:BlockIOCredentialsSubSchema"
|
||||
```
|
||||
|
||||
#### Example: GitHub integration
|
||||
@@ -402,12 +391,12 @@ rather than being executed manually.
|
||||
Creating and running a webhook-triggered block involves three main components:
|
||||
|
||||
- The block itself, which specifies:
|
||||
- Inputs for the user to select a resource and events to subscribe to
|
||||
- A `credentials` input with the scopes needed to manage webhooks
|
||||
- Logic to turn the webhook payload into outputs for the webhook block
|
||||
- Inputs for the user to select a resource and events to subscribe to
|
||||
- A `credentials` input with the scopes needed to manage webhooks
|
||||
- Logic to turn the webhook payload into outputs for the webhook block
|
||||
- The `WebhooksManager` for the corresponding webhook service provider, which handles:
|
||||
- (De)registering webhooks with the provider
|
||||
- Parsing and validating incoming webhook payloads
|
||||
- (De)registering webhooks with the provider
|
||||
- Parsing and validating incoming webhook payloads
|
||||
- The credentials system for the corresponding service provider, which may include an `OAuthHandler`
|
||||
|
||||
There is more going on under the hood, e.g. to store and retrieve webhooks and their
|
||||
@@ -420,72 +409,67 @@ To create a webhook-triggered block, follow these additional steps on top of the
|
||||
|
||||
1. **Define `webhook_config`** in your block's `__init__` method.
|
||||
|
||||
<details>
|
||||
<summary>Example: <code>GitHubPullRequestTriggerBlock</code></summary>
|
||||
<details>
|
||||
<summary>Example: <code>GitHubPullRequestTriggerBlock</code></summary>
|
||||
|
||||
```python title="backend/blocks/github/triggers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/blocks/github/triggers.py:example-webhook_config"
|
||||
```
|
||||
```python title="backend/blocks/github/triggers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/blocks/github/triggers.py:example-webhook_config"
|
||||
```
|
||||
</details>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary><code>BlockWebhookConfig</code> definition</summary>
|
||||
|
||||
<details>
|
||||
<summary><code>BlockWebhookConfig</code> definition</summary>
|
||||
|
||||
```python title="backend/data/block.py"
|
||||
--8<-- "autogpt_platform/backend/backend/data/block.py:BlockWebhookConfig"
|
||||
```
|
||||
|
||||
</details>
|
||||
```python title="backend/data/block.py"
|
||||
--8<-- "autogpt_platform/backend/backend/data/block.py:BlockWebhookConfig"
|
||||
```
|
||||
</details>
|
||||
|
||||
2. **Define event filter input** in your block's Input schema.
|
||||
This allows the user to select which specific types of events will trigger the block in their agent.
|
||||
This allows the user to select which specific types of events will trigger the block in their agent.
|
||||
|
||||
<details>
|
||||
<summary>Example: <code>GitHubPullRequestTriggerBlock</code></summary>
|
||||
<details>
|
||||
<summary>Example: <code>GitHubPullRequestTriggerBlock</code></summary>
|
||||
|
||||
```python title="backend/blocks/github/triggers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/blocks/github/triggers.py:example-event-filter"
|
||||
```
|
||||
```python title="backend/blocks/github/triggers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/blocks/github/triggers.py:example-event-filter"
|
||||
```
|
||||
</details>
|
||||
|
||||
</details>
|
||||
- The name of the input field (`events` in this case) must match `webhook_config.event_filter_input`.
|
||||
- The event filter itself must be a Pydantic model with only boolean fields.
|
||||
|
||||
- The name of the input field (`events` in this case) must match `webhook_config.event_filter_input`.
|
||||
- The event filter itself must be a Pydantic model with only boolean fields.
|
||||
4. **Include payload field** in your block's Input schema.
|
||||
|
||||
3. **Include payload field** in your block's Input schema.
|
||||
<details>
|
||||
<summary>Example: <code>GitHubTriggerBase</code></summary>
|
||||
|
||||
<details>
|
||||
<summary>Example: <code>GitHubTriggerBase</code></summary>
|
||||
```python title="backend/blocks/github/triggers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/blocks/github/triggers.py:example-payload-field"
|
||||
```
|
||||
</details>
|
||||
|
||||
```python title="backend/blocks/github/triggers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/blocks/github/triggers.py:example-payload-field"
|
||||
```
|
||||
5. **Define `credentials` input** in your block's Input schema.
|
||||
- Its scopes must be sufficient to manage a user's webhooks through the provider's API
|
||||
- See [Blocks with authentication](#blocks-with-authentication) for further details
|
||||
|
||||
</details>
|
||||
6. **Process webhook payload** and output relevant parts of it in your block's `run` method.
|
||||
|
||||
4. **Define `credentials` input** in your block's Input schema.
|
||||
<details>
|
||||
<summary>Example: <code>GitHubPullRequestTriggerBlock</code></summary>
|
||||
|
||||
- Its scopes must be sufficient to manage a user's webhooks through the provider's API
|
||||
- See [Blocks with authentication](#blocks-with-authentication) for further details
|
||||
```python
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "payload", input_data.payload
|
||||
yield "sender", input_data.payload["sender"]
|
||||
yield "event", input_data.payload["action"]
|
||||
yield "number", input_data.payload["number"]
|
||||
yield "pull_request", input_data.payload["pull_request"]
|
||||
```
|
||||
|
||||
5. **Process webhook payload** and output relevant parts of it in your block's `run` method.
|
||||
|
||||
<details>
|
||||
<summary>Example: <code>GitHubPullRequestTriggerBlock</code></summary>
|
||||
|
||||
```python
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "payload", input_data.payload
|
||||
yield "sender", input_data.payload["sender"]
|
||||
yield "event", input_data.payload["action"]
|
||||
yield "number", input_data.payload["number"]
|
||||
yield "pull_request", input_data.payload["pull_request"]
|
||||
```
|
||||
|
||||
Note that the `credentials` parameter can be omitted if the credentials
|
||||
aren't used at block runtime, like in the example.
|
||||
</details>
|
||||
Note that the `credentials` parameter can be omitted if the credentials
|
||||
aren't used at block runtime, like in the example.
|
||||
</details>
|
||||
|
||||
#### Adding a Webhooks Manager
|
||||
|
||||
@@ -516,7 +500,6 @@ GitHub Webhook triggers: <a href="https://github.com/Significant-Gravitas/AutoGP
|
||||
```python title="backend/blocks/github/triggers.py"
|
||||
--8<-- "autogpt_platform/backend/backend/blocks/github/triggers.py:GithubTriggerExample"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -527,7 +510,6 @@ GitHub Webhooks Manager: <a href="https://github.com/Significant-Gravitas/AutoGP
|
||||
```python title="backend/integrations/webhooks/github.py"
|
||||
--8<-- "autogpt_platform/backend/backend/integrations/webhooks/github.py:GithubWebhooksManager"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Key Points to Remember
|
||||
@@ -581,24 +563,22 @@ class MyNetworkBlock(Block):
|
||||
The `Requests` wrapper provides these security features:
|
||||
|
||||
1. **URL Validation**:
|
||||
|
||||
- Blocks requests to private IP ranges (RFC 1918)
|
||||
- Validates URL format and protocol
|
||||
- Resolves DNS and checks IP addresses
|
||||
- Supports whitelisting trusted origins
|
||||
- Blocks requests to private IP ranges (RFC 1918)
|
||||
- Validates URL format and protocol
|
||||
- Resolves DNS and checks IP addresses
|
||||
- Supports whitelisting trusted origins
|
||||
|
||||
2. **Secure Defaults**:
|
||||
|
||||
- Disables redirects by default
|
||||
- Raises exceptions for non-200 status codes
|
||||
- Supports custom headers and validators
|
||||
- Disables redirects by default
|
||||
- Raises exceptions for non-200 status codes
|
||||
- Supports custom headers and validators
|
||||
|
||||
3. **Protected IP Ranges**:
|
||||
The wrapper denies requests to these networks:
|
||||
|
||||
```python title="backend/util/request.py"
|
||||
--8<-- "autogpt_platform/backend/backend/util/request.py:BLOCKED_IP_NETWORKS"
|
||||
```
|
||||
```python title="backend/util/request.py"
|
||||
--8<-- "autogpt_platform/backend/backend/util/request.py:BLOCKED_IP_NETWORKS"
|
||||
```
|
||||
|
||||
### Custom Request Configuration
|
||||
|
||||
@@ -621,9 +601,9 @@ custom_requests = Requests(
|
||||
|
||||
2. **Define appropriate test_output**:
|
||||
|
||||
- For deterministic outputs, use specific expected values.
|
||||
- For non-deterministic outputs or when only the type matters, use Python types (e.g., `str`, `int`, `dict`).
|
||||
- You can mix specific values and types, e.g., `("key1", str), ("key2", 42)`.
|
||||
- For deterministic outputs, use specific expected values.
|
||||
- For non-deterministic outputs or when only the type matters, use Python types (e.g., `str`, `int`, `dict`).
|
||||
- You can mix specific values and types, e.g., `("key1", str), ("key2", 42)`.
|
||||
|
||||
3. **Use test_mock for network calls**: This prevents tests from failing due to network issues or API changes.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user