mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-07 22:33:57 -05:00
Merge branch 'dev' into copilot/fix-d17b24f0-80fd-4065-b701-fc1cb15e1958
This commit is contained in:
@@ -10,7 +10,7 @@ from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
@@ -20,7 +20,9 @@ def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
@@ -30,7 +32,7 @@ def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
@@ -53,12 +53,12 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
@@ -69,28 +69,28 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_user_missing_sub(self):
|
||||
async def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_requires_user_empty_sub(self):
|
||||
async def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
@@ -101,51 +101,51 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
def test_get_user_id_missing_sub(self):
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_get_user_id_none_sub(self):
|
||||
async def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -184,7 +184,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -203,7 +203,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
def test_dependency_with_complex_payload(self):
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -256,14 +256,14 @@ class TestAuthDependenciesEdgeCases:
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = requires_user(complex_payload)
|
||||
user = await requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = requires_admin_user(complex_payload)
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
@@ -272,11 +272,11 @@ class TestAuthDependenciesEdgeCases:
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = requires_user(unicode_payload)
|
||||
user = await requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
def test_dependency_with_null_values(self):
|
||||
async def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -286,18 +286,18 @@ class TestAuthDependenciesEdgeCases:
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = requires_user(null_payload)
|
||||
user = await requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
def test_concurrent_requests_isolation(self):
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
@@ -314,7 +314,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
def test_dependency_error_cases(
|
||||
async def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
@@ -325,7 +325,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
def test_dependency_valid_user(self):
|
||||
async def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_payload(
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
@@ -115,12 +115,23 @@ def cached(
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
# Cache storage and locks
|
||||
# Cache storage and per-event-loop locks
|
||||
cache_storage = {}
|
||||
_event_loop_locks = {} # Maps event loop to its asyncio.Lock
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
# Async function with asyncio.Lock
|
||||
cache_lock = asyncio.Lock()
|
||||
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -141,7 +152,7 @@ def cached(
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
|
||||
@@ -90,6 +90,7 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -161,43 +162,52 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
|
||||
@@ -98,6 +98,7 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -171,50 +172,60 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get(
|
||||
"competition"
|
||||
),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get(
|
||||
"keyword_properties", {}
|
||||
).get("keyword_difficulty"),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch related keywords: {str(e)}"
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
|
||||
@@ -101,6 +101,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
@@ -213,6 +214,9 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
|
||||
@@ -519,34 +519,121 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
# Use retry decorator for LLM calls with validation
|
||||
from backend.util.retry import create_retry_decorator
|
||||
|
||||
# Create retry decorator that excludes ValueError from retry (for non-LLM errors)
|
||||
llm_retry = create_retry_decorator(
|
||||
max_attempts=input_data.retry,
|
||||
exclude_exceptions=(), # Don't exclude ValueError - we want to retry validation failures
|
||||
context="SmartDecisionMaker LLM call",
|
||||
)
|
||||
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
@llm_retry
|
||||
async def call_llm_with_validation():
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
)
|
||||
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
return response, None # No tool calls, return response
|
||||
|
||||
# Validate all tool calls before proceeding
|
||||
validation_errors = []
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
parameters = tool_def["function"]["parameters"]
|
||||
expected_args = parameters.get("properties", {})
|
||||
required_params = set(parameters.get("required", []))
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
required_params = set()
|
||||
|
||||
# Validate tool call arguments
|
||||
provided_args = set(tool_args.keys())
|
||||
expected_args_set = set(expected_args.keys())
|
||||
|
||||
# Check for unexpected arguments (typos)
|
||||
unexpected_args = provided_args - expected_args_set
|
||||
# Only check for missing REQUIRED parameters
|
||||
missing_required_args = required_params - provided_args
|
||||
|
||||
if unexpected_args or missing_required_args:
|
||||
error_msg = f"Tool call '{tool_name}' has parameter errors:"
|
||||
if unexpected_args:
|
||||
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
|
||||
if missing_required_args:
|
||||
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
|
||||
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
|
||||
if required_params:
|
||||
error_msg += f" Required parameters: {sorted(required_params)}."
|
||||
validation_errors.append(error_msg)
|
||||
|
||||
# If validation failed, add feedback and raise for retry
|
||||
if validation_errors:
|
||||
# Add the failed response to conversation
|
||||
prompt.append(response.raw_response)
|
||||
|
||||
# Add error feedback for retry
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
+ "\n".join(f"- {error}" for error in validation_errors)
|
||||
+ "\n\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback})
|
||||
|
||||
raise ValueError(
|
||||
f"Tool call validation failed: {'; '.join(validation_errors)}"
|
||||
)
|
||||
|
||||
return response, validation_errors
|
||||
|
||||
# Call the LLM with retry logic
|
||||
response, validation_errors = await call_llm_with_validation()
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
return
|
||||
|
||||
# If we get here, validation passed - yield tool outputs
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
# Get expected arguments (already validated above)
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
@@ -555,7 +642,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
@@ -563,14 +649,11 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = tool_args.keys()
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
|
||||
# Yield provided arguments and None for missing ones
|
||||
# Yield provided arguments, use .get() for optional parameters
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args.get(arg_name)
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
|
||||
@@ -249,3 +249,232 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Verify outputs
|
||||
assert "finished" in outputs # Should have finished since no tool calls
|
||||
assert outputs["finished"] == "I need to think about this."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"description": "Search for keywords with difficulty filtering",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"max_keyword_difficulty": {
|
||||
"type": "integer",
|
||||
"description": "Maximum keyword difficulty (required)",
|
||||
},
|
||||
"optional_param": {
|
||||
"type": "string",
|
||||
"description": "Optional parameter with default",
|
||||
"default": "default_value",
|
||||
},
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Tool call with TYPO in parameter name (should retry and eventually fail)
|
||||
mock_tool_call_with_typo = MagicMock()
|
||||
mock_tool_call_with_typo.function.name = "search_keywords"
|
||||
mock_tool_call_with_typo.function.arguments = '{"query": "test", "maximum_keyword_difficulty": 50}' # TYPO: maximum instead of max
|
||||
|
||||
mock_response_with_typo = MagicMock()
|
||||
mock_response_with_typo.response = None
|
||||
mock_response_with_typo.tool_calls = [mock_tool_call_with_typo]
|
||||
mock_response_with_typo.prompt_tokens = 50
|
||||
mock_response_with_typo.completion_tokens = 25
|
||||
mock_response_with_typo.reasoning = None
|
||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_with_typo
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
)
|
||||
|
||||
# Should raise ValueError after retries due to typo'd parameter name
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify error message contains details about the typo
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call validation failed" in error_msg
|
||||
assert "Unknown parameters: ['maximum_keyword_difficulty']" in error_msg
|
||||
|
||||
# Verify that LLM was called the expected number of times (retries)
|
||||
assert mock_llm_call.call_count == 2 # Should retry based on input_data.retry
|
||||
|
||||
# Test case 2: Tool call missing REQUIRED parameter (should raise ValueError)
|
||||
mock_tool_call_missing_required = MagicMock()
|
||||
mock_tool_call_missing_required.function.name = "search_keywords"
|
||||
mock_tool_call_missing_required.function.arguments = (
|
||||
'{"query": "test"}' # Missing required max_keyword_difficulty
|
||||
)
|
||||
|
||||
mock_response_missing_required = MagicMock()
|
||||
mock_response_missing_required.response = None
|
||||
mock_response_missing_required.tool_calls = [mock_tool_call_missing_required]
|
||||
mock_response_missing_required.prompt_tokens = 50
|
||||
mock_response_missing_required.completion_tokens = 25
|
||||
mock_response_missing_required.reasoning = None
|
||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_missing_required
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should raise ValueError due to missing required parameter
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Missing required parameters: ['max_keyword_difficulty']" in error_msg
|
||||
|
||||
# Test case 3: Valid tool call with OPTIONAL parameter missing (should succeed)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "search_keywords"
|
||||
mock_tool_call_valid.function.arguments = '{"query": "test", "max_keyword_difficulty": 50}' # optional_param missing, but that's OK
|
||||
|
||||
mock_response_valid = MagicMock()
|
||||
mock_response_valid.response = None
|
||||
mock_response_valid.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_valid.prompt_tokens = 50
|
||||
mock_response_valid.completion_tokens = 25
|
||||
mock_response_valid.reasoning = None
|
||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_valid
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed - optional parameter missing is OK
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify tool outputs were generated correctly
|
||||
assert "tools_^_search_keywords_~_query" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
# Optional parameter should be None when not provided
|
||||
assert "tools_^_search_keywords_~_optional_param" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] is None
|
||||
|
||||
# Test case 4: Valid tool call with ALL parameters (should succeed)
|
||||
mock_tool_call_all_params = MagicMock()
|
||||
mock_tool_call_all_params.function.name = "search_keywords"
|
||||
mock_tool_call_all_params.function.arguments = '{"query": "test", "max_keyword_difficulty": 50, "optional_param": "custom_value"}'
|
||||
|
||||
mock_response_all_params = MagicMock()
|
||||
mock_response_all_params.response = None
|
||||
mock_response_all_params.tool_calls = [mock_tool_call_all_params]
|
||||
mock_response_all_params.prompt_tokens = 50
|
||||
mock_response_all_params.completion_tokens = 25
|
||||
mock_response_all_params.reasoning = None
|
||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_all_params
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed with all parameters
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify all tool outputs were generated correctly
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
|
||||
@@ -9,6 +9,7 @@ from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -178,9 +179,13 @@ async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
async def list_user_api_keys(
|
||||
user_id: str, limit: int = MAX_USER_API_KEYS_FETCH
|
||||
) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
@@ -69,6 +69,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
|
||||
@@ -23,6 +23,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -905,7 +906,9 @@ class UserCredit(UserCreditBase):
|
||||
),
|
||||
)
|
||||
|
||||
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
|
||||
async def get_refund_requests(
|
||||
self, user_id: str, limit: int = MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
) -> list[RefundRequest]:
|
||||
return [
|
||||
RefundRequest(
|
||||
id=r.id,
|
||||
@@ -921,6 +924,7 @@ class UserCredit(UserCreditBase):
|
||||
for r in await CreditRefundRequest.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
@@ -1059,11 +1060,14 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[GraphModel]:
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: str, limit: int = MAX_GRAPH_VERSIONS_FETCH
|
||||
) -> list[GraphModel]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id, "userId": user_id},
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
if not graph_versions:
|
||||
|
||||
@@ -14,6 +14,7 @@ AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
EXECUTION_RESULT_ORDER: list[prisma.types.AgentNodeExecutionOrderByInput] = [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
@@ -28,6 +29,13 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH = 10
|
||||
|
||||
# Default limits for potentially large result sets
|
||||
MAX_CREDIT_REFUND_REQUESTS_FETCH = 100
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH = 100
|
||||
MAX_USER_API_KEYS_FETCH = 500
|
||||
MAX_GRAPH_VERSIONS_FETCH = 50
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
@@ -71,13 +79,56 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
def library_agent_include(
|
||||
user_id: str,
|
||||
include_nodes: bool = True,
|
||||
include_executions: bool = True,
|
||||
execution_limit: int = MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
|
||||
) -> prisma.types.LibraryAgentInclude:
|
||||
"""
|
||||
Fully configurable includes for library agent queries with performance optimization.
|
||||
|
||||
Args:
|
||||
user_id: User ID for filtering user-specific data
|
||||
include_nodes: Whether to include graph nodes (default: True, needed for get_sub_graphs)
|
||||
include_executions: Whether to include executions (default: True, safe with execution_limit)
|
||||
execution_limit: Limit on executions to fetch (default: MAX_LIBRARY_AGENT_EXECUTIONS_FETCH)
|
||||
|
||||
Defaults maintain backward compatibility and safety - includes everything needed for all functionality.
|
||||
For performance optimization, explicitly set include_nodes=False and include_executions=False
|
||||
for listing views where frontend fetches data separately.
|
||||
|
||||
Performance impact:
|
||||
- Default (full nodes + limited executions): Original performance, works everywhere
|
||||
- Listing optimization (no nodes/executions): ~2s for 15 agents vs potential timeouts
|
||||
- Unlimited executions: varies by user (thousands of executions = timeouts)
|
||||
"""
|
||||
result: prisma.types.LibraryAgentInclude = {
|
||||
"Creator": True, # Always needed for creator info
|
||||
}
|
||||
|
||||
# Build AgentGraph include based on requested options
|
||||
if include_nodes or include_executions:
|
||||
agent_graph_include = {}
|
||||
|
||||
# Add nodes if requested (always full nodes)
|
||||
if include_nodes:
|
||||
agent_graph_include.update(AGENT_GRAPH_INCLUDE) # Full nodes
|
||||
|
||||
# Add executions if requested
|
||||
if include_executions:
|
||||
agent_graph_include["Executions"] = {
|
||||
"where": {"userId": user_id},
|
||||
"order_by": {"createdAt": "desc"},
|
||||
"take": execution_limit,
|
||||
}
|
||||
|
||||
result["AgentGraph"] = cast(
|
||||
prisma.types.AgentGraphArgsFromLibraryAgent,
|
||||
{"include": agent_graph_include},
|
||||
)
|
||||
else:
|
||||
# Default: Basic metadata only (fast - recommended for most use cases)
|
||||
result["AgentGraph"] = True # Basic graph metadata (name, description, id)
|
||||
|
||||
return result
|
||||
|
||||
@@ -11,7 +11,10 @@ from prisma.types import (
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.includes import (
|
||||
INTEGRATION_WEBHOOK_INCLUDE,
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
@@ -128,22 +131,36 @@ async def get_webhook(
|
||||
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[True],
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[False] = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: bool = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
|
||||
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClusterLock:
|
||||
"""Simple Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Integration tests for ClusterLock - Redis-based distributed locking.
|
||||
|
||||
Tests the complete lock lifecycle without mocking Redis to ensure
|
||||
real-world behavior is correct. Covers acquisition, refresh, expiry,
|
||||
contention, and error scenarios.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from threading import Thread
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client():
|
||||
"""Get Redis client for testing using same config as backend."""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
|
||||
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
|
||||
client = redis.Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
|
||||
)
|
||||
|
||||
# Clean up any existing test keys
|
||||
try:
|
||||
for key in client.scan_iter(match="test_lock:*"):
|
||||
client.delete(key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lock_key():
|
||||
"""Generate unique lock key for each test."""
|
||||
return f"test_lock:{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def owner_id():
|
||||
"""Generate unique owner ID for each test."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestClusterLockBasic:
|
||||
"""Basic lock acquisition and release functionality."""
|
||||
|
||||
def test_lock_acquisition_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test basic lock acquisition succeeds."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Lock should be acquired successfully
|
||||
result = lock.try_acquire()
|
||||
assert result == owner_id # Returns our owner_id when successfully acquired
|
||||
assert lock._last_refresh > 0
|
||||
|
||||
# Lock key should exist in Redis
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
assert redis_client.get(lock_key).decode("utf-8") == owner_id
|
||||
|
||||
def test_lock_acquisition_contention(self, redis_client, lock_key):
|
||||
"""Test second acquisition fails when lock is held."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=60)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=60)
|
||||
|
||||
# First lock should succeed
|
||||
result1 = lock1.try_acquire()
|
||||
assert result1 == owner1 # Successfully acquired, returns our owner_id
|
||||
|
||||
# Second lock should fail and return the first owner
|
||||
result2 = lock2.try_acquire()
|
||||
assert result2 == owner1 # Returns the current owner (first owner)
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
def test_lock_release_deletes_redis_key(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock release deletes Redis key and marks locally as released."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
assert lock._last_refresh > 0
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Release should delete Redis key and mark locally as released
|
||||
lock.release()
|
||||
assert lock._last_refresh == 0
|
||||
assert lock._last_refresh == 0.0
|
||||
|
||||
# Redis key should be deleted for immediate release
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# Another lock should be able to acquire immediately
|
||||
new_owner_id = str(uuid.uuid4())
|
||||
new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == new_owner_id
|
||||
|
||||
|
||||
class TestClusterLockRefresh:
|
||||
"""Lock refresh and TTL management."""
|
||||
|
||||
def test_lock_refresh_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock refresh extends TTL."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
original_ttl = redis_client.ttl(lock_key)
|
||||
|
||||
# Wait a bit then refresh
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# TTL should be reset to full timeout (allow for small timing differences)
|
||||
new_ttl = redis_client.ttl(lock_key)
|
||||
assert new_ttl >= original_ttl or new_ttl >= 58 # Allow for timing variance
|
||||
|
||||
def test_lock_refresh_rate_limiting(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh is rate-limited to timeout/10."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=100
|
||||
) # 100s timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# First refresh should work
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Immediate second refresh should be skipped (rate limited) but verify key exists
|
||||
assert lock.refresh() is True # Returns True but skips actual refresh
|
||||
assert lock._last_refresh == first_refresh_time # Time unchanged
|
||||
|
||||
def test_lock_refresh_verifies_existence_during_rate_limit(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test refresh verifies lock existence even during rate limiting."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=100)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates expiry or external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should detect missing key even during rate limit period
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_ownership_lost(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when ownership is lost."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Simulate another process taking the lock
|
||||
different_owner = str(uuid.uuid4())
|
||||
redis_client.set(lock_key, different_owner, ex=60)
|
||||
|
||||
# Force refresh past rate limit and verify it fails
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_when_not_acquired(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when lock was never acquired."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Refresh without acquiring should fail
|
||||
assert lock.refresh() is False
|
||||
|
||||
|
||||
class TestClusterLockExpiry:
|
||||
"""Lock expiry and timeout behavior."""
|
||||
|
||||
def test_lock_natural_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock expires naturally via Redis TTL."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=2
|
||||
) # 2 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Wait for expiry
|
||||
time.sleep(3)
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# New lock with same key should succeed
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
def test_lock_refresh_prevents_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test refreshing prevents lock from expiring."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # 3 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Wait and refresh before expiry
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Wait beyond original timeout
|
||||
time.sleep(2.5)
|
||||
assert redis_client.exists(lock_key) == 1 # Should still exist
|
||||
|
||||
|
||||
class TestClusterLockConcurrency:
|
||||
"""Concurrent access patterns."""
|
||||
|
||||
def test_multiple_threads_contention(self, redis_client, lock_key):
|
||||
"""Test multiple threads competing for same lock."""
|
||||
num_threads = 5
|
||||
successful_acquisitions = []
|
||||
|
||||
def try_acquire_lock(thread_id):
|
||||
owner_id = f"thread_{thread_id}"
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
if lock.try_acquire() == owner_id:
|
||||
successful_acquisitions.append(thread_id)
|
||||
time.sleep(0.1) # Hold lock briefly
|
||||
lock.release()
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
thread = Thread(target=try_acquire_lock, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one thread should have acquired the lock
|
||||
assert len(successful_acquisitions) == 1
|
||||
|
||||
def test_sequential_lock_reuse(self, redis_client, lock_key):
|
||||
"""Test lock can be reused after natural expiry."""
|
||||
owners = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
for i, owner_id in enumerate(owners):
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=1) # 1 second
|
||||
|
||||
assert lock.try_acquire() == owner_id
|
||||
time.sleep(1.5) # Wait for expiry
|
||||
|
||||
# Verify lock expired
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
def test_refresh_during_concurrent_access(self, redis_client, lock_key):
|
||||
"""Test lock refresh works correctly during concurrent access attempts."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=5)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=5)
|
||||
|
||||
# Thread 1 holds lock and refreshes
|
||||
assert lock1.try_acquire() == owner1
|
||||
|
||||
def refresh_continuously():
|
||||
for _ in range(10):
|
||||
lock1._last_refresh = 0 # Force refresh
|
||||
lock1.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
def try_acquire_continuously():
|
||||
attempts = 0
|
||||
while attempts < 20:
|
||||
if lock2.try_acquire() == owner2:
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
attempts += 1
|
||||
return False
|
||||
|
||||
refresh_thread = Thread(target=refresh_continuously)
|
||||
acquire_thread = Thread(target=try_acquire_continuously)
|
||||
|
||||
refresh_thread.start()
|
||||
acquire_thread.start()
|
||||
|
||||
refresh_thread.join()
|
||||
acquire_thread.join()
|
||||
|
||||
# Lock1 should still own the lock due to refreshes
|
||||
assert lock1._last_refresh > 0
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockErrorHandling:
|
||||
"""Error handling and edge cases."""
|
||||
|
||||
def test_redis_connection_failure_on_acquire(self, lock_key, owner_id):
|
||||
"""Test graceful handling when Redis is unavailable during acquisition."""
|
||||
# Use invalid Redis connection
|
||||
bad_redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
lock = ClusterLock(bad_redis, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Should return None for Redis connection failures
|
||||
result = lock.try_acquire()
|
||||
assert result is None # Returns None when Redis fails
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_redis_connection_failure_on_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test graceful handling when Redis fails during refresh."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Acquire normally
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Replace Redis client with failing one
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
|
||||
# Refresh should fail gracefully
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_invalid_lock_parameters(self, redis_client):
|
||||
"""Test validation of lock parameters."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
|
||||
# All parameters are now simple - no validation needed
|
||||
# Just test basic construction works
|
||||
lock = ClusterLock(redis_client, "test_key", owner_id, timeout=60)
|
||||
assert lock.key == "test_key"
|
||||
assert lock.owner_id == owner_id
|
||||
assert lock.timeout == 60
|
||||
|
||||
def test_refresh_after_redis_key_deleted(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh behavior when Redis key is manually deleted."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should fail and mark as not acquired
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockDynamicRefreshInterval:
|
||||
"""Dynamic refresh interval based on timeout."""
|
||||
|
||||
def test_refresh_interval_calculation(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh interval is calculated as max(timeout/10, 1)."""
|
||||
test_cases = [
|
||||
(5, 1), # 5/10 = 0, but minimum is 1
|
||||
(10, 1), # 10/10 = 1
|
||||
(30, 3), # 30/10 = 3
|
||||
(100, 10), # 100/10 = 10
|
||||
(200, 20), # 200/10 = 20
|
||||
(1000, 100), # 1000/10 = 100
|
||||
]
|
||||
|
||||
for timeout, expected_interval in test_cases:
|
||||
lock = ClusterLock(
|
||||
redis_client, f"{lock_key}_{timeout}", owner_id, timeout=timeout
|
||||
)
|
||||
lock.try_acquire()
|
||||
|
||||
# Calculate expected interval using same logic as implementation
|
||||
refresh_interval = max(timeout // 10, 1)
|
||||
assert refresh_interval == expected_interval
|
||||
|
||||
# Test rate limiting works with calculated interval
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Sleep less than interval - should be rate limited
|
||||
time.sleep(0.1)
|
||||
assert lock.refresh() is True
|
||||
assert lock._last_refresh == first_refresh_time # No actual refresh
|
||||
|
||||
|
||||
class TestClusterLockRealWorldScenarios:
|
||||
"""Real-world usage patterns."""
|
||||
|
||||
def test_execution_coordination_simulation(self, redis_client):
|
||||
"""Simulate graph execution coordination across multiple pods."""
|
||||
graph_exec_id = str(uuid.uuid4())
|
||||
lock_key = f"execution:{graph_exec_id}"
|
||||
|
||||
# Simulate 3 pods trying to execute same graph
|
||||
pods = [f"pod_{i}" for i in range(3)]
|
||||
execution_results = {}
|
||||
|
||||
def execute_graph(pod_id):
|
||||
"""Simulate graph execution with cluster lock."""
|
||||
lock = ClusterLock(redis_client, lock_key, pod_id, timeout=300)
|
||||
|
||||
if lock.try_acquire() == pod_id:
|
||||
# Simulate execution work
|
||||
execution_results[pod_id] = "executed"
|
||||
time.sleep(0.1)
|
||||
lock.release()
|
||||
else:
|
||||
execution_results[pod_id] = "rejected"
|
||||
|
||||
threads = []
|
||||
for pod_id in pods:
|
||||
thread = Thread(target=execute_graph, args=(pod_id,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one pod should have executed
|
||||
executed_count = sum(
|
||||
1 for result in execution_results.values() if result == "executed"
|
||||
)
|
||||
rejected_count = sum(
|
||||
1 for result in execution_results.values() if result == "rejected"
|
||||
)
|
||||
|
||||
assert executed_count == 1
|
||||
assert rejected_count == 2
|
||||
|
||||
def test_long_running_execution_with_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test lock maintains ownership during long execution with periodic refresh."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=30
|
||||
) # 30 second timeout, refresh interval = max(30//10, 1) = 3 seconds
|
||||
|
||||
def long_execution_with_refresh():
|
||||
"""Simulate long-running execution with periodic refresh."""
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Simulate 10 seconds of work with refreshes every 2 seconds
|
||||
# This respects rate limiting - actual refreshes will happen at 0s, 3s, 6s, 9s
|
||||
try:
|
||||
for i in range(5): # 5 iterations * 2 seconds = 10 seconds total
|
||||
time.sleep(2)
|
||||
refresh_success = lock.refresh()
|
||||
assert refresh_success is True, f"Refresh failed at iteration {i}"
|
||||
return "completed"
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
# Should complete successfully without losing lock
|
||||
result = long_execution_with_refresh()
|
||||
assert result == "completed"
|
||||
|
||||
def test_graceful_degradation_pattern(self, redis_client, lock_key):
|
||||
"""Test graceful degradation when Redis becomes unavailable."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # Use shorter timeout
|
||||
|
||||
# Normal operation
|
||||
assert lock.try_acquire() == owner_id
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Simulate Redis becoming unavailable
|
||||
original_redis = lock.redis
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host",
|
||||
port=1234,
|
||||
socket_connect_timeout=1,
|
||||
decode_responses=False,
|
||||
)
|
||||
|
||||
# Should degrade gracefully
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
# Restore Redis and verify can acquire again
|
||||
lock.redis = original_redis
|
||||
# Wait for original lock to expire (use longer wait for 3s timeout)
|
||||
time.sleep(4)
|
||||
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run specific test for quick validation
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -10,31 +11,11 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
@@ -55,12 +36,25 @@ from backend.data.execution import (
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
@@ -69,6 +63,7 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
@@ -84,6 +79,7 @@ from backend.util.decorator import (
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -91,6 +87,12 @@ from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
|
||||
settings = Settings()
|
||||
@@ -106,6 +108,7 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
|
||||
@@ -117,10 +120,14 @@ def init_worker():
|
||||
|
||||
|
||||
def execute_graph(
|
||||
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
|
||||
graph_exec_entry: "GraphExecutionEntry",
|
||||
cancel_event: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
|
||||
return _tls.processor.on_graph_execution(
|
||||
graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -429,7 +436,7 @@ class ExecutionProcessor:
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
@@ -583,6 +590,7 @@ class ExecutionProcessor:
|
||||
self,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -641,6 +649,7 @@ class ExecutionProcessor:
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
execution_stats=exec_stats,
|
||||
cluster_lock=cluster_lock,
|
||||
)
|
||||
exec_stats.walltime += timing_info.wall_time
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
@@ -742,6 +751,7 @@ class ExecutionProcessor:
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
cluster_lock: ClusterLock,
|
||||
) -> ExecutionStatus:
|
||||
"""
|
||||
Returns:
|
||||
@@ -927,7 +937,7 @@ class ExecutionProcessor:
|
||||
and execution_queue.empty()
|
||||
and (running_node_execution or running_node_evaluation)
|
||||
):
|
||||
# There is nothing to execute, and no output to process, let's relax for a while.
|
||||
cluster_lock.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
@@ -1219,6 +1229,7 @@ class ExecutionManager(AppProcess):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
@@ -1228,6 +1239,8 @@ class ExecutionManager(AppProcess):
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._execution_locks = {}
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
@@ -1435,17 +1448,46 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
|
||||
# Check for local duplicate execution first
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
# TODO: Make this check cluster-wide, prevent duplicate runs across executor pods.
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
|
||||
)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide execution lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"exec_lock:{graph_exec_id}",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
# Either someone else has it or Redis is unavailable
|
||||
if current_owner is not None:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
self._execution_locks[graph_exec_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"[{self.service_name}] Acquired cluster lock for {graph_exec_id} with executor {self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
|
||||
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
|
||||
future = self.executor.submit(
|
||||
execute_graph, graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
self._update_prompt_metrics()
|
||||
|
||||
@@ -1464,6 +1506,10 @@ class ExecutionManager(AppProcess):
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
)
|
||||
finally:
|
||||
# Release the cluster-wide execution lock
|
||||
if graph_exec_id in self._execution_locks:
|
||||
self._execution_locks[graph_exec_id].release()
|
||||
del self._execution_locks[graph_exec_id]
|
||||
self._cleanup_completed_runs()
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
@@ -1546,6 +1592,10 @@ class ExecutionManager(AppProcess):
|
||||
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
|
||||
)
|
||||
|
||||
for graph_exec_id in self.active_graph_runs:
|
||||
if lock := self._execution_locks.get(graph_exec_id):
|
||||
lock.refresh()
|
||||
|
||||
time.sleep(wait_interval)
|
||||
waited += wait_interval
|
||||
|
||||
@@ -1563,6 +1613,15 @@ class ExecutionManager(AppProcess):
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Release remaining execution locks
|
||||
try:
|
||||
for lock in self._execution_locks.values():
|
||||
lock.release()
|
||||
self._execution_locks.clear()
|
||||
logger.info(f"{prefix} ✅ Released execution locks")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
|
||||
|
||||
# Disconnect the run execution consumer
|
||||
self._stop_message_consumers(
|
||||
self.run_thread,
|
||||
@@ -1668,15 +1727,18 @@ def update_graph_execution_state(
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release lock for key {key}: {e}")
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
|
||||
@@ -151,7 +151,10 @@ class IntegrationCredentialsManager:
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
await _lock.release()
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release OAuth refresh lock: {e}")
|
||||
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
@@ -184,7 +187,10 @@ class IntegrationCredentialsManager:
|
||||
yield
|
||||
finally:
|
||||
if (await lock.locked()) and (await lock.owned()):
|
||||
await lock.release()
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release credentials lock: {e}")
|
||||
|
||||
async def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
class SecurityHeadersMiddleware:
|
||||
"""
|
||||
Middleware to add security headers to responses, with cache control
|
||||
disabled by default for all endpoints except those explicitly allowed.
|
||||
@@ -25,6 +23,8 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"/api/health",
|
||||
"/api/v1/health",
|
||||
"/api/status",
|
||||
"/api/blocks",
|
||||
"/api/v1/blocks",
|
||||
# Public store/marketplace pages (read-only)
|
||||
"/api/store/agents",
|
||||
"/api/v1/store/agents",
|
||||
@@ -49,7 +49,7 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
}
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
super().__init__(app)
|
||||
self.app = app
|
||||
# Compile regex patterns for wildcard matching
|
||||
self.cacheable_patterns = [
|
||||
re.compile(pattern.replace("*", "[^/]+"))
|
||||
@@ -72,26 +72,42 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response: Response = await call_next(request)
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Pure ASGI middleware implementation for better performance than BaseHTTPMiddleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Add general security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
# Extract path from scope
|
||||
path = scope["path"]
|
||||
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in request.url.path:
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow"
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
# Add security headers to the response
|
||||
headers = dict(message.get("headers", []))
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
response.headers["Cache-Control"] = (
|
||||
"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
# Add general security headers (HTTP spec requires proper capitalization)
|
||||
headers[b"X-Content-Type-Options"] = b"nosniff"
|
||||
headers[b"X-Frame-Options"] = b"DENY"
|
||||
headers[b"X-XSS-Protection"] = b"1; mode=block"
|
||||
headers[b"Referrer-Policy"] = b"strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in path:
|
||||
headers[b"X-Robots-Tag"] = b"noindex, nofollow"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(path):
|
||||
headers[b"Cache-Control"] = (
|
||||
b"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
headers[b"Pragma"] = b"no-cache"
|
||||
headers[b"Expires"] = b"0"
|
||||
|
||||
# Convert headers back to list format
|
||||
message["headers"] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -11,6 +12,7 @@ import uvicorn
|
||||
from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
@@ -70,6 +72,26 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.db.connect()
|
||||
|
||||
# Configure thread pool for FastAPI sync operation performance
|
||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||
# - Any endpoint defined with 'def' (not async def)
|
||||
# - Any dependency function defined with 'def' (not async def)
|
||||
# - Manual run_in_threadpool() calls (like JWT decoding)
|
||||
# Default pool size is only 40 threads, causing bottlenecks under high concurrency
|
||||
config = backend.util.settings.Config()
|
||||
try:
|
||||
import anyio.to_thread
|
||||
|
||||
anyio.to_thread.current_default_thread_limiter().total_tokens = (
|
||||
config.fastapi_thread_pool_size
|
||||
)
|
||||
logger.info(
|
||||
f"Thread pool size set to {config.fastapi_thread_pool_size} for sync endpoint/dependency performance"
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Could not configure thread pool size: {e}")
|
||||
# Continue without thread pool configuration
|
||||
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
@@ -140,6 +162,9 @@ app = fastapi.FastAPI(
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Add GZip compression middleware for large responses (like /api/blocks)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=50_000) # 50KB threshold
|
||||
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
@@ -273,12 +298,28 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
uvicorn.run(
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
log_config=None,
|
||||
)
|
||||
config = backend.util.settings.Config()
|
||||
|
||||
# Configure uvicorn with performance optimizations from Kludex FastAPI tips
|
||||
uvicorn_config = {
|
||||
"app": server_app,
|
||||
"host": config.agent_api_host,
|
||||
"port": config.agent_api_port,
|
||||
"log_config": None,
|
||||
# Use httptools for HTTP parsing (if available)
|
||||
"http": "httptools",
|
||||
# Only use uvloop on Unix-like systems (not supported on Windows)
|
||||
"loop": "uvloop" if platform.system() != "Windows" else "auto",
|
||||
}
|
||||
|
||||
# Only add debug in local environment (not supported in all uvicorn versions)
|
||||
if config.app_env == backend.util.settings.AppEnvironment.LOCAL:
|
||||
import os
|
||||
|
||||
# Enable asyncio debug mode via environment variable
|
||||
os.environ["PYTHONASYNCIODEBUG"] = "1"
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
@@ -24,6 +24,8 @@ from fastapi import (
|
||||
Security,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -85,6 +87,7 @@ from backend.server.model import (
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.json import dumps
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
@@ -263,13 +266,10 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
@cached()
|
||||
def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
def _compute_blocks_sync() -> str:
|
||||
"""
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses sync_cache decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
Synchronous function to compute blocks data.
|
||||
This does the heavy lifting: instantiate 226+ blocks, compute costs, serialize.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
@@ -279,11 +279,27 @@ def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
# Convert BlockCost BaseModel objects to dictionaries for JSON serialization
|
||||
costs_dict = [
|
||||
cost.model_dump() if isinstance(cost, BaseModel) else cost
|
||||
for cost in costs
|
||||
]
|
||||
result.append({**block_instance.to_dict(), "costs": costs_dict})
|
||||
|
||||
return result
|
||||
# Use our JSON utility which properly handles complex types through to_dict conversion
|
||||
return dumps(result)
|
||||
|
||||
|
||||
@cached()
|
||||
async def _get_cached_blocks() -> str:
|
||||
"""
|
||||
Async cached function with thundering herd protection.
|
||||
On cache miss: runs heavy work in thread pool
|
||||
On cache hit: returns cached string immediately (no thread pool needed)
|
||||
"""
|
||||
# Only run in thread pool on cache miss - cache hits return immediately
|
||||
return await run_in_threadpool(_compute_blocks_sync)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -291,9 +307,28 @@ def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(requires_user)],
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": {"additionalProperties": True, "type": "object"},
|
||||
"type": "array",
|
||||
"title": "Response Getv1List Available Blocks",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
return _get_cached_blocks()
|
||||
async def get_graph_blocks() -> Response:
|
||||
# Cache hit: returns immediately, Cache miss: runs in thread pool
|
||||
content = await _get_cached_blocks()
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
|
||||
@@ -101,7 +101,9 @@ async def list_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -185,7 +187,9 @@ async def list_favorite_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -417,7 +421,9 @@ async def create_library_agent(
|
||||
}
|
||||
},
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
for graph_entry in graph_entries
|
||||
)
|
||||
@@ -642,7 +648,9 @@ async def add_store_agent_to_library(
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
|
||||
@@ -177,7 +177,9 @@ async def test_add_agent_to_library(mocker):
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
include=library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type, TypeGuard, TypeVar, overload
|
||||
|
||||
import jsonschema
|
||||
import orjson
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from prisma import Json
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .type import type_match
|
||||
|
||||
# Precompiled regex to remove PostgreSQL-incompatible control characters
|
||||
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
|
||||
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
if isinstance(data, BaseModel):
|
||||
@@ -15,7 +21,9 @@ def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
def dumps(
|
||||
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
|
||||
) -> str:
|
||||
"""
|
||||
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
|
||||
|
||||
@@ -28,9 +36,13 @@ def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
data : Any
|
||||
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
|
||||
*args : Any
|
||||
Additional positional arguments passed to json.dumps()
|
||||
Additional positional arguments
|
||||
indent : int | None
|
||||
If not None, pretty-print with indentation
|
||||
option : int
|
||||
orjson option flags (default: 0)
|
||||
**kwargs : Any
|
||||
Additional keyword arguments passed to json.dumps() (e.g., indent, separators)
|
||||
Additional keyword arguments. Supported: default, ensure_ascii, separators, indent
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -45,7 +57,21 @@ def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
>>> dumps(pydantic_model_instance, indent=2)
|
||||
'{\n "field1": "value1",\n "field2": "value2"\n}'
|
||||
"""
|
||||
return json.dumps(to_dict(data), *args, **kwargs)
|
||||
serializable_data = to_dict(data)
|
||||
|
||||
# Handle indent parameter
|
||||
if indent is not None or kwargs.get("indent") is not None:
|
||||
option |= orjson.OPT_INDENT_2
|
||||
|
||||
# orjson only accepts specific parameters, filter out stdlib json params
|
||||
# ensure_ascii: orjson always produces UTF-8 (better than ASCII)
|
||||
# separators: orjson uses compact separators by default
|
||||
supported_orjson_params = {"default"}
|
||||
orjson_kwargs = {k: v for k, v in kwargs.items() if k in supported_orjson_params}
|
||||
|
||||
return orjson.dumps(serializable_data, option=option, **orjson_kwargs).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -62,9 +88,8 @@ def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
parsed = orjson.loads(data)
|
||||
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
return parsed
|
||||
@@ -99,16 +124,19 @@ def convert_pydantic_to_json(output_data: Any) -> Any:
|
||||
|
||||
|
||||
def SafeJson(data: Any) -> Json:
|
||||
"""Safely serialize data and return Prisma's Json type."""
|
||||
"""
|
||||
Safely serialize data and return Prisma's Json type.
|
||||
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
return Json(
|
||||
data.model_dump(
|
||||
mode="json",
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
json_string = data.model_dump_json(
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
# Round-trip through JSON to ensure proper serialization with fallback for non-serializable values
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
return Json(json.loads(json_string))
|
||||
else:
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
|
||||
# Remove PostgreSQL-incompatible control characters in single regex operation
|
||||
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", json_string)
|
||||
return Json(json.loads(sanitized_json))
|
||||
|
||||
@@ -17,6 +17,37 @@ from backend.util.process import get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Alert threshold for excessive retries
|
||||
EXCESSIVE_RETRY_THRESHOLD = 50
|
||||
|
||||
|
||||
def _send_critical_retry_alert(
|
||||
func_name: str, attempt_number: int, exception: Exception, context: str = ""
|
||||
):
|
||||
"""Send alert when a function is approaching the retry failure threshold."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
|
||||
notification_client = get_notification_manager_client()
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
alert_msg = (
|
||||
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
|
||||
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
|
||||
f"Error: {type(exception).__name__}: {exception}\n\n"
|
||||
f"This operation is about to fail permanently. Investigate immediately."
|
||||
)
|
||||
|
||||
notification_client.discord_system_alert(alert_msg)
|
||||
logger.critical(
|
||||
f"CRITICAL ALERT SENT: Operation {func_name} at attempt {attempt_number}"
|
||||
)
|
||||
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send critical retry alert: {alert_error}")
|
||||
# Don't let alerting failures break the main flow
|
||||
|
||||
|
||||
def _create_retry_callback(context: str = ""):
|
||||
"""Create a retry callback with optional context."""
|
||||
@@ -29,17 +60,22 @@ def _create_retry_callback(context: str = ""):
|
||||
prefix = f"{context}: " if context else ""
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
# Final failure
|
||||
# Final failure - just log the error (alert was already sent at excessive threshold)
|
||||
logger.error(
|
||||
f"{prefix}Giving up after {attempt_number} attempts for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
else:
|
||||
# Retry attempt
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
# Retry attempt - send critical alert only once at threshold
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
_send_critical_retry_alert(
|
||||
func_name, attempt_number, exception, context
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, update_wrapper
|
||||
from functools import update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
@@ -375,6 +375,8 @@ def get_service_client(
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = 0
|
||||
self._async_clients = {} # None key for default async client
|
||||
self._sync_clients = {} # For sync clients (no event loop concept)
|
||||
|
||||
def _create_sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
@@ -398,13 +400,33 @@ def get_service_client(
|
||||
),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return self._create_sync_client()
|
||||
"""Get the sync client (thread-safe singleton)."""
|
||||
# Use service name as key for better identification
|
||||
service_name = service_client_type.get_service_type().__name__
|
||||
if client := self._sync_clients.get(service_name):
|
||||
return client
|
||||
return self._sync_clients.setdefault(
|
||||
service_name, self._create_sync_client()
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return self._create_async_client()
|
||||
"""Get the appropriate async client for the current context.
|
||||
|
||||
Returns per-event-loop client when in async context,
|
||||
falls back to default client otherwise.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if client := self._async_clients.get(loop):
|
||||
return client
|
||||
return self._async_clients.setdefault(loop, self._create_async_client())
|
||||
|
||||
def _handle_connection_error(self, error: Exception) -> None:
|
||||
"""Handle connection errors and implement self-healing"""
|
||||
@@ -423,10 +445,8 @@ def get_service_client(
|
||||
|
||||
# Clear cached clients to force recreation on next access
|
||||
# Only recreate when there's actually a problem
|
||||
if hasattr(self, "sync_client"):
|
||||
delattr(self, "sync_client")
|
||||
if hasattr(self, "async_client"):
|
||||
delattr(self, "async_client")
|
||||
self._sync_clients.clear()
|
||||
self._async_clients.clear()
|
||||
|
||||
# Reset counters
|
||||
self._connection_failure_count = 0
|
||||
@@ -492,28 +512,37 @@ def get_service_client(
|
||||
raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
await self.async_client.aclose()
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
|
||||
# Close all async clients (including default with None key)
|
||||
for client in self._async_clients.values():
|
||||
await client.aclose()
|
||||
self._async_clients.clear()
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
# Note: Cannot close async client synchronously
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
# Note: Cannot close async clients synchronously
|
||||
# They will be cleaned up by garbage collection
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup HTTP clients on garbage collection to prevent resource leaks."""
|
||||
try:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
# Note: Can't await in __del__, so we just close sync
|
||||
# The async client will be cleaned up by garbage collection
|
||||
# Close any remaining sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
|
||||
# Warn if async clients weren't properly closed
|
||||
if self._async_clients:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"DynamicClient async client not explicitly closed. "
|
||||
"DynamicClient async clients not explicitly closed. "
|
||||
"Call aclose() before destroying the client.",
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
|
||||
@@ -59,6 +59,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for graph execution.",
|
||||
)
|
||||
|
||||
# FastAPI Thread Pool Configuration
|
||||
# IMPORTANT: FastAPI automatically offloads ALL sync functions to a thread pool:
|
||||
# - Sync endpoint functions (def instead of async def)
|
||||
# - Sync dependency functions (def instead of async def)
|
||||
# - Manually called run_in_threadpool() operations
|
||||
# Default thread pool size is only 40, which becomes a bottleneck under high concurrency
|
||||
fastapi_thread_pool_size: int = Field(
|
||||
default=60,
|
||||
ge=40,
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
@@ -127,6 +140,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=5 * 60,
|
||||
description="Time in seconds after which the execution stuck on QUEUED status is considered late.",
|
||||
)
|
||||
cluster_lock_timeout: int = Field(
|
||||
default=300,
|
||||
description="Cluster lock timeout in seconds for graph execution coordination.",
|
||||
)
|
||||
execution_late_notification_checkrange_secs: int = Field(
|
||||
default=60 * 60,
|
||||
description="Time in seconds for how far back to check for the late executions.",
|
||||
|
||||
@@ -215,3 +215,29 @@ class TestSafeJson:
|
||||
}
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
def test_control_character_sanitization(self):
|
||||
"""Test that PostgreSQL-incompatible control characters are sanitized by SafeJson."""
|
||||
# Test data with problematic control characters that would cause PostgreSQL errors
|
||||
problematic_data = {
|
||||
"null_byte": "data with \x00 null",
|
||||
"bell_char": "data with \x07 bell",
|
||||
"form_feed": "data with \x0C feed",
|
||||
"escape_char": "data with \x1B escape",
|
||||
"delete_char": "data with \x7F delete",
|
||||
}
|
||||
|
||||
# SafeJson should successfully process data with control characters
|
||||
result = SafeJson(problematic_data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Test that safe whitespace characters are preserved
|
||||
safe_data = {
|
||||
"with_tab": "text with \t tab",
|
||||
"with_newline": "text with \n newline",
|
||||
"with_carriage_return": "text with \r carriage return",
|
||||
"normal_text": "completely normal text",
|
||||
}
|
||||
|
||||
safe_result = SafeJson(safe_data)
|
||||
assert isinstance(safe_result, Json)
|
||||
|
||||
@@ -2,9 +2,16 @@ import asyncio
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import aioclamd
|
||||
# Suppress the specific pkg_resources deprecation warning from aioclamd
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="pkg_resources is deprecated", category=UserWarning
|
||||
)
|
||||
import aioclamd
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean the test database by removing all data while preserving the schema.
|
||||
|
||||
Usage:
|
||||
poetry run python clean_test_db.py [--yes]
|
||||
|
||||
Options:
|
||||
--yes Skip confirmation prompt
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Cleaning Test Database")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get initial counts
|
||||
user_count = await db.user.count()
|
||||
agent_count = await db.agentgraph.count()
|
||||
|
||||
print(f"Current data: {user_count} users, {agent_count} agent graphs")
|
||||
|
||||
if user_count == 0 and agent_count == 0:
|
||||
print("Database is already clean!")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# Check for --yes flag
|
||||
skip_confirm = "--yes" in sys.argv
|
||||
|
||||
if not skip_confirm:
|
||||
response = input("\nDo you want to clean all data? (yes/no): ")
|
||||
if response.lower() != "yes":
|
||||
print("Aborted.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
print("\nCleaning database...")
|
||||
|
||||
# Delete in reverse order of dependencies
|
||||
tables = [
|
||||
("UserNotificationBatch", db.usernotificationbatch),
|
||||
("NotificationEvent", db.notificationevent),
|
||||
("CreditRefundRequest", db.creditrefundrequest),
|
||||
("StoreListingReview", db.storelistingreview),
|
||||
("StoreListingVersion", db.storelistingversion),
|
||||
("StoreListing", db.storelisting),
|
||||
("AgentNodeExecutionInputOutput", db.agentnodeexecutioninputoutput),
|
||||
("AgentNodeExecution", db.agentnodeexecution),
|
||||
("AgentGraphExecution", db.agentgraphexecution),
|
||||
("AgentNodeLink", db.agentnodelink),
|
||||
("LibraryAgent", db.libraryagent),
|
||||
("AgentPreset", db.agentpreset),
|
||||
("IntegrationWebhook", db.integrationwebhook),
|
||||
("AgentNode", db.agentnode),
|
||||
("AgentGraph", db.agentgraph),
|
||||
("AgentBlock", db.agentblock),
|
||||
("APIKey", db.apikey),
|
||||
("CreditTransaction", db.credittransaction),
|
||||
("AnalyticsMetrics", db.analyticsmetrics),
|
||||
("AnalyticsDetails", db.analyticsdetails),
|
||||
("Profile", db.profile),
|
||||
("UserOnboarding", db.useronboarding),
|
||||
("User", db.user),
|
||||
]
|
||||
|
||||
for table_name, table in tables:
|
||||
try:
|
||||
count = await table.count()
|
||||
if count > 0:
|
||||
await table.delete_many()
|
||||
print(f"✓ Deleted {count} records from {table_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Error cleaning {table_name}: {e}")
|
||||
|
||||
# Refresh materialized views (they should be empty now)
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("\n✓ Refreshed materialized views")
|
||||
except Exception as e:
|
||||
print(f"\n⚠ Could not refresh materialized views: {e}")
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Database cleaned successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -8,26 +8,23 @@ Clean, streamlined load testing infrastructure for the AutoGPT Platform using k6
|
||||
# 1. Set up Supabase service key (required for token generation)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# 2. Generate pre-authenticated tokens (first time setup - creates 150+ tokens with 24-hour expiry)
|
||||
node generate-tokens.js
|
||||
# 2. Generate pre-authenticated tokens (first time setup - creates 160+ tokens with 24-hour expiry)
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# 3. Set up k6 cloud credentials (for cloud testing)
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
# 3. Set up k6 cloud credentials (for cloud testing - see Credential Setup section below)
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_PROJECT_ID="4254406"
|
||||
|
||||
# 4. Verify setup and run quick test
|
||||
node run-tests.js verify
|
||||
# 4. Run orchestrated load tests locally
|
||||
node orchestrator/orchestrator.js DEV local
|
||||
|
||||
# 5. Run tests locally (development/debugging)
|
||||
node run-tests.js run all DEV
|
||||
|
||||
# 6. Run tests in k6 cloud (performance testing)
|
||||
node run-tests.js cloud all DEV
|
||||
# 5. Run orchestrated load tests in k6 cloud (recommended)
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
```
|
||||
|
||||
## 📋 Unified Test Runner
|
||||
## 📋 Load Test Orchestrator
|
||||
|
||||
The AutoGPT Platform uses a single unified test runner (`run-tests.js`) for both local and cloud execution:
|
||||
The AutoGPT Platform uses a comprehensive load test orchestrator (`orchestrator/orchestrator.js`) that runs 12 optimized tests with maximum VU counts:
|
||||
|
||||
### Available Tests
|
||||
|
||||
@@ -53,45 +50,33 @@ The AutoGPT Platform uses a single unified test runner (`run-tests.js`) for both
|
||||
### Test Modes
|
||||
|
||||
- **Local Mode**: 5 VUs × 30s - Quick validation and debugging
|
||||
- **Cloud Mode**: 80-150 VUs × 3-5m - Real performance testing
|
||||
- **Cloud Mode**: 80-160 VUs × 3-6m - Real performance testing
|
||||
|
||||
## 🛠️ Usage
|
||||
|
||||
### Basic Commands
|
||||
|
||||
```bash
|
||||
# List available tests and show cloud credentials status
|
||||
node run-tests.js list
|
||||
# Run 12 optimized tests locally (for debugging)
|
||||
node orchestrator/orchestrator.js DEV local
|
||||
|
||||
# Quick setup verification
|
||||
node run-tests.js verify
|
||||
# Run 12 optimized tests in k6 cloud (recommended for performance testing)
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
|
||||
# Run specific test locally
|
||||
node run-tests.js run core-api-test DEV
|
||||
# Run against production (coordinate with team!)
|
||||
node orchestrator/orchestrator.js PROD cloud
|
||||
|
||||
# Run multiple tests sequentially (comma-separated)
|
||||
node run-tests.js run connectivity-test,core-api-test,marketplace-public-test DEV
|
||||
|
||||
# Run all tests locally
|
||||
node run-tests.js run all DEV
|
||||
|
||||
# Run specific test in k6 cloud
|
||||
node run-tests.js cloud core-api-test DEV
|
||||
|
||||
# Run all tests in k6 cloud
|
||||
node run-tests.js cloud all DEV
|
||||
# Run individual test directly with k6
|
||||
K6_ENVIRONMENT=DEV VUS=100 DURATION=3m k6 run tests/api/core-api-test.js
|
||||
```
|
||||
|
||||
### NPM Scripts
|
||||
|
||||
```bash
|
||||
# Quick verification
|
||||
npm run verify
|
||||
# Run orchestrator locally
|
||||
npm run local
|
||||
|
||||
# Run all tests locally
|
||||
npm test
|
||||
|
||||
# Run all tests in k6 cloud
|
||||
# Run orchestrator in k6 cloud
|
||||
npm run cloud
|
||||
```
|
||||
|
||||
@@ -99,12 +84,12 @@ npm run cloud
|
||||
|
||||
### Pre-Authenticated Tokens
|
||||
|
||||
- **Generation**: Run `node generate-tokens.js` to create tokens
|
||||
- **File**: `configs/pre-authenticated-tokens.js` (gitignored for security)
|
||||
- **Capacity**: 150+ tokens supporting high-concurrency testing
|
||||
- **Generation**: Run `node generate-tokens.js --count=160` to create tokens
|
||||
- **File**: `configs/pre-authenticated-tokens.js` (gitignored for security)
|
||||
- **Capacity**: 160+ tokens supporting high-concurrency testing
|
||||
- **Expiry**: 24 hours (86400 seconds) - extended for long-duration testing
|
||||
- **Benefit**: Eliminates Supabase auth rate limiting at scale
|
||||
- **Regeneration**: Run `node generate-tokens.js` when tokens expire after 24 hours
|
||||
- **Regeneration**: Run `node generate-tokens.js --count=160` when tokens expire after 24 hours
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
@@ -124,8 +109,8 @@ npm run cloud
|
||||
|
||||
### Load Testing Capabilities
|
||||
|
||||
- **High Concurrency**: Up to 150+ virtual users per test
|
||||
- **Authentication Scaling**: Pre-auth tokens support 150+ concurrent users (10 tokens generated by default)
|
||||
- **High Concurrency**: Up to 160+ virtual users per test
|
||||
- **Authentication Scaling**: Pre-auth tokens support 160+ concurrent users
|
||||
- **Sequential Execution**: Multiple tests run one after another with proper delays
|
||||
- **Cloud Infrastructure**: Tests run on k6 cloud servers for consistent results
|
||||
- **ES Module Support**: Full ES module compatibility with modern JavaScript features
|
||||
@@ -134,10 +119,11 @@ npm run cloud
|
||||
|
||||
### Validated Performance Limits
|
||||
|
||||
- **Core API**: 100 VUs successfully handling `/api/credits`, `/api/graphs`, `/api/blocks`, `/api/executions`
|
||||
- **Graph Execution**: 80 VUs for complete workflow pipeline
|
||||
- **Marketplace Browsing**: 150 VUs for public marketplace access
|
||||
- **Authentication**: 150+ concurrent users with pre-authenticated tokens
|
||||
- **Core API**: 100+ VUs successfully handling `/api/credits`, `/api/graphs`, `/api/blocks`, `/api/executions`
|
||||
- **Graph Execution**: 80+ VUs for complete workflow pipeline
|
||||
- **Marketplace Browsing**: 160 VUs for public marketplace access (verified)
|
||||
- **Marketplace Library**: 160 VUs for authenticated library operations (verified)
|
||||
- **Authentication**: 160+ concurrent users with pre-authenticated tokens
|
||||
|
||||
### Target Metrics
|
||||
|
||||
@@ -146,6 +132,14 @@ npm run cloud
|
||||
- **Success Rate**: Target > 95% under normal load
|
||||
- **Error Rate**: Target < 5% for all endpoints
|
||||
|
||||
### Recent Performance Results (160 VU Test - Verified)
|
||||
|
||||
- **Marketplace Library Operations**: 500-1000ms response times at 160 VUs
|
||||
- **Authentication**: 100% success rate with pre-authenticated tokens
|
||||
- **Library Journeys**: 5 operations per journey completing successfully
|
||||
- **Test Duration**: 6+ minutes sustained load without degradation
|
||||
- **k6 Cloud Execution**: Stable performance on Amazon US Columbus infrastructure
|
||||
|
||||
## 🔍 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
@@ -157,8 +151,8 @@ npm run cloud
|
||||
❌ Token has expired
|
||||
```
|
||||
|
||||
- **Solution**: Run `node generate-tokens.js` to create fresh 24-hour tokens
|
||||
- **Note**: Default generates 10 tokens (increase with `--count=50` for higher concurrency)
|
||||
- **Solution**: Run `node generate-tokens.js --count=160` to create fresh 24-hour tokens
|
||||
- **Note**: Use `--count` parameter to generate appropriate number of tokens for your test scale
|
||||
|
||||
**2. Cloud Credentials Missing**
|
||||
|
||||
@@ -168,7 +162,17 @@ npm run cloud
|
||||
|
||||
- **Solution**: Set `K6_CLOUD_TOKEN` and `K6_CLOUD_PROJECT_ID=4254406`
|
||||
|
||||
**3. Setup Verification Failed**
|
||||
**3. k6 Cloud VU Scaling Issue**
|
||||
|
||||
```
|
||||
❌ Test shows only 5 VUs instead of requested 100+ VUs
|
||||
```
|
||||
|
||||
- **Problem**: Using `K6_ENVIRONMENT=DEV VUS=160 k6 cloud run test.js` (incorrect)
|
||||
- **Solution**: Use `k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=160 test.js` (correct)
|
||||
- **Note**: The unified test runner (`run-tests.js`) already uses the correct syntax
|
||||
|
||||
**4. Setup Verification Failed**
|
||||
|
||||
```
|
||||
❌ Verification failed
|
||||
@@ -181,28 +185,38 @@ npm run cloud
|
||||
**1. Supabase Service Key (Required for all testing):**
|
||||
|
||||
```bash
|
||||
# Get service key from environment or Kubernetes
|
||||
# Option 1: From your local environment (if available)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# Option 2: From Kubernetes secret (for platform developers)
|
||||
kubectl get secret supabase-service-key -o jsonpath='{.data.service-key}' | base64 -d
|
||||
|
||||
# Option 3: From Supabase dashboard
|
||||
# Go to Project Settings > API > service_role key (never commit this!)
|
||||
```
|
||||
|
||||
**2. Generate Pre-Authenticated Tokens (Required):**
|
||||
|
||||
```bash
|
||||
# Creates 10 tokens with 24-hour expiry - prevents auth rate limiting
|
||||
node generate-tokens.js
|
||||
# Creates 160 tokens with 24-hour expiry - prevents auth rate limiting
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# Generate more tokens for higher concurrency
|
||||
# Generate fewer tokens for smaller tests (minimum 10)
|
||||
node generate-tokens.js --count=50
|
||||
|
||||
# Regenerate when tokens expire (every 24 hours)
|
||||
node generate-tokens.js
|
||||
node generate-tokens.js --count=160
|
||||
```
|
||||
|
||||
**3. k6 Cloud Credentials (Required for cloud testing):**
|
||||
|
||||
```bash
|
||||
# Get from k6 cloud dashboard: https://app.k6.io/account/api-token
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_PROJECT_ID="4254406" # AutoGPT Platform project ID
|
||||
|
||||
# Verify credentials work by running orchestrator
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
```
|
||||
|
||||
## 📂 File Structure
|
||||
@@ -210,9 +224,10 @@ export K6_CLOUD_PROJECT_ID="4254406" # AutoGPT Platform project ID
|
||||
```
|
||||
load-tests/
|
||||
├── README.md # This documentation
|
||||
├── run-tests.js # Unified test runner (MAIN ENTRY POINT)
|
||||
├── generate-tokens.js # Generate pre-auth tokens
|
||||
├── generate-tokens.js # Generate pre-auth tokens (MAIN TOKEN SETUP)
|
||||
├── package.json # Node.js dependencies and scripts
|
||||
├── orchestrator/
|
||||
│ └── orchestrator.js # Main test orchestrator (MAIN ENTRY POINT)
|
||||
├── configs/
|
||||
│ ├── environment.js # Environment URLs and configuration
|
||||
│ └── pre-authenticated-tokens.js # Generated tokens (gitignored)
|
||||
@@ -228,21 +243,19 @@ load-tests/
|
||||
│ │ └── library-access-test.js # Authenticated marketplace/library
|
||||
│ └── comprehensive/
|
||||
│ └── platform-journey-test.js # Complete user journey simulation
|
||||
├── orchestrator/
|
||||
│ └── comprehensive-orchestrator.js # Full 25-test orchestration suite
|
||||
├── results/ # Local test results (auto-created)
|
||||
├── k6-cloud-results.txt # Cloud test URLs (auto-created)
|
||||
└── *.json # Test output files (auto-created)
|
||||
├── unified-results-*.json # Orchestrator results (auto-created)
|
||||
└── *.log # Test execution logs (auto-created)
|
||||
```
|
||||
|
||||
## 🎯 Best Practices
|
||||
|
||||
1. **Start with Verification**: Always run `node run-tests.js verify` first
|
||||
2. **Local for Development**: Use `run` command for debugging and development
|
||||
3. **Cloud for Performance**: Use `cloud` command for actual performance testing
|
||||
1. **Generate Tokens First**: Always run `node generate-tokens.js --count=160` before testing
|
||||
2. **Local for Development**: Use `DEV local` for debugging and development
|
||||
3. **Cloud for Performance**: Use `DEV cloud` for actual performance testing
|
||||
4. **Monitor Real-Time**: Check k6 cloud dashboards during test execution
|
||||
5. **Regenerate Tokens**: Refresh tokens every 24 hours when they expire
|
||||
6. **Sequential Testing**: Use comma-separated tests for organized execution
|
||||
6. **Unified Testing**: Orchestrator runs 12 optimized tests automatically
|
||||
|
||||
## 🚀 Advanced Usage
|
||||
|
||||
@@ -252,8 +265,8 @@ For granular control over individual test scripts:
|
||||
|
||||
```bash
|
||||
# k6 Cloud execution (recommended for performance testing)
|
||||
K6_ENVIRONMENT=DEV VUS=100 DURATION=5m \
|
||||
k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=100 --env DURATION=5m tests/api/core-api-test.js
|
||||
# IMPORTANT: Use --env syntax for k6 cloud to ensure proper VU scaling
|
||||
k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=160 --env DURATION=5m --env RAMP_UP=30s --env RAMP_DOWN=30s tests/marketplace/library-access-test.js
|
||||
|
||||
# Local execution with cloud output (debugging)
|
||||
K6_ENVIRONMENT=DEV VUS=10 DURATION=1m \
|
||||
|
||||
@@ -1,611 +0,0 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
// AutoGPT Platform Load Test Orchestrator
|
||||
// Runs comprehensive test suite locally or in k6 cloud
|
||||
// Collects URLs, statistics, and generates reports
|
||||
|
||||
const { spawn } = require("child_process");
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
|
||||
console.log("🎯 AUTOGPT PLATFORM LOAD TEST ORCHESTRATOR\n");
|
||||
console.log("===========================================\n");
|
||||
|
||||
// Parse command line arguments
|
||||
const args = process.argv.slice(2);
|
||||
const environment = args[0] || "DEV"; // LOCAL, DEV, PROD
|
||||
const executionMode = args[1] || "cloud"; // local, cloud
|
||||
const testScale = args[2] || "full"; // small, full
|
||||
|
||||
console.log(`🌍 Target Environment: ${environment}`);
|
||||
console.log(`🚀 Execution Mode: ${executionMode}`);
|
||||
console.log(`📏 Test Scale: ${testScale}`);
|
||||
|
||||
// Test scenario definitions
|
||||
const testScenarios = {
|
||||
// Small scale for validation (3 tests, ~5 minutes)
|
||||
small: [
|
||||
{
|
||||
name: "Basic_Connectivity_Test",
|
||||
file: "tests/basic/connectivity-test.js",
|
||||
vus: 5,
|
||||
duration: "30s",
|
||||
},
|
||||
{
|
||||
name: "Core_API_Quick_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 10,
|
||||
duration: "1m",
|
||||
},
|
||||
{
|
||||
name: "Marketplace_Quick_Test",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 15,
|
||||
duration: "1m",
|
||||
},
|
||||
],
|
||||
|
||||
// Full comprehensive test suite (25 tests, ~2 hours)
|
||||
full: [
|
||||
// Marketplace Viewing Tests
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_Out_Day1",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 106,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_Out_VeryHigh",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 314,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_In_Day1",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 53,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_In_VeryHigh",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 157,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Library Management Tests
|
||||
{
|
||||
name: "Adding_Agent_to_Library_Day1",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 32,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Adding_Agent_to_Library_VeryHigh",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 95,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Library_Home_0_Agents_Day1",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 53,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Library_Home_0_Agents_VeryHigh",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 157,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Core API Tests
|
||||
{
|
||||
name: "Core_API_Load_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Graph_Execution_Load_Test",
|
||||
file: "tests/api/graph-execution-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Single API Endpoint Tests
|
||||
{
|
||||
name: "Credits_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "credits", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
{
|
||||
name: "Graphs_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "graphs", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
{
|
||||
name: "Blocks_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "blocks", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
{
|
||||
name: "Executions_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "executions", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
|
||||
// Comprehensive Platform Tests
|
||||
{
|
||||
name: "Comprehensive_Platform_Low",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 25,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Comprehensive_Platform_Medium",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Comprehensive_Platform_High",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// User Authentication Workflows
|
||||
{
|
||||
name: "User_Auth_Workflows_Day1",
|
||||
file: "tests/basic/connectivity-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "User_Auth_Workflows_VeryHigh",
|
||||
file: "tests/basic/connectivity-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Mixed Load Tests
|
||||
{
|
||||
name: "Mixed_Load_Light",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 75,
|
||||
duration: "5m",
|
||||
},
|
||||
{
|
||||
name: "Mixed_Load_Heavy",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 200,
|
||||
duration: "5m",
|
||||
},
|
||||
|
||||
// Stress Tests
|
||||
{
|
||||
name: "Marketplace_Stress_Test",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 500,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Core_API_Stress_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 300,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Extended Duration Tests
|
||||
{
|
||||
name: "Long_Duration_Marketplace",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 100,
|
||||
duration: "10m",
|
||||
},
|
||||
{
|
||||
name: "Long_Duration_Core_API",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "10m",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const scenarios = testScenarios[testScale];
|
||||
console.log(`📊 Running ${scenarios.length} test scenarios`);
|
||||
|
||||
// Results collection
|
||||
const results = [];
|
||||
const cloudUrls = [];
|
||||
const detailedMetrics = [];
|
||||
|
||||
// Create results directory
|
||||
const timestamp = new Date()
|
||||
.toISOString()
|
||||
.replace(/[:.]/g, "-")
|
||||
.substring(0, 16);
|
||||
const resultsDir = `results-${environment.toLowerCase()}-${executionMode}-${testScale}-${timestamp}`;
|
||||
if (!fs.existsSync(resultsDir)) {
|
||||
fs.mkdirSync(resultsDir);
|
||||
}
|
||||
|
||||
// Function to run a single test
|
||||
function runTest(scenario, testIndex) {
|
||||
return new Promise((resolve, reject) => {
|
||||
console.log(`\n🚀 Test ${testIndex}/${scenarios.length}: ${scenario.name}`);
|
||||
console.log(
|
||||
`📊 Config: ${scenario.vus} VUs × ${scenario.duration} (${executionMode} mode)`,
|
||||
);
|
||||
console.log(`📁 Script: ${scenario.file}`);
|
||||
|
||||
// Build k6 command
|
||||
let k6Command, k6Args;
|
||||
|
||||
// Determine k6 binary location
|
||||
const isInPod = fs.existsSync("/app/k6-v0.54.0-linux-amd64/k6");
|
||||
const k6Binary = isInPod ? "/app/k6-v0.54.0-linux-amd64/k6" : "k6";
|
||||
|
||||
// Build environment variables
|
||||
const envVars = [
|
||||
`K6_ENVIRONMENT=${environment}`,
|
||||
`VUS=${scenario.vus}`,
|
||||
`DURATION=${scenario.duration}`,
|
||||
`RAMP_UP=30s`,
|
||||
`RAMP_DOWN=30s`,
|
||||
`THRESHOLD_P95=60000`,
|
||||
`THRESHOLD_P99=60000`,
|
||||
];
|
||||
|
||||
// Add scenario-specific environment variables
|
||||
if (scenario.env) {
|
||||
Object.keys(scenario.env).forEach((key) => {
|
||||
envVars.push(`${key}=${scenario.env[key]}`);
|
||||
});
|
||||
}
|
||||
|
||||
// Configure command based on execution mode
|
||||
if (executionMode === "cloud") {
|
||||
k6Command = k6Binary;
|
||||
k6Args = ["cloud", "run", scenario.file];
|
||||
// Add environment variables as --env flags
|
||||
envVars.forEach((env) => {
|
||||
k6Args.push("--env", env);
|
||||
});
|
||||
} else {
|
||||
k6Command = k6Binary;
|
||||
k6Args = ["run", scenario.file];
|
||||
|
||||
// Add local output files
|
||||
const outputFile = path.join(resultsDir, `${scenario.name}.json`);
|
||||
const summaryFile = path.join(
|
||||
resultsDir,
|
||||
`${scenario.name}_summary.json`,
|
||||
);
|
||||
k6Args.push("--out", `json=${outputFile}`);
|
||||
k6Args.push("--summary-export", summaryFile);
|
||||
}
|
||||
|
||||
const startTime = Date.now();
|
||||
let testUrl = "";
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
|
||||
console.log(`⏱️ Test started: ${new Date().toISOString()}`);
|
||||
|
||||
// Set environment variables for spawned process
|
||||
const processEnv = { ...process.env };
|
||||
envVars.forEach((env) => {
|
||||
const [key, value] = env.split("=");
|
||||
processEnv[key] = value;
|
||||
});
|
||||
|
||||
const childProcess = spawn(k6Command, k6Args, {
|
||||
env: processEnv,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
// Handle stdout
|
||||
childProcess.stdout.on("data", (data) => {
|
||||
const output = data.toString();
|
||||
stdout += output;
|
||||
|
||||
// Extract k6 cloud URL
|
||||
if (executionMode === "cloud") {
|
||||
const urlMatch = output.match(/output:\s*(https:\/\/[^\s]+)/);
|
||||
if (urlMatch) {
|
||||
testUrl = urlMatch[1];
|
||||
console.log(`🔗 Test URL: ${testUrl}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Show progress indicators
|
||||
if (output.includes("Run [")) {
|
||||
const progressMatch = output.match(/Run\s+\[\s*(\d+)%\s*\]/);
|
||||
if (progressMatch) {
|
||||
process.stdout.write(`\r⏳ Progress: ${progressMatch[1]}%`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Handle stderr
|
||||
childProcess.stderr.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
// Handle process completion
|
||||
childProcess.on("close", (code) => {
|
||||
const endTime = Date.now();
|
||||
const duration = Math.round((endTime - startTime) / 1000);
|
||||
|
||||
console.log(`\n⏱️ Completed in ${duration}s`);
|
||||
|
||||
if (code === 0) {
|
||||
console.log(`✅ ${scenario.name} SUCCESS`);
|
||||
|
||||
const result = {
|
||||
test: scenario.name,
|
||||
status: "SUCCESS",
|
||||
duration: `${duration}s`,
|
||||
vus: scenario.vus,
|
||||
target_duration: scenario.duration,
|
||||
url: testUrl || "N/A",
|
||||
execution_mode: executionMode,
|
||||
environment: environment,
|
||||
completed_at: new Date().toISOString(),
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
|
||||
if (testUrl) {
|
||||
cloudUrls.push(`${scenario.name}: ${testUrl}`);
|
||||
}
|
||||
|
||||
// Store detailed output for analysis
|
||||
detailedMetrics.push({
|
||||
test: scenario.name,
|
||||
stdout_lines: stdout.split("\n").length,
|
||||
stderr_lines: stderr.split("\n").length,
|
||||
has_url: !!testUrl,
|
||||
});
|
||||
|
||||
resolve(result);
|
||||
} else {
|
||||
console.error(`❌ ${scenario.name} FAILED (exit code ${code})`);
|
||||
|
||||
const result = {
|
||||
test: scenario.name,
|
||||
status: "FAILED",
|
||||
error: `Exit code ${code}`,
|
||||
duration: `${duration}s`,
|
||||
vus: scenario.vus,
|
||||
execution_mode: executionMode,
|
||||
environment: environment,
|
||||
completed_at: new Date().toISOString(),
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
reject(new Error(`Test failed with exit code ${code}`));
|
||||
}
|
||||
});
|
||||
|
||||
// Handle spawn errors
|
||||
childProcess.on("error", (error) => {
|
||||
console.error(`❌ ${scenario.name} ERROR:`, error.message);
|
||||
|
||||
results.push({
|
||||
test: scenario.name,
|
||||
status: "ERROR",
|
||||
error: error.message,
|
||||
execution_mode: executionMode,
|
||||
environment: environment,
|
||||
});
|
||||
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Main orchestration function
|
||||
async function runOrchestrator() {
|
||||
const estimatedMinutes = scenarios.length * (testScale === "small" ? 2 : 5);
|
||||
console.log(`\n🎯 Starting ${testScale} test suite on ${environment}`);
|
||||
console.log(`📈 Estimated time: ~${estimatedMinutes} minutes`);
|
||||
console.log(`🌩️ Execution: ${executionMode} mode\n`);
|
||||
|
||||
const startTime = Date.now();
|
||||
let successCount = 0;
|
||||
let failureCount = 0;
|
||||
|
||||
// Run tests sequentially
|
||||
for (let i = 0; i < scenarios.length; i++) {
|
||||
try {
|
||||
await runTest(scenarios[i], i + 1);
|
||||
successCount++;
|
||||
|
||||
// Pause between tests (avoid overwhelming k6 cloud API)
|
||||
if (i < scenarios.length - 1) {
|
||||
const pauseSeconds = testScale === "small" ? 10 : 30;
|
||||
console.log(`\n⏸️ Pausing ${pauseSeconds}s before next test...\n`);
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, pauseSeconds * 1000),
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
failureCount++;
|
||||
console.log(`💥 Continuing after failure...\n`);
|
||||
|
||||
// Brief pause before continuing
|
||||
if (i < scenarios.length - 1) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 15000));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const totalTime = Math.round((Date.now() - startTime) / 1000);
|
||||
await generateReports(successCount, failureCount, totalTime);
|
||||
}
|
||||
|
||||
// Generate comprehensive reports
|
||||
async function generateReports(successCount, failureCount, totalTime) {
|
||||
console.log("\n🎉 LOAD TEST ORCHESTRATOR COMPLETE\n");
|
||||
console.log("===================================\n");
|
||||
|
||||
// Summary statistics
|
||||
const successRate = Math.round((successCount / scenarios.length) * 100);
|
||||
console.log("📊 EXECUTION SUMMARY:");
|
||||
console.log(
|
||||
`✅ Successful tests: ${successCount}/${scenarios.length} (${successRate}%)`,
|
||||
);
|
||||
console.log(`❌ Failed tests: ${failureCount}/${scenarios.length}`);
|
||||
console.log(`⏱️ Total execution time: ${Math.round(totalTime / 60)} minutes`);
|
||||
console.log(`🌍 Environment: ${environment}`);
|
||||
console.log(`🚀 Mode: ${executionMode}`);
|
||||
|
||||
// Generate CSV report
|
||||
const csvHeaders =
|
||||
"Test Name,Status,VUs,Target Duration,Actual Duration,Environment,Mode,Test URL,Error,Completed At";
|
||||
const csvRows = results.map(
|
||||
(r) =>
|
||||
`"${r.test}","${r.status}",${r.vus},"${r.target_duration || "N/A"}","${r.duration || "N/A"}","${r.environment}","${r.execution_mode}","${r.url || "N/A"}","${r.error || "None"}","${r.completed_at || "N/A"}"`,
|
||||
);
|
||||
|
||||
const csvContent = [csvHeaders, ...csvRows].join("\n");
|
||||
const csvFile = path.join(resultsDir, "orchestrator_results.csv");
|
||||
fs.writeFileSync(csvFile, csvContent);
|
||||
console.log(`\n📁 CSV Report: ${csvFile}`);
|
||||
|
||||
// Generate cloud URLs file
|
||||
if (executionMode === "cloud" && cloudUrls.length > 0) {
|
||||
const urlsContent = [
|
||||
`# AutoGPT Platform Load Test URLs`,
|
||||
`# Environment: ${environment}`,
|
||||
`# Generated: ${new Date().toISOString()}`,
|
||||
`# Dashboard: https://significantgravitas.grafana.net/a/k6-app/`,
|
||||
"",
|
||||
...cloudUrls,
|
||||
"",
|
||||
"# Direct Dashboard Access:",
|
||||
"https://significantgravitas.grafana.net/a/k6-app/",
|
||||
].join("\n");
|
||||
|
||||
const urlsFile = path.join(resultsDir, "cloud_test_urls.txt");
|
||||
fs.writeFileSync(urlsFile, urlsContent);
|
||||
console.log(`📁 Cloud URLs: ${urlsFile}`);
|
||||
}
|
||||
|
||||
// Generate detailed JSON report
|
||||
const jsonReport = {
|
||||
meta: {
|
||||
orchestrator_version: "1.0",
|
||||
environment: environment,
|
||||
execution_mode: executionMode,
|
||||
test_scale: testScale,
|
||||
total_scenarios: scenarios.length,
|
||||
generated_at: new Date().toISOString(),
|
||||
results_directory: resultsDir,
|
||||
},
|
||||
summary: {
|
||||
successful_tests: successCount,
|
||||
failed_tests: failureCount,
|
||||
success_rate: `${successRate}%`,
|
||||
total_execution_time_seconds: totalTime,
|
||||
total_execution_time_minutes: Math.round(totalTime / 60),
|
||||
},
|
||||
test_results: results,
|
||||
detailed_metrics: detailedMetrics,
|
||||
cloud_urls: cloudUrls,
|
||||
};
|
||||
|
||||
const jsonFile = path.join(resultsDir, "orchestrator_results.json");
|
||||
fs.writeFileSync(jsonFile, JSON.stringify(jsonReport, null, 2));
|
||||
console.log(`📁 JSON Report: ${jsonFile}`);
|
||||
|
||||
// Display immediate results
|
||||
if (executionMode === "cloud" && cloudUrls.length > 0) {
|
||||
console.log("\n🔗 K6 CLOUD TEST DASHBOARD URLS:");
|
||||
console.log("================================");
|
||||
cloudUrls.slice(0, 5).forEach((url) => console.log(url));
|
||||
if (cloudUrls.length > 5) {
|
||||
console.log(`... and ${cloudUrls.length - 5} more URLs in ${urlsFile}`);
|
||||
}
|
||||
console.log(
|
||||
"\n📈 Main Dashboard: https://significantgravitas.grafana.net/a/k6-app/",
|
||||
);
|
||||
}
|
||||
|
||||
console.log(`\n📂 All results saved in: ${resultsDir}/`);
|
||||
console.log("🏁 Load Test Orchestrator finished successfully!");
|
||||
}
|
||||
|
||||
// Show usage help
|
||||
function showUsage() {
|
||||
console.log("🎯 AutoGPT Platform Load Test Orchestrator\n");
|
||||
console.log(
|
||||
"Usage: node load-test-orchestrator.js [ENVIRONMENT] [MODE] [SCALE]\n",
|
||||
);
|
||||
console.log("ENVIRONMENT:");
|
||||
console.log(" LOCAL - http://localhost:8006 (local development)");
|
||||
console.log(" DEV - https://dev-api.agpt.co (development server)");
|
||||
console.log(
|
||||
" PROD - https://api.agpt.co (production - coordinate with team!)\n",
|
||||
);
|
||||
console.log("MODE:");
|
||||
console.log(" local - Run locally with JSON output files");
|
||||
console.log(" cloud - Run in k6 cloud with dashboard monitoring\n");
|
||||
console.log("SCALE:");
|
||||
console.log(" small - 3 validation tests (~5 minutes)");
|
||||
console.log(" full - 25 comprehensive tests (~2 hours)\n");
|
||||
console.log("Examples:");
|
||||
console.log(" node load-test-orchestrator.js DEV cloud small");
|
||||
console.log(" node load-test-orchestrator.js LOCAL local small");
|
||||
console.log(" node load-test-orchestrator.js DEV cloud full");
|
||||
console.log(
|
||||
" node load-test-orchestrator.js PROD cloud full # Coordinate with team!\n",
|
||||
);
|
||||
console.log("Requirements:");
|
||||
console.log(
|
||||
" - Pre-authenticated tokens generated (node generate-tokens.js)",
|
||||
);
|
||||
console.log(" - k6 installed locally or run from Kubernetes pod");
|
||||
console.log(" - For cloud mode: K6_CLOUD_TOKEN and K6_CLOUD_PROJECT_ID set");
|
||||
}
|
||||
|
||||
// Handle command line help
|
||||
if (args.includes("--help") || args.includes("-h")) {
|
||||
showUsage();
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
// Handle graceful shutdown
|
||||
process.on("SIGINT", () => {
|
||||
console.log("\n🛑 Orchestrator interrupted by user");
|
||||
console.log("📊 Generating partial results...");
|
||||
generateReports(
|
||||
results.filter((r) => r.status === "SUCCESS").length,
|
||||
results.filter((r) => r.status === "FAILED").length,
|
||||
0,
|
||||
).then(() => {
|
||||
console.log("🏃♂️ Partial results saved");
|
||||
process.exit(0);
|
||||
});
|
||||
});
|
||||
|
||||
// Start orchestrator
|
||||
if (require.main === module) {
|
||||
runOrchestrator().catch((error) => {
|
||||
console.error("💥 Orchestrator failed:", error);
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = { runOrchestrator, testScenarios };
|
||||
362
autogpt_platform/backend/load-tests/orchestrator/orchestrator.js
Normal file
362
autogpt_platform/backend/load-tests/orchestrator/orchestrator.js
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
/**
|
||||
* AutoGPT Platform Load Test Orchestrator
|
||||
*
|
||||
* Optimized test suite with only the highest VU count for each unique test type.
|
||||
* Eliminates duplicate tests and focuses on maximum load testing.
|
||||
*/
|
||||
|
||||
import { spawn } from 'child_process';
|
||||
import fs from 'fs';
|
||||
|
||||
console.log("🎯 AUTOGPT PLATFORM LOAD TEST ORCHESTRATOR\n");
|
||||
console.log("===========================================\n");
|
||||
|
||||
// Parse command line arguments
|
||||
const args = process.argv.slice(2);
|
||||
const environment = args[0] || "DEV"; // LOCAL, DEV, PROD
|
||||
const executionMode = args[1] || "cloud"; // local, cloud
|
||||
|
||||
console.log(`🌍 Target Environment: ${environment}`);
|
||||
console.log(`🚀 Execution Mode: ${executionMode}`);
|
||||
|
||||
// Unified test scenarios - only highest VUs for each unique test
|
||||
const unifiedTestScenarios = [
|
||||
// 1. Marketplace Public Access (highest VUs: 314)
|
||||
{
|
||||
name: "Marketplace_Public_Access_Max_Load",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 314,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
description: "Public marketplace browsing at maximum load"
|
||||
},
|
||||
|
||||
// 2. Marketplace Authenticated Access (highest VUs: 157)
|
||||
{
|
||||
name: "Marketplace_Authenticated_Access_Max_Load",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 157,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
description: "Authenticated marketplace/library operations at maximum load"
|
||||
},
|
||||
|
||||
// 3. Core API Load Test (highest VUs: 100)
|
||||
{
|
||||
name: "Core_API_Max_Load",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "5m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Core authenticated API endpoints at maximum load"
|
||||
},
|
||||
|
||||
// 4. Graph Execution Load Test (highest VUs: 100)
|
||||
{
|
||||
name: "Graph_Execution_Max_Load",
|
||||
file: "tests/api/graph-execution-test.js",
|
||||
vus: 100,
|
||||
duration: "5m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Graph workflow execution pipeline at maximum load"
|
||||
},
|
||||
|
||||
// 5. Credits API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Credits_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "credits", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Credits API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 6. Graphs API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Graphs_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "graphs", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Graphs API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 7. Blocks API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Blocks_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "blocks", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Blocks API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 8. Executions API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Executions_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "executions", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Executions API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 9. Comprehensive Platform Journey (highest VUs: 100)
|
||||
{
|
||||
name: "Comprehensive_Platform_Max_Load",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
description: "End-to-end user journey simulation at maximum load"
|
||||
},
|
||||
|
||||
// 10. Marketplace Stress Test (highest VUs: 500)
|
||||
{
|
||||
name: "Marketplace_Stress_Test",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 500,
|
||||
duration: "2m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Ultimate marketplace stress test"
|
||||
},
|
||||
|
||||
// 11. Core API Stress Test (highest VUs: 500)
|
||||
{
|
||||
name: "Core_API_Stress_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 500,
|
||||
duration: "2m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Ultimate core API stress test"
|
||||
},
|
||||
|
||||
// 12. Long Duration Core API Test (highest VUs: 100, longest duration)
|
||||
{
|
||||
name: "Long_Duration_Core_API_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "10m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Extended duration core API endurance test"
|
||||
}
|
||||
];
|
||||
|
||||
// Configuration
|
||||
const K6_CLOUD_TOKEN = process.env.K6_CLOUD_TOKEN || '9347b8bd716cadc243e92f7d2f89107febfb81b49f2340d17da515d7b0513b51';
|
||||
const K6_CLOUD_PROJECT_ID = process.env.K6_CLOUD_PROJECT_ID || '4254406';
|
||||
const PAUSE_BETWEEN_TESTS = 30; // seconds
|
||||
|
||||
/**
|
||||
* Sleep for specified milliseconds
|
||||
*/
|
||||
function sleep(ms) {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a single k6 test
|
||||
*/
|
||||
async function runTest(test, index) {
|
||||
return new Promise((resolve, reject) => {
|
||||
console.log(`\n🚀 Test ${index + 1}/${unifiedTestScenarios.length}: ${test.name}`);
|
||||
console.log(`📊 Config: ${test.vus} VUs × ${test.duration} (${executionMode} mode)`);
|
||||
console.log(`📁 Script: ${test.file}`);
|
||||
console.log(`📋 Description: ${test.description}`);
|
||||
console.log(`⏱️ Test started: ${new Date().toISOString()}`);
|
||||
|
||||
const env = {
|
||||
K6_CLOUD_TOKEN,
|
||||
K6_CLOUD_PROJECT_ID,
|
||||
K6_ENVIRONMENT: environment,
|
||||
VUS: test.vus.toString(),
|
||||
DURATION: test.duration,
|
||||
RAMP_UP: test.rampUp,
|
||||
RAMP_DOWN: test.rampDown,
|
||||
...test.env
|
||||
};
|
||||
|
||||
let args;
|
||||
if (executionMode === 'cloud') {
|
||||
args = [
|
||||
'cloud', 'run',
|
||||
...Object.entries(env).map(([key, value]) => ['--env', `${key}=${value}`]).flat(),
|
||||
test.file
|
||||
];
|
||||
} else {
|
||||
args = [
|
||||
'run',
|
||||
...Object.entries(env).map(([key, value]) => ['--env', `${key}=${value}`]).flat(),
|
||||
test.file
|
||||
];
|
||||
}
|
||||
|
||||
const k6Process = spawn('k6', args, {
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
env: { ...process.env, ...env }
|
||||
});
|
||||
|
||||
let output = '';
|
||||
let testId = null;
|
||||
|
||||
k6Process.stdout.on('data', (data) => {
|
||||
const str = data.toString();
|
||||
output += str;
|
||||
|
||||
// Extract test ID from k6 cloud output
|
||||
const testIdMatch = str.match(/Test created: .*\/(\d+)/);
|
||||
if (testIdMatch) {
|
||||
testId = testIdMatch[1];
|
||||
console.log(`🔗 Test URL: https://significantgravitas.grafana.net/a/k6-app/runs/${testId}`);
|
||||
}
|
||||
|
||||
// Show progress updates
|
||||
const progressMatch = str.match(/(\d+)%/);
|
||||
if (progressMatch) {
|
||||
process.stdout.write(`\r⏳ Progress: ${progressMatch[1]}%`);
|
||||
}
|
||||
});
|
||||
|
||||
k6Process.stderr.on('data', (data) => {
|
||||
output += data.toString();
|
||||
});
|
||||
|
||||
k6Process.on('close', (code) => {
|
||||
process.stdout.write('\n'); // Clear progress line
|
||||
|
||||
if (code === 0) {
|
||||
console.log(`✅ ${test.name} SUCCESS`);
|
||||
resolve({
|
||||
success: true,
|
||||
testId,
|
||||
url: testId ? `https://significantgravitas.grafana.net/a/k6-app/runs/${testId}` : 'unknown',
|
||||
vus: test.vus,
|
||||
duration: test.duration
|
||||
});
|
||||
} else {
|
||||
console.log(`❌ ${test.name} FAILED (exit code ${code})`);
|
||||
resolve({
|
||||
success: false,
|
||||
testId,
|
||||
url: testId ? `https://significantgravitas.grafana.net/a/k6-app/runs/${testId}` : 'unknown',
|
||||
exitCode: code,
|
||||
vus: test.vus,
|
||||
duration: test.duration
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
k6Process.on('error', (error) => {
|
||||
console.log(`❌ ${test.name} ERROR: ${error.message}`);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Main execution
|
||||
*/
|
||||
async function main() {
|
||||
console.log(`\n📋 UNIFIED TEST PLAN`);
|
||||
console.log(`📊 Total tests: ${unifiedTestScenarios.length} (reduced from 25 original tests)`);
|
||||
console.log(`⏱️ Estimated duration: ~60 minutes\n`);
|
||||
|
||||
console.log(`📋 Test Summary:`);
|
||||
unifiedTestScenarios.forEach((test, i) => {
|
||||
console.log(` ${i + 1}. ${test.name} (${test.vus} VUs × ${test.duration})`);
|
||||
});
|
||||
console.log('');
|
||||
|
||||
const results = [];
|
||||
|
||||
for (let i = 0; i < unifiedTestScenarios.length; i++) {
|
||||
const test = unifiedTestScenarios[i];
|
||||
|
||||
try {
|
||||
const result = await runTest(test, i);
|
||||
results.push({ ...test, ...result });
|
||||
|
||||
// Pause between tests (except after the last one)
|
||||
if (i < unifiedTestScenarios.length - 1) {
|
||||
console.log(`\n⏸️ Pausing ${PAUSE_BETWEEN_TESTS}s before next test...`);
|
||||
await sleep(PAUSE_BETWEEN_TESTS * 1000);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`💥 Fatal error running ${test.name}:`, error.message);
|
||||
results.push({ ...test, success: false, error: error.message });
|
||||
}
|
||||
}
|
||||
|
||||
// Summary
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log('🏁 UNIFIED LOAD TEST RESULTS SUMMARY');
|
||||
console.log('='.repeat(60));
|
||||
|
||||
const successful = results.filter(r => r.success);
|
||||
const failed = results.filter(r => !r.success);
|
||||
|
||||
console.log(`✅ Successful tests: ${successful.length}/${results.length} (${Math.round(successful.length / results.length * 100)}%)`);
|
||||
console.log(`❌ Failed tests: ${failed.length}/${results.length}`);
|
||||
|
||||
if (successful.length > 0) {
|
||||
console.log('\n✅ SUCCESSFUL TESTS:');
|
||||
successful.forEach(test => {
|
||||
console.log(` • ${test.name} (${test.vus} VUs) - ${test.url}`);
|
||||
});
|
||||
}
|
||||
|
||||
if (failed.length > 0) {
|
||||
console.log('\n❌ FAILED TESTS:');
|
||||
failed.forEach(test => {
|
||||
console.log(` • ${test.name} (${test.vus} VUs) - ${test.url || 'no URL'} (exit: ${test.exitCode || 'unknown'})`);
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate total VU-minutes tested
|
||||
const totalVuMinutes = results.reduce((sum, test) => {
|
||||
const minutes = parseFloat(test.duration.replace(/[ms]/g, ''));
|
||||
const multiplier = test.duration.includes('m') ? 1 : (1/60); // convert seconds to minutes
|
||||
return sum + (test.vus * minutes * multiplier);
|
||||
}, 0);
|
||||
|
||||
console.log(`\n📊 LOAD TESTING SUMMARY:`);
|
||||
console.log(` • Total VU-minutes tested: ${Math.round(totalVuMinutes)}`);
|
||||
console.log(` • Peak concurrent VUs: ${Math.max(...results.map(r => r.vus))}`);
|
||||
console.log(` • Average test duration: ${(results.reduce((sum, r) => sum + parseFloat(r.duration.replace(/[ms]/g, '')), 0) / results.length).toFixed(1)}${results[0].duration.includes('m') ? 'm' : 's'}`);
|
||||
|
||||
// Write results to file
|
||||
const timestamp = Math.floor(Date.now() / 1000);
|
||||
const resultsFile = `unified-results-${timestamp}.json`;
|
||||
fs.writeFileSync(resultsFile, JSON.stringify(results, null, 2));
|
||||
console.log(`\n📄 Detailed results saved to: ${resultsFile}`);
|
||||
|
||||
console.log(`\n🎉 UNIFIED LOAD TEST ORCHESTRATOR COMPLETE\n`);
|
||||
|
||||
process.exit(failed.length === 0 ? 0 : 1);
|
||||
}
|
||||
|
||||
// Run if called directly
|
||||
if (process.argv[1] === new URL(import.meta.url).pathname) {
|
||||
main().catch(error => {
|
||||
console.error('💥 Fatal error:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* Unified Load Test Runner
|
||||
*
|
||||
* Supports both local execution and k6 cloud execution with the same interface.
|
||||
* Automatically detects cloud credentials and provides seamless switching.
|
||||
*
|
||||
* Usage:
|
||||
* node run-tests.js verify # Quick verification (1 VU, 10s)
|
||||
* node run-tests.js run core-api-test DEV # Run specific test locally
|
||||
* node run-tests.js run all DEV # Run all tests locally
|
||||
* node run-tests.js cloud core-api DEV # Run specific test in k6 cloud
|
||||
* node run-tests.js cloud all DEV # Run all tests in k6 cloud
|
||||
*/
|
||||
|
||||
import { execSync } from "child_process";
|
||||
import fs from "fs";
|
||||
|
||||
const TESTS = {
|
||||
"connectivity-test": {
|
||||
script: "tests/basic/connectivity-test.js",
|
||||
description: "Basic connectivity validation",
|
||||
cloudConfig: { vus: 10, duration: "2m" },
|
||||
},
|
||||
"single-endpoint-test": {
|
||||
script: "tests/basic/single-endpoint-test.js",
|
||||
description: "Individual API endpoint testing",
|
||||
cloudConfig: { vus: 25, duration: "3m" },
|
||||
},
|
||||
"core-api-test": {
|
||||
script: "tests/api/core-api-test.js",
|
||||
description: "Core API endpoints performance test",
|
||||
cloudConfig: { vus: 100, duration: "5m" },
|
||||
},
|
||||
"graph-execution-test": {
|
||||
script: "tests/api/graph-execution-test.js",
|
||||
description: "Graph creation and execution pipeline test",
|
||||
cloudConfig: { vus: 80, duration: "5m" },
|
||||
},
|
||||
"marketplace-public-test": {
|
||||
script: "tests/marketplace/public-access-test.js",
|
||||
description: "Public marketplace browsing test",
|
||||
cloudConfig: { vus: 150, duration: "3m" },
|
||||
},
|
||||
"marketplace-library-test": {
|
||||
script: "tests/marketplace/library-access-test.js",
|
||||
description: "Authenticated marketplace/library test",
|
||||
cloudConfig: { vus: 100, duration: "4m" },
|
||||
},
|
||||
"comprehensive-test": {
|
||||
script: "tests/comprehensive/platform-journey-test.js",
|
||||
description: "Complete user journey simulation",
|
||||
cloudConfig: { vus: 50, duration: "6m" },
|
||||
},
|
||||
};
|
||||
|
||||
function checkCloudCredentials() {
|
||||
const token = process.env.K6_CLOUD_TOKEN;
|
||||
const projectId = process.env.K6_CLOUD_PROJECT_ID;
|
||||
|
||||
if (!token || !projectId) {
|
||||
console.log("❌ Missing k6 cloud credentials");
|
||||
console.log("Set: K6_CLOUD_TOKEN and K6_CLOUD_PROJECT_ID");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function verifySetup() {
|
||||
console.log("🔍 Quick Setup Verification");
|
||||
|
||||
// Check tokens
|
||||
if (!fs.existsSync("configs/pre-authenticated-tokens.js")) {
|
||||
console.log("❌ No tokens found. Run: node generate-tokens.js");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Quick test
|
||||
try {
|
||||
execSync(
|
||||
"K6_ENVIRONMENT=DEV VUS=1 DURATION=10s k6 run tests/basic/connectivity-test.js --quiet",
|
||||
{ stdio: "inherit", cwd: process.cwd() },
|
||||
);
|
||||
console.log("✅ Verification successful");
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.log("❌ Verification failed");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function runLocalTest(testName, environment) {
|
||||
const test = TESTS[testName];
|
||||
if (!test) {
|
||||
console.log(`❌ Unknown test: ${testName}`);
|
||||
console.log("Available tests:", Object.keys(TESTS).join(", "));
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`🚀 Running ${test.description} locally on ${environment}`);
|
||||
|
||||
try {
|
||||
const cmd = `K6_ENVIRONMENT=${environment} VUS=5 DURATION=30s k6 run ${test.script}`;
|
||||
execSync(cmd, { stdio: "inherit", cwd: process.cwd() });
|
||||
console.log("✅ Test completed");
|
||||
} catch (error) {
|
||||
console.log("❌ Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
function runCloudTest(testName, environment) {
|
||||
const test = TESTS[testName];
|
||||
if (!test) {
|
||||
console.log(`❌ Unknown test: ${testName}`);
|
||||
console.log("Available tests:", Object.keys(TESTS).join(", "));
|
||||
return;
|
||||
}
|
||||
|
||||
const { vus, duration } = test.cloudConfig;
|
||||
console.log(`☁️ Running ${test.description} in k6 cloud`);
|
||||
console.log(` Environment: ${environment}`);
|
||||
console.log(` Config: ${vus} VUs × ${duration}`);
|
||||
|
||||
try {
|
||||
const cmd = `k6 cloud run --env K6_ENVIRONMENT=${environment} --env VUS=${vus} --env DURATION=${duration} --env RAMP_UP=30s --env RAMP_DOWN=30s ${test.script}`;
|
||||
const output = execSync(cmd, {
|
||||
stdio: "pipe",
|
||||
cwd: process.cwd(),
|
||||
encoding: "utf8",
|
||||
});
|
||||
|
||||
// Extract and display URL
|
||||
const urlMatch = output.match(/https:\/\/[^\s]*grafana[^\s]*/);
|
||||
if (urlMatch) {
|
||||
const url = urlMatch[0];
|
||||
console.log(`🔗 Test URL: ${url}`);
|
||||
|
||||
// Save to results file
|
||||
const timestamp = new Date().toISOString();
|
||||
const result = `${timestamp} - ${testName}: ${url}\n`;
|
||||
fs.appendFileSync("k6-cloud-results.txt", result);
|
||||
}
|
||||
|
||||
console.log("✅ Cloud test started successfully");
|
||||
} catch (error) {
|
||||
console.log("❌ Cloud test failed to start");
|
||||
console.log(error.message);
|
||||
}
|
||||
}
|
||||
|
||||
function runAllLocalTests(environment) {
|
||||
console.log(`🚀 Running all tests locally on ${environment}`);
|
||||
|
||||
for (const [testName, test] of Object.entries(TESTS)) {
|
||||
console.log(`\n📊 ${test.description}`);
|
||||
runLocalTest(testName, environment);
|
||||
}
|
||||
}
|
||||
|
||||
function runAllCloudTests(environment) {
|
||||
console.log(`☁️ Running all tests in k6 cloud on ${environment}`);
|
||||
|
||||
const testNames = Object.keys(TESTS);
|
||||
for (let i = 0; i < testNames.length; i++) {
|
||||
const testName = testNames[i];
|
||||
console.log(`\n📊 Test ${i + 1}/${testNames.length}: ${testName}`);
|
||||
|
||||
runCloudTest(testName, environment);
|
||||
|
||||
// Brief pause between cloud tests (except last one)
|
||||
if (i < testNames.length - 1) {
|
||||
console.log("⏸️ Waiting 2 minutes before next cloud test...");
|
||||
execSync("sleep 120");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function listTests() {
|
||||
console.log("📋 Available Tests:");
|
||||
console.log("==================");
|
||||
|
||||
Object.entries(TESTS).forEach(([name, test]) => {
|
||||
const { vus, duration } = test.cloudConfig;
|
||||
console.log(` ${name.padEnd(20)} - ${test.description}`);
|
||||
console.log(` ${" ".repeat(20)} Cloud: ${vus} VUs × ${duration}`);
|
||||
});
|
||||
|
||||
console.log("\n🌍 Available Environments: LOCAL, DEV, PROD");
|
||||
console.log("\n💡 Examples:");
|
||||
console.log(" # Local execution (5 VUs, 30s)");
|
||||
console.log(" node run-tests.js verify");
|
||||
console.log(" node run-tests.js run core-api-test DEV");
|
||||
console.log(" node run-tests.js run core-api-test,marketplace-test DEV");
|
||||
console.log(" node run-tests.js run all DEV");
|
||||
console.log("");
|
||||
console.log(" # Cloud execution (high VUs, longer duration)");
|
||||
console.log(" node run-tests.js cloud core-api DEV");
|
||||
console.log(" node run-tests.js cloud all DEV");
|
||||
|
||||
const hasCloudCreds = checkCloudCredentials();
|
||||
console.log(
|
||||
`\n☁️ Cloud Status: ${hasCloudCreds ? "✅ Configured" : "❌ Missing credentials"}`,
|
||||
);
|
||||
}
|
||||
|
||||
function runSequentialTests(testNames, environment, isCloud = false) {
|
||||
const tests = testNames.split(",").map((t) => t.trim());
|
||||
const mode = isCloud ? "cloud" : "local";
|
||||
console.log(
|
||||
`🚀 Running ${tests.length} tests sequentially in ${mode} mode on ${environment}`,
|
||||
);
|
||||
|
||||
for (let i = 0; i < tests.length; i++) {
|
||||
const testName = tests[i];
|
||||
console.log(`\n📊 Test ${i + 1}/${tests.length}: ${testName}`);
|
||||
|
||||
if (isCloud) {
|
||||
runCloudTest(testName, environment);
|
||||
} else {
|
||||
runLocalTest(testName, environment);
|
||||
}
|
||||
|
||||
// Brief pause between tests (except last one)
|
||||
if (i < tests.length - 1) {
|
||||
const pauseTime = isCloud ? "2 minutes" : "10 seconds";
|
||||
const pauseCmd = isCloud ? "sleep 120" : "sleep 10";
|
||||
console.log(`⏸️ Waiting ${pauseTime} before next test...`);
|
||||
if (!isCloud) {
|
||||
// Note: In real implementation, would use setTimeout/sleep for local tests
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Main CLI
|
||||
const [, , command, testOrEnv, environment] = process.argv;
|
||||
|
||||
switch (command) {
|
||||
case "verify":
|
||||
verifySetup();
|
||||
break;
|
||||
case "list":
|
||||
listTests();
|
||||
break;
|
||||
case "run":
|
||||
if (testOrEnv === "all") {
|
||||
runAllLocalTests(environment || "DEV");
|
||||
} else if (testOrEnv?.includes(",")) {
|
||||
runSequentialTests(testOrEnv, environment || "DEV", false);
|
||||
} else {
|
||||
runLocalTest(testOrEnv, environment || "DEV");
|
||||
}
|
||||
break;
|
||||
case "cloud":
|
||||
if (!checkCloudCredentials()) {
|
||||
process.exit(1);
|
||||
}
|
||||
if (testOrEnv === "all") {
|
||||
runAllCloudTests(environment || "DEV");
|
||||
} else if (testOrEnv?.includes(",")) {
|
||||
runSequentialTests(testOrEnv, environment || "DEV", true);
|
||||
} else {
|
||||
runCloudTest(testOrEnv, environment || "DEV");
|
||||
}
|
||||
break;
|
||||
default:
|
||||
listTests();
|
||||
}
|
||||
95
autogpt_platform/backend/poetry.lock
generated
95
autogpt_platform/backend/poetry.lock
generated
@@ -3451,6 +3451,99 @@ files = [
|
||||
importlib-metadata = ">=6.0,<8.8.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.11.3"
|
||||
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "orjson-3.11.3-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:29cb1f1b008d936803e2da3d7cba726fc47232c45df531b29edf0b232dd737e7"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97dceed87ed9139884a55db8722428e27bd8452817fbf1869c58b49fecab1120"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:58533f9e8266cb0ac298e259ed7b4d42ed3fa0b78ce76860626164de49e0d467"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c212cfdd90512fe722fa9bd620de4d46cda691415be86b2e02243242ae81873"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff835b5d3e67d9207343effb03760c00335f8b5285bfceefd4dc967b0e48f6a"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f5aa4682912a450c2db89cbd92d356fef47e115dffba07992555542f344d301b"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7d18dd34ea2e860553a579df02041845dee0af8985dff7f8661306f95504ddf"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d8b11701bc43be92ea42bd454910437b355dfb63696c06fe953ffb40b5f763b4"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:90368277087d4af32d38bd55f9da2ff466d25325bf6167c8f382d8ee40cb2bbc"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fd7ff459fb393358d3a155d25b275c60b07a2c83dcd7ea962b1923f5a1134569"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f8d902867b699bcd09c176a280b1acdab57f924489033e53d0afe79817da37e6"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-win32.whl", hash = "sha256:bb93562146120bb51e6b154962d3dadc678ed0fce96513fa6bc06599bb6f6edc"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-win_amd64.whl", hash = "sha256:976c6f1975032cc327161c65d4194c549f2589d88b105a5e3499429a54479770"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9d2ae0cc6aeb669633e0124531f342a17d8e97ea999e42f12a5ad4adaa304c5f"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:ba21dbb2493e9c653eaffdc38819b004b7b1b246fb77bfc93dc016fe664eac91"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00f1a271e56d511d1569937c0447d7dce5a99a33ea0dec76673706360a051904"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b67e71e47caa6680d1b6f075a396d04fa6ca8ca09aafb428731da9b3ea32a5a6"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d7d012ebddffcce8c85734a6d9e5f08180cd3857c5f5a3ac70185b43775d043d"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd759f75d6b8d1b62012b7f5ef9461d03c804f94d539a5515b454ba3a6588038"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6890ace0809627b0dff19cfad92d69d0fa3f089d3e359a2a532507bb6ba34efb"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d4a5e041ae435b815e568537755773d05dac031fee6a57b4ba70897a44d9d2"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d68bf97a771836687107abfca089743885fb664b90138d8761cce61d5625d55"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:bfc27516ec46f4520b18ef645864cee168d2a027dbf32c5537cb1f3e3c22dac1"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f66b001332a017d7945e177e282a40b6997056394e3ed7ddb41fb1813b83e824"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:212e67806525d2561efbfe9e799633b17eb668b8964abed6b5319b2f1cfbae1f"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-win32.whl", hash = "sha256:6e8e0c3b85575a32f2ffa59de455f85ce002b8bdc0662d6b9c2ed6d80ab5d204"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-win_amd64.whl", hash = "sha256:6be2f1b5d3dc99a5ce5ce162fc741c22ba9f3443d3dd586e6a1211b7bc87bc7b"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-win_arm64.whl", hash = "sha256:fafb1a99d740523d964b15c8db4eabbfc86ff29f84898262bf6e3e4c9e97e43e"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8c752089db84333e36d754c4baf19c0e1437012242048439c7e80eb0e6426e3b"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:9b8761b6cf04a856eb544acdd82fc594b978f12ac3602d6374a7edb9d86fd2c2"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b13974dc8ac6ba22feaa867fc19135a3e01a134b4f7c9c28162fed4d615008a"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f83abab5bacb76d9c821fd5c07728ff224ed0e52d7a71b7b3de822f3df04e15c"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e6fbaf48a744b94091a56c62897b27c31ee2da93d826aa5b207131a1e13d4064"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc779b4f4bba2847d0d2940081a7b6f7b5877e05408ffbb74fa1faf4a136c424"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd4b909ce4c50faa2192da6bb684d9848d4510b736b0611b6ab4020ea6fd2d23"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:524b765ad888dc5518bbce12c77c2e83dee1ed6b0992c1790cc5fb49bb4b6667"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:84fd82870b97ae3cdcea9d8746e592b6d40e1e4d4527835fc520c588d2ded04f"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:fbecb9709111be913ae6879b07bafd4b0785b44c1eb5cac8ac76da048b3885a1"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9dba358d55aee552bd868de348f4736ca5a4086d9a62e2bfbbeeb5629fe8b0cc"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eabcf2e84f1d7105f84580e03012270c7e97ecb1fb1618bda395061b2a84a049"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-win32.whl", hash = "sha256:3782d2c60b8116772aea8d9b7905221437fdf53e7277282e8d8b07c220f96cca"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:79b44319268af2eaa3e315b92298de9a0067ade6e6003ddaef72f8e0bedb94f1"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:0e92a4e83341ef79d835ca21b8bd13e27c859e4e9e4d7b63defc6e58462a3710"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:af40c6612fd2a4b00de648aa26d18186cd1322330bd3a3cc52f87c699e995810"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:9f1587f26c235894c09e8b5b7636a38091a9e6e7fe4531937534749c04face43"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61dcdad16da5bb486d7227a37a2e789c429397793a6955227cedbd7252eb5a27"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:11c6d71478e2cbea0a709e8a06365fa63da81da6498a53e4c4f065881d21ae8f"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff94112e0098470b665cb0ed06efb187154b63649403b8d5e9aedeb482b4548c"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae8b756575aaa2a855a75192f356bbda11a89169830e1439cfb1a3e1a6dde7be"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c9416cc19a349c167ef76135b2fe40d03cea93680428efee8771f3e9fb66079d"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b822caf5b9752bc6f246eb08124c3d12bf2175b66ab74bac2ef3bbf9221ce1b2"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:414f71e3bdd5573893bf5ecdf35c32b213ed20aa15536fe2f588f946c318824f"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:828e3149ad8815dc14468f36ab2a4b819237c155ee1370341b91ea4c8672d2ee"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ac9e05f25627ffc714c21f8dfe3a579445a5c392a9c8ae7ba1d0e9fb5333f56e"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e44fbe4000bd321d9f3b648ae46e0196d21577cf66ae684a96ff90b1f7c93633"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-win32.whl", hash = "sha256:2039b7847ba3eec1f5886e75e6763a16e18c68a63efc4b029ddf994821e2e66b"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:29be5ac4164aa8bdcba5fa0700a3c9c316b411d8ed9d39ef8a882541bd452fae"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-win_arm64.whl", hash = "sha256:18bd1435cb1f2857ceb59cfb7de6f92593ef7b831ccd1b9bfb28ca530e539dce"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cf4b81227ec86935568c7edd78352a92e97af8da7bd70bdfdaa0d2e0011a1ab4"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:bc8bc85b81b6ac9fc4dae393a8c159b817f4c2c9dee5d12b773bddb3b95fc07e"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-manylinux_2_34_aarch64.whl", hash = "sha256:88dcfc514cfd1b0de038443c7b3e6a9797ffb1b3674ef1fd14f701a13397f82d"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-manylinux_2_34_x86_64.whl", hash = "sha256:d61cd543d69715d5fc0a690c7c6f8dcc307bc23abef9738957981885f5f38229"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2b7b153ed90ababadbef5c3eb39549f9476890d339cf47af563aea7e07db2451"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:7909ae2460f5f494fecbcd10613beafe40381fd0316e35d6acb5f3a05bfda167"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:2030c01cbf77bc67bee7eef1e7e31ecf28649353987775e3583062c752da0077"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a0169ebd1cbd94b26c7a7ad282cf5c2744fce054133f959e02eb5265deae1872"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-win32.whl", hash = "sha256:0c6d7328c200c349e3a4c6d8c83e0a5ad029bdc2d417f234152bf34842d0fc8d"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-win_amd64.whl", hash = "sha256:317bbe2c069bbc757b1a2e4105b64aacd3bc78279b66a6b9e51e846e4809f804"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-win_arm64.whl", hash = "sha256:e8f6a7a27d7b7bec81bd5924163e9af03d49bbb63013f107b48eb5d16db711bc"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:56afaf1e9b02302ba636151cfc49929c1bb66b98794291afd0e5f20fecaf757c"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:913f629adef31d2d350d41c051ce7e33cf0fd06a5d1cb28d49b1899b23b903aa"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e0a23b41f8f98b4e61150a03f83e4f0d566880fe53519d445a962929a4d21045"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d721fee37380a44f9d9ce6c701b3960239f4fb3d5ceea7f31cbd43882edaa2f"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73b92a5b69f31b1a58c0c7e31080aeaec49c6e01b9522e71ff38d08f15aa56de"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d2489b241c19582b3f1430cc5d732caefc1aaf378d97e7fb95b9e56bed11725f"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5189a5dab8b0312eadaf9d58d3049b6a52c454256493a557405e77a3d67ab7f"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9d8787bdfbb65a85ea76d0e96a3b1bed7bf0fbcb16d40408dc1172ad784a49d2"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:8e531abd745f51f8035e207e75e049553a86823d189a51809c078412cefb399a"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:8ab962931015f170b97a3dd7bd933399c1bae8ed8ad0fb2a7151a5654b6941c7"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:124d5ba71fee9c9902c4a7baa9425e663f7f0aecf73d31d54fe3dd357d62c1a7"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-win32.whl", hash = "sha256:22724d80ee5a815a44fc76274bb7ba2e7464f5564aacb6ecddaa9970a83e3225"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-win_amd64.whl", hash = "sha256:215c595c792a87d4407cb72dd5e0f6ee8e694ceeb7f9102b533c5a9bf2a916bb"},
|
||||
{file = "orjson-3.11.3.tar.gz", hash = "sha256:1c0603b1d2ffcd43a411d64797a19556ef76958aef1c182f22dc30860152a98a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
@@ -7159,4 +7252,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "2c7e9370f500039b99868376021627c5a120e0ee31c5c5e6de39db2c3d82f414"
|
||||
content-hash = "b2363edeebb91f410039c8d4b563f683c1edb0cf4bda4f3e6c287040e93639bc"
|
||||
|
||||
@@ -38,6 +38,7 @@ mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
ollama = "^0.5.1"
|
||||
openai = "^1.97.1"
|
||||
orjson = "^3.10.0"
|
||||
pika = "^1.3.2"
|
||||
pinecone = "^7.3.0"
|
||||
poetry = "2.1.1" # CHECK DEPENDABOT SUPPORT BEFORE UPGRADING
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run test data creation and update scripts in sequence.
|
||||
|
||||
Usage:
|
||||
poetry run python run_test_data.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_command(cmd: list[str], cwd: Path | None = None) -> bool:
|
||||
"""Run a command and return True if successful."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, check=True, capture_output=True, text=True, cwd=cwd
|
||||
)
|
||||
if result.stdout:
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running command: {' '.join(cmd)}")
|
||||
print(f"Error: {e.stderr}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run test data scripts."""
|
||||
print("=" * 60)
|
||||
print("Running Test Data Scripts for AutoGPT Platform")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get the backend directory
|
||||
backend_dir = Path(__file__).parent
|
||||
test_dir = backend_dir / "test"
|
||||
|
||||
# Check if we're in the right directory
|
||||
if not (backend_dir / "pyproject.toml").exists():
|
||||
print("ERROR: This script must be run from the backend directory")
|
||||
sys.exit(1)
|
||||
|
||||
print("1. Checking database connection...")
|
||||
print("-" * 40)
|
||||
|
||||
# Import here to ensure proper environment setup
|
||||
try:
|
||||
from prisma import Prisma
|
||||
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
print("✓ Database connection successful")
|
||||
await db.disconnect()
|
||||
except Exception as e:
|
||||
print(f"✗ Database connection failed: {e}")
|
||||
print("\nPlease ensure:")
|
||||
print("1. The database services are running (docker compose up -d)")
|
||||
print("2. The DATABASE_URL in .env is correct")
|
||||
print("3. Migrations have been run (poetry run prisma migrate deploy)")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
print("2. Running test data creator...")
|
||||
print("-" * 40)
|
||||
|
||||
# Run test_data_creator.py
|
||||
if run_command(["poetry", "run", "python", "test_data_creator.py"], cwd=test_dir):
|
||||
print()
|
||||
print("✅ Test data created successfully!")
|
||||
|
||||
print()
|
||||
print("3. Running test data updater...")
|
||||
print("-" * 40)
|
||||
|
||||
# Run test_data_updater.py
|
||||
if run_command(
|
||||
["poetry", "run", "python", "test_data_updater.py"], cwd=test_dir
|
||||
):
|
||||
print()
|
||||
print("✅ Test data updated successfully!")
|
||||
else:
|
||||
print()
|
||||
print("❌ Test data updater failed!")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print()
|
||||
print("❌ Test data creator failed!")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("Test data setup completed successfully!")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("The materialized views have been populated with test data:")
|
||||
print("- mv_agent_run_counts: Agent execution statistics")
|
||||
print("- mv_review_stats: Store listing review statistics")
|
||||
print()
|
||||
print("You can now:")
|
||||
print("1. Run tests: poetry run test")
|
||||
print("2. Start the backend: poetry run serve")
|
||||
print("3. View data in the database")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -48,7 +48,10 @@
|
||||
"@radix-ui/react-tabs": "1.1.13",
|
||||
"@radix-ui/react-toast": "1.2.15",
|
||||
"@radix-ui/react-tooltip": "1.2.8",
|
||||
"@sentry/nextjs": "9.42.0",
|
||||
"@rjsf/core": "5.24.13",
|
||||
"@rjsf/utils": "5.24.13",
|
||||
"@rjsf/validator-ajv8": "5.24.13",
|
||||
"@sentry/nextjs": "10.15.0",
|
||||
"@supabase/ssr": "0.6.1",
|
||||
"@supabase/supabase-js": "2.55.0",
|
||||
"@tanstack/react-query": "5.85.3",
|
||||
@@ -103,7 +106,8 @@
|
||||
"tailwindcss-animate": "1.0.7",
|
||||
"uuid": "11.1.0",
|
||||
"vaul": "1.1.2",
|
||||
"zod": "3.25.76"
|
||||
"zod": "3.25.76",
|
||||
"zustand": "5.0.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.1.1",
|
||||
|
||||
1142
autogpt_platform/frontend/pnpm-lock.yaml
generated
1142
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,31 @@
|
||||
"use client";
|
||||
|
||||
import { Tabs, TabsList, TabsTrigger } from "@/components/__legacy__/ui/tabs";
|
||||
|
||||
export type BuilderView = "old" | "new";
|
||||
|
||||
export function BuilderViewTabs({
|
||||
value,
|
||||
onChange,
|
||||
}: {
|
||||
value: BuilderView;
|
||||
onChange: (value: BuilderView) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="pointer-events-auto fixed right-4 top-20 z-50">
|
||||
<Tabs
|
||||
value={value}
|
||||
onValueChange={(v: string) => onChange(v as BuilderView)}
|
||||
>
|
||||
<TabsList className="w-fit bg-zinc-900">
|
||||
<TabsTrigger value="old" className="text-gray-100">
|
||||
Old
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="new" className="text-gray-100">
|
||||
New
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useMemo } from "react";
|
||||
import { BuilderView } from "./BuilderViewTabs";
|
||||
|
||||
export function useBuilderView() {
|
||||
const isNewFlowEditorEnabled = useGetFlag(Flag.NEW_FLOW_EDITOR);
|
||||
const isBuilderViewSwitchEnabled = useGetFlag(Flag.BUILDER_VIEW_SWITCH);
|
||||
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const currentView = searchParams.get("view");
|
||||
const defaultView = "old";
|
||||
const selectedView = useMemo<BuilderView>(() => {
|
||||
if (currentView === "new" || currentView === "old") return currentView;
|
||||
return defaultView;
|
||||
}, [currentView, defaultView]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isBuilderViewSwitchEnabled === true) {
|
||||
if (currentView !== "new" && currentView !== "old") {
|
||||
const params = new URLSearchParams(searchParams);
|
||||
params.set("view", defaultView);
|
||||
router.replace(`${pathname}?${params.toString()}`);
|
||||
}
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [isBuilderViewSwitchEnabled, defaultView, pathname, router, searchParams]);
|
||||
|
||||
const setSelectedView = (value: BuilderView) => {
|
||||
const params = new URLSearchParams(searchParams);
|
||||
params.set("view", value);
|
||||
router.push(`${pathname}?${params.toString()}`);
|
||||
};
|
||||
|
||||
return {
|
||||
isSwitchEnabled: isBuilderViewSwitchEnabled === true,
|
||||
selectedView,
|
||||
setSelectedView,
|
||||
isNewFlowEditorEnabled: Boolean(isNewFlowEditorEnabled),
|
||||
} as const;
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
import { ReactFlow, Background, Controls } from "@xyflow/react";
|
||||
import { useNodeStore } from "../../stores/nodeStore";
|
||||
|
||||
import NewControlPanel from "../NewBlockMenu/NewControlPanel/NewControlPanel";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useMemo } from "react";
|
||||
import { CustomNode } from "./nodes/CustomNode";
|
||||
import { useCustomEdge } from "./edges/useCustomEdge";
|
||||
import CustomEdge from "./edges/CustomEdge";
|
||||
import { RightSidebar } from "../RIghtSidebar";
|
||||
|
||||
export const Flow = () => {
|
||||
// All these 3 are working perfectly
|
||||
const nodes = useNodeStore(useShallow((state) => state.nodes));
|
||||
const onNodesChange = useNodeStore(
|
||||
useShallow((state) => state.onNodesChange),
|
||||
);
|
||||
const nodeTypes = useMemo(() => ({ custom: CustomNode }), []);
|
||||
const { edges, onConnect, onEdgesChange } = useCustomEdge();
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-full dark:bg-slate-900">
|
||||
{/* Builder area - flexible width */}
|
||||
<div className="relative flex-1">
|
||||
<ReactFlow
|
||||
nodes={nodes}
|
||||
onNodesChange={onNodesChange}
|
||||
nodeTypes={nodeTypes}
|
||||
edges={edges}
|
||||
onConnect={onConnect}
|
||||
onEdgesChange={onEdgesChange}
|
||||
edgeTypes={{ custom: CustomEdge }}
|
||||
>
|
||||
<Background />
|
||||
<Controls />
|
||||
<NewControlPanel />
|
||||
</ReactFlow>
|
||||
</div>
|
||||
<div className="w-[30%]">
|
||||
<RightSidebar />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,88 @@
|
||||
import { ArrayFieldTemplateItemType, RJSFSchema } from "@rjsf/utils";
|
||||
import { generateHandleId, HandleIdType } from "../../handlers/helpers";
|
||||
import { ArrayEditorContext } from "./ArrayEditorContext";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { PlusIcon, XIcon } from "@phosphor-icons/react";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
|
||||
export interface ArrayEditorProps {
|
||||
items?: ArrayFieldTemplateItemType<any, RJSFSchema, any>[];
|
||||
nodeId: string;
|
||||
canAdd: boolean | undefined;
|
||||
onAddClick?: () => void;
|
||||
disabled: boolean | undefined;
|
||||
readonly: boolean | undefined;
|
||||
id: string;
|
||||
}
|
||||
|
||||
export const ArrayEditor = ({
|
||||
items,
|
||||
nodeId,
|
||||
canAdd,
|
||||
onAddClick,
|
||||
disabled,
|
||||
readonly,
|
||||
id,
|
||||
}: ArrayEditorProps) => {
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex-1">
|
||||
{items?.map((element) => {
|
||||
const fieldKey = generateHandleId(
|
||||
id,
|
||||
[element.index.toString()],
|
||||
HandleIdType.ARRAY,
|
||||
);
|
||||
const isConnected = isInputConnected(nodeId, fieldKey);
|
||||
return (
|
||||
<div
|
||||
key={element.key}
|
||||
className="-ml-2 flex max-w-[400px] items-center gap-2"
|
||||
>
|
||||
<ArrayEditorContext.Provider
|
||||
value={{
|
||||
isArrayItem: true,
|
||||
fieldKey,
|
||||
isConnected,
|
||||
}}
|
||||
>
|
||||
{element.children}
|
||||
</ArrayEditorContext.Provider>
|
||||
|
||||
{element.hasRemove &&
|
||||
!readonly &&
|
||||
!disabled &&
|
||||
!isConnected && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
className="relative top-5"
|
||||
size="small"
|
||||
onClick={element.onDropIndexClick(element.index)}
|
||||
>
|
||||
<XIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{canAdd && !readonly && !disabled && (
|
||||
<Button
|
||||
type="button"
|
||||
size="small"
|
||||
onClick={onAddClick}
|
||||
className="w-full"
|
||||
>
|
||||
<PlusIcon className="mr-2 h-4 w-4" />
|
||||
Add Item
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,11 @@
|
||||
import { createContext } from "react";
|
||||
|
||||
export const ArrayEditorContext = createContext<{
|
||||
isArrayItem: boolean;
|
||||
fieldKey: string;
|
||||
isConnected: boolean;
|
||||
}>({
|
||||
isArrayItem: false,
|
||||
fieldKey: "",
|
||||
isConnected: false,
|
||||
});
|
||||
@@ -0,0 +1,166 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { Plus, X } from "lucide-react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import NodeHandle from "../../handlers/NodeHandle";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { generateHandleId, HandleIdType } from "../../handlers/helpers";
|
||||
|
||||
export interface ObjectEditorProps {
|
||||
id: string;
|
||||
value?: Record<string, any>;
|
||||
onChange?: (value: Record<string, any>) => void;
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
className?: string;
|
||||
nodeId: string;
|
||||
fieldKey: string;
|
||||
}
|
||||
|
||||
export const ObjectEditor = React.forwardRef<HTMLDivElement, ObjectEditorProps>(
|
||||
(
|
||||
{
|
||||
id,
|
||||
value = {},
|
||||
onChange,
|
||||
placeholder = "Enter value",
|
||||
disabled = false,
|
||||
className,
|
||||
nodeId,
|
||||
fieldKey,
|
||||
},
|
||||
ref,
|
||||
) => {
|
||||
const setProperty = (key: string, propertyValue: any) => {
|
||||
if (!onChange) return;
|
||||
|
||||
const newData: Record<string, any> = { ...value };
|
||||
if (propertyValue === undefined || propertyValue === "") {
|
||||
delete newData[key];
|
||||
} else {
|
||||
newData[key] = propertyValue;
|
||||
}
|
||||
onChange(newData);
|
||||
};
|
||||
|
||||
const addProperty = () => {
|
||||
if (!onChange) return;
|
||||
onChange({ ...value, [""]: "" });
|
||||
};
|
||||
|
||||
const removeProperty = (key: string) => {
|
||||
if (!onChange) return;
|
||||
const newData = { ...value };
|
||||
delete newData[key];
|
||||
onChange(newData);
|
||||
};
|
||||
|
||||
const updateKey = (oldKey: string, newKey: string) => {
|
||||
if (!onChange || oldKey === newKey) return;
|
||||
|
||||
const propertyValue = value[oldKey];
|
||||
const newData: Record<string, any> = { ...value };
|
||||
delete newData[oldKey];
|
||||
newData[newKey] = propertyValue;
|
||||
onChange(newData);
|
||||
};
|
||||
|
||||
const hasEmptyKeys = Object.keys(value).some((key) => key.trim() === "");
|
||||
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className={`flex flex-col gap-2 ${className || ""}`}
|
||||
id={id}
|
||||
>
|
||||
{Object.entries(value).map(([key, propertyValue], idx) => {
|
||||
const dynamicHandleId = generateHandleId(
|
||||
fieldKey,
|
||||
[key],
|
||||
HandleIdType.KEY_VALUE,
|
||||
);
|
||||
const isDynamicPropertyConnected = isInputConnected(
|
||||
nodeId,
|
||||
dynamicHandleId,
|
||||
);
|
||||
|
||||
console.log("dynamicHandleId", dynamicHandleId);
|
||||
console.log("key", key);
|
||||
console.log("fieldKey", fieldKey);
|
||||
|
||||
return (
|
||||
<div key={idx} className="flex flex-col gap-2">
|
||||
<div className="-ml-2 flex items-center gap-1">
|
||||
<NodeHandle
|
||||
id={dynamicHandleId}
|
||||
isConnected={isDynamicPropertyConnected}
|
||||
side="left"
|
||||
/>
|
||||
|
||||
<Text variant="small" className="text-gray-600">
|
||||
#{key.trim() === "" ? "" : key}
|
||||
</Text>
|
||||
<Text variant="small" className="!text-green-500">
|
||||
(string)
|
||||
</Text>
|
||||
</div>
|
||||
{!isDynamicPropertyConnected && (
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
hideLabel={true}
|
||||
label=""
|
||||
id={`key-${idx}`}
|
||||
size="small"
|
||||
value={key}
|
||||
onChange={(e) => updateKey(key, e.target.value)}
|
||||
placeholder="Key"
|
||||
wrapperClassName="mb-0"
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Input
|
||||
hideLabel={true}
|
||||
label=""
|
||||
id={`value-${idx}`}
|
||||
size="small"
|
||||
value={propertyValue as string}
|
||||
onChange={(e) => setProperty(key, e.target.value)}
|
||||
placeholder={placeholder}
|
||||
wrapperClassName="mb-0"
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => removeProperty(key)}
|
||||
disabled={disabled}
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
size="small"
|
||||
onClick={addProperty}
|
||||
className="w-full"
|
||||
disabled={hasEmptyKeys || disabled}
|
||||
>
|
||||
<Plus className="mr-2 h-4 w-4" />
|
||||
Add Property
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
ObjectEditor.displayName = "ObjectEditor";
|
||||
@@ -0,0 +1,580 @@
|
||||
# Form Creator System
|
||||
|
||||
The Form Creator is a dynamic form generation system built on React JSON Schema Form (RJSF) that automatically creates interactive forms based on JSON schemas. It's the core component that powers the input handling in the FlowEditor.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Architecture Overview](#architecture-overview)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Schema Processing](#schema-processing)
|
||||
- [Widget System](#widget-system)
|
||||
- [Field System](#field-system)
|
||||
- [Template System](#template-system)
|
||||
- [Customization Guide](#customization-guide)
|
||||
- [Advanced Features](#advanced-features)
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The Form Creator system consists of several interconnected layers:
|
||||
|
||||
```
|
||||
FormCreator
|
||||
├── Schema Preprocessing
|
||||
│ └── input-schema-pre-processor.ts
|
||||
├── Widget System
|
||||
│ ├── TextInputWidget
|
||||
│ ├── SelectWidget
|
||||
│ ├── SwitchWidget
|
||||
│ └── ... (other widgets)
|
||||
├── Field System
|
||||
│ ├── AnyOfField
|
||||
│ ├── ObjectField
|
||||
│ └── CredentialsField
|
||||
├── Template System
|
||||
│ ├── FieldTemplate
|
||||
│ └── ArrayFieldTemplate
|
||||
└── UI Schema
|
||||
└── uiSchema.ts
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. **Schema Input**
|
||||
|
||||
The FormCreator receives a JSON schema that defines the structure of the form:
|
||||
|
||||
```typescript
|
||||
const schema = {
|
||||
type: "object",
|
||||
properties: {
|
||||
message: {
|
||||
type: "string",
|
||||
title: "Message",
|
||||
description: "Enter your message",
|
||||
},
|
||||
count: {
|
||||
type: "number",
|
||||
title: "Count",
|
||||
minimum: 0,
|
||||
},
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
### 2. **Schema Preprocessing**
|
||||
|
||||
The schema is preprocessed to ensure all properties have proper types:
|
||||
|
||||
```typescript
|
||||
// Before preprocessing
|
||||
{
|
||||
"properties": {
|
||||
"name": { "title": "Name" } // No type defined
|
||||
}
|
||||
}
|
||||
|
||||
// After preprocessing
|
||||
// if there is no type - that means it can accept any type
|
||||
{
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"anyOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "number" },
|
||||
{ "type": "boolean" },
|
||||
{ "type": "array", "items": { "type": "string" } },
|
||||
{ "type": "object" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. **Widget Mapping**
|
||||
|
||||
Schema types are mapped to appropriate input widgets:
|
||||
|
||||
```typescript
|
||||
// Schema type -> Widget mapping
|
||||
"string" -> TextInputWidget
|
||||
"number" -> TextInputWidget (with number type)
|
||||
"boolean" -> SwitchWidget
|
||||
"array" -> ArrayFieldTemplate
|
||||
"object" -> ObjectField
|
||||
"enum" -> SelectWidget
|
||||
```
|
||||
|
||||
### 4. **Form Rendering**
|
||||
|
||||
RJSF renders the form using the mapped widgets and templates:
|
||||
|
||||
```typescript
|
||||
<Form
|
||||
schema={preprocessedSchema}
|
||||
validator={validator}
|
||||
fields={fields}
|
||||
templates={templates}
|
||||
widgets={widgets}
|
||||
formContext={{ nodeId }}
|
||||
onChange={handleChange}
|
||||
uiSchema={uiSchema}
|
||||
/>
|
||||
```
|
||||
|
||||
## Schema Processing
|
||||
|
||||
### Input Schema Preprocessor
|
||||
|
||||
The `preprocessInputSchema` function ensures all properties have proper types:
|
||||
|
||||
```typescript
|
||||
export function preprocessInputSchema(schema: RJSFSchema): RJSFSchema {
|
||||
// Recursively processes properties
|
||||
if (processedSchema.properties) {
|
||||
for (const [key, property] of Object.entries(processedSchema.properties)) {
|
||||
// Add type if none exists
|
||||
if (
|
||||
!processedProperty.type &&
|
||||
!processedProperty.anyOf &&
|
||||
!processedProperty.oneOf &&
|
||||
!processedProperty.allOf
|
||||
) {
|
||||
processedProperty.anyOf = [
|
||||
{ type: "string" },
|
||||
{ type: "number" },
|
||||
{ type: "integer" },
|
||||
{ type: "boolean" },
|
||||
{ type: "array", items: { type: "string" } },
|
||||
{ type: "object" },
|
||||
{ type: "null" },
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Key Features
|
||||
|
||||
1. **Type Safety**: Ensures all properties have types
|
||||
2. **Recursive Processing**: Handles nested objects and arrays
|
||||
3. **Array Item Processing**: Processes array item schemas
|
||||
4. **Schema Cleanup**: Removes titles and descriptions from root schema
|
||||
|
||||
## Widget System
|
||||
|
||||
Widgets are the actual input components that users interact with.
|
||||
|
||||
### Available Widgets
|
||||
|
||||
#### TextInputWidget
|
||||
|
||||
Handles text, number, password, and textarea inputs:
|
||||
|
||||
```typescript
|
||||
export const TextInputWidget = (props: WidgetProps) => {
|
||||
const { schema } = props;
|
||||
const mapped = mapJsonSchemaTypeToInputType(schema);
|
||||
|
||||
const inputConfig = {
|
||||
[InputType.TEXT_AREA]: {
|
||||
htmlType: "textarea",
|
||||
placeholder: "Enter text...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : v),
|
||||
},
|
||||
[InputType.PASSWORD]: {
|
||||
htmlType: "password",
|
||||
placeholder: "Enter secret text...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : v),
|
||||
},
|
||||
[InputType.NUMBER]: {
|
||||
htmlType: "number",
|
||||
placeholder: "Enter number value...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : Number(v)),
|
||||
}
|
||||
};
|
||||
|
||||
return <Input {...config} />;
|
||||
};
|
||||
```
|
||||
|
||||
#### SelectWidget
|
||||
|
||||
Handles dropdown and multi-select inputs:
|
||||
|
||||
```typescript
|
||||
export const SelectWidget = (props: WidgetProps) => {
|
||||
const { options, value, onChange, schema } = props;
|
||||
const enumOptions = options.enumOptions || [];
|
||||
const type = mapJsonSchemaTypeToInputType(schema);
|
||||
|
||||
if (type === InputType.MULTI_SELECT) {
|
||||
return <MultiSelector values={value} onValuesChange={onChange} />;
|
||||
}
|
||||
|
||||
return <Select value={value} onValueChange={onChange} options={enumOptions} />;
|
||||
};
|
||||
```
|
||||
|
||||
#### SwitchWidget
|
||||
|
||||
Handles boolean toggles:
|
||||
|
||||
```typescript
|
||||
export function SwitchWidget(props: WidgetProps) {
|
||||
const { value = false, onChange, disabled, readonly } = props;
|
||||
|
||||
return (
|
||||
<Switch
|
||||
checked={Boolean(value)}
|
||||
onCheckedChange={(checked) => onChange(checked)}
|
||||
disabled={disabled || readonly}
|
||||
/>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### Widget Registration
|
||||
|
||||
Widgets are registered in the widgets registry:
|
||||
|
||||
```typescript
|
||||
export const widgets: RegistryWidgetsType = {
|
||||
TextWidget: TextInputWidget,
|
||||
SelectWidget: SelectWidget,
|
||||
CheckboxWidget: SwitchWidget,
|
||||
FileWidget: FileWidget,
|
||||
DateWidget: DateInputWidget,
|
||||
TimeWidget: TimeInputWidget,
|
||||
DateTimeWidget: DateTimeInputWidget,
|
||||
};
|
||||
```
|
||||
|
||||
## Field System
|
||||
|
||||
Fields handle complex data structures and provide custom rendering logic.
|
||||
|
||||
### AnyOfField
|
||||
|
||||
Handles union types and nullable fields:
|
||||
|
||||
```typescript
|
||||
export const AnyOfField = ({ schema, formData, onChange, ...props }: FieldProps) => {
|
||||
const { isNullableType, selectedType, handleTypeChange, currentTypeOption } = useAnyOfField(schema, formData, onChange);
|
||||
|
||||
if (isNullableType) {
|
||||
return (
|
||||
<div>
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
<Switch checked={isEnabled} onCheckedChange={handleNullableToggle} />
|
||||
{isEnabled && renderInput(nonNull)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
<Select value={selectedType} onValueChange={handleTypeChange} />
|
||||
{renderInput(currentTypeOption)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### ObjectField
|
||||
|
||||
Handles free-form object editing:
|
||||
|
||||
```typescript
|
||||
export const ObjectField = (props: FieldProps) => {
|
||||
const { schema, formData = {}, onChange, name, idSchema, formContext } = props;
|
||||
|
||||
// Use default field for fixed-schema objects
|
||||
if (idSchema?.$id === "root" || !isFreeForm) {
|
||||
return <DefaultObjectField {...props} />;
|
||||
}
|
||||
|
||||
// Use custom ObjectEditor for free-form objects
|
||||
return (
|
||||
<ObjectEditor
|
||||
id={`${name}-input`}
|
||||
nodeId={nodeId}
|
||||
fieldKey={fieldKey}
|
||||
value={formData}
|
||||
onChange={onChange}
|
||||
/>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### Field Registration
|
||||
|
||||
Fields are registered in the fields registry:
|
||||
|
||||
```typescript
|
||||
export const fields: RegistryFieldsType = {
|
||||
AnyOfField: AnyOfField,
|
||||
credentials: CredentialsField,
|
||||
ObjectField: ObjectField,
|
||||
};
|
||||
```
|
||||
|
||||
## Template System
|
||||
|
||||
Templates provide custom rendering for form structure elements.
|
||||
|
||||
### FieldTemplate
|
||||
|
||||
Custom field wrapper with connection handles:
|
||||
|
||||
```typescript
|
||||
const FieldTemplate: React.FC<FieldTemplateProps> = ({
|
||||
id, label, required, description, children, schema, formContext, uiSchema
|
||||
}) => {
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
const { nodeId } = formContext;
|
||||
|
||||
const fieldKey = generateHandleId(id);
|
||||
const isConnected = isInputConnected(nodeId, fieldKey);
|
||||
|
||||
return (
|
||||
<div className="mt-4 w-[400px] space-y-1">
|
||||
{label && schema.type && (
|
||||
<label htmlFor={id} className="flex items-center gap-1">
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
<Text variant="body">{label}</Text>
|
||||
<Text variant="small" className={colorClass}>({displayType})</Text>
|
||||
{required && <span style={{ color: "red" }}>*</span>}
|
||||
</label>
|
||||
)}
|
||||
{!isConnected && <div className="pl-2">{children}</div>}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### ArrayFieldTemplate
|
||||
|
||||
Custom array editing interface:
|
||||
|
||||
```typescript
|
||||
function ArrayFieldTemplate(props: ArrayFieldTemplateProps) {
|
||||
const { items, canAdd, onAddClick, disabled, readonly, formContext, idSchema } = props;
|
||||
const { nodeId } = formContext;
|
||||
|
||||
return (
|
||||
<ArrayEditor
|
||||
items={items}
|
||||
nodeId={nodeId}
|
||||
canAdd={canAdd}
|
||||
onAddClick={onAddClick}
|
||||
disabled={disabled}
|
||||
readonly={readonly}
|
||||
id={idSchema.$id}
|
||||
/>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## Customization Guide
|
||||
|
||||
### Adding a Custom Widget
|
||||
|
||||
1. **Create the Widget Component**:
|
||||
|
||||
```typescript
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
|
||||
export const MyCustomWidget = (props: WidgetProps) => {
|
||||
const { value, onChange, schema, disabled, readonly } = props;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<input
|
||||
value={value || ""}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
disabled={disabled || readonly}
|
||||
placeholder={schema.placeholder}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
2. **Register the Widget**:
|
||||
|
||||
```typescript
|
||||
// In widgets/index.ts
|
||||
export const widgets: RegistryWidgetsType = {
|
||||
// ... existing widgets
|
||||
MyCustomWidget: MyCustomWidget,
|
||||
};
|
||||
```
|
||||
|
||||
3. **Use in Schema**:
|
||||
|
||||
```typescript
|
||||
const schema = {
|
||||
type: "object",
|
||||
properties: {
|
||||
myField: {
|
||||
type: "string",
|
||||
"ui:widget": "MyCustomWidget",
|
||||
},
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
### Adding a Custom Field
|
||||
|
||||
1. **Create the Field Component**:
|
||||
|
||||
```typescript
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
|
||||
export const MyCustomField = (props: FieldProps) => {
|
||||
const { schema, formData, onChange, name, idSchema, formContext } = props;
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Custom field implementation */}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
2. **Register the Field**:
|
||||
|
||||
```typescript
|
||||
// In fields/index.ts
|
||||
export const fields: RegistryFieldsType = {
|
||||
// ... existing fields
|
||||
MyCustomField: MyCustomField,
|
||||
};
|
||||
```
|
||||
|
||||
3. **Use in Schema**:
|
||||
|
||||
```typescript
|
||||
const schema = {
|
||||
type: "object",
|
||||
properties: {
|
||||
myField: {
|
||||
type: "string",
|
||||
"ui:field": "MyCustomField",
|
||||
},
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
### Customizing Templates
|
||||
|
||||
1. **Create Custom Template**:
|
||||
|
||||
```typescript
|
||||
const MyCustomFieldTemplate: React.FC<FieldTemplateProps> = (props) => {
|
||||
return (
|
||||
<div className="my-custom-field">
|
||||
{/* Custom template implementation */}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
2. **Register Template**:
|
||||
|
||||
```typescript
|
||||
// In templates/index.ts
|
||||
export const templates = {
|
||||
FieldTemplate: MyCustomFieldTemplate,
|
||||
// ... other templates
|
||||
};
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Connection State Management
|
||||
|
||||
The Form Creator integrates with the edge store to show/hide input fields based on connection state:
|
||||
|
||||
```typescript
|
||||
const FieldTemplate = ({ id, children, formContext }) => {
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
const { nodeId } = formContext;
|
||||
|
||||
const fieldKey = generateHandleId(id);
|
||||
const isConnected = isInputConnected(nodeId, fieldKey);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
{!isConnected && children}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### Advanced Mode
|
||||
|
||||
Fields can be hidden/shown based on advanced mode:
|
||||
|
||||
```typescript
|
||||
const FieldTemplate = ({ schema, formContext }) => {
|
||||
const { nodeId } = formContext;
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[nodeId] || false
|
||||
);
|
||||
|
||||
if (!showAdvanced && schema.advanced === true) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <div>{/* field content */}</div>;
|
||||
};
|
||||
```
|
||||
|
||||
### Array Item Context
|
||||
|
||||
Array items have special context for connection handling:
|
||||
|
||||
```typescript
|
||||
const ArrayEditor = ({ items, nodeId }) => {
|
||||
return (
|
||||
<div>
|
||||
{items?.map((element) => {
|
||||
const fieldKey = generateHandleId(id, [element.index.toString()], HandleIdType.ARRAY);
|
||||
const isConnected = isInputConnected(nodeId, fieldKey);
|
||||
|
||||
return (
|
||||
<ArrayEditorContext.Provider
|
||||
value={{ isArrayItem: true, fieldKey, isConnected }}
|
||||
>
|
||||
{element.children}
|
||||
</ArrayEditorContext.Provider>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### Handle ID Generation
|
||||
|
||||
Handle IDs are generated based on field structure:
|
||||
|
||||
```typescript
|
||||
// Simple field
|
||||
generateHandleId("message"); // "message"
|
||||
|
||||
// Nested field
|
||||
generateHandleId("config", ["api_key"]); // "config.api_key"
|
||||
|
||||
// Array item
|
||||
generateHandleId("items", ["0"]); // "items_$_0"
|
||||
|
||||
// Key-value pair
|
||||
generateHandleId("headers", ["Authorization"]); // "headers_#_Authorization"
|
||||
```
|
||||
@@ -0,0 +1,159 @@
|
||||
# FlowEditor Component
|
||||
|
||||
The FlowEditor is a powerful visual flow builder component built on top of React Flow that allows users to create, connect, and manage nodes in a visual workflow. It provides a comprehensive form system with dynamic input handling, connection management, and advanced features.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Architecture Overview](#architecture-overview)
|
||||
- [Store Management](#store-management)
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The FlowEditor follows a modular architecture with clear separation of concerns:
|
||||
|
||||
```
|
||||
FlowEditor/
|
||||
├── Flow.tsx # Main component
|
||||
├── nodes/ # Node-related components
|
||||
│ ├── CustomNode.tsx # Main node component
|
||||
│ ├── FormCreator.tsx # Dynamic form generator
|
||||
│ ├── fields/ # Custom field components
|
||||
│ ├── widgets/ # Custom input widgets
|
||||
│ ├── templates/ # RJSF templates
|
||||
│ └── helpers.ts # Utility functions
|
||||
├── edges/ # Edge-related components
|
||||
│ ├── CustomEdge.tsx # Custom edge component
|
||||
│ ├── useCustomEdge.ts # Edge management hook
|
||||
│ └── helpers.ts # Edge utilities
|
||||
├── handlers/ # Connection handles
|
||||
│ ├── NodeHandle.tsx # Connection handle component
|
||||
│ └── helpers.ts # Handle utilities
|
||||
├── components/ # Shared components
|
||||
│ ├── ArrayEditor/ # Array editing components
|
||||
│ └── ObjectEditor/ # Object editing components
|
||||
└── processors/ # Data processors
|
||||
└── input-schema-pre-processor.ts
|
||||
```
|
||||
|
||||
## Store Management
|
||||
|
||||
The FlowEditor uses Zustand for state management with two main stores:
|
||||
|
||||
### NodeStore (`useNodeStore`)
|
||||
|
||||
Manages all node-related state and operations.
|
||||
|
||||
**Key Features:**
|
||||
|
||||
- Node CRUD operations
|
||||
- Advanced state management per node
|
||||
- Form data persistence
|
||||
- Node counter for unique IDs
|
||||
|
||||
**Usage:**
|
||||
|
||||
```typescript
|
||||
import { useNodeStore } from "../stores/nodeStore";
|
||||
|
||||
// Get nodes
|
||||
const nodes = useNodeStore(useShallow((state) => state.nodes));
|
||||
|
||||
// Add a new node
|
||||
const addNode = useNodeStore((state) => state.addNode);
|
||||
|
||||
// Update node data
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
// Toggle advanced mode
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
```
|
||||
|
||||
**Store Methods:**
|
||||
|
||||
- `setNodes(nodes)` - Replace all nodes
|
||||
- `addNode(node)` - Add a single node
|
||||
- `addBlock(blockInfo)` - Add node from block info
|
||||
- `updateNodeData(nodeId, data)` - Update node data
|
||||
- `onNodesChange(changes)` - Handle node changes from React Flow
|
||||
- `setShowAdvanced(nodeId, show)` - Toggle advanced mode
|
||||
- `incrementNodeCounter()` - Get next node ID
|
||||
|
||||
### EdgeStore (`useEdgeStore`)
|
||||
|
||||
Manages all connection-related state and operations.
|
||||
|
||||
**Key Features:**
|
||||
|
||||
- Connection CRUD operations
|
||||
- Connection validation
|
||||
- Backend link conversion
|
||||
- Connection state queries
|
||||
|
||||
**Usage:**
|
||||
|
||||
```typescript
|
||||
import { useEdgeStore } from "../stores/edgeStore";
|
||||
|
||||
// Get connections
|
||||
const connections = useEdgeStore((state) => state.connections);
|
||||
|
||||
// Add connection
|
||||
const addConnection = useEdgeStore((state) => state.addConnection);
|
||||
|
||||
// Check if input is connected
|
||||
const isInputConnected = useEdgeStore((state) => state.isInputConnected);
|
||||
```
|
||||
|
||||
**Store Methods:**
|
||||
|
||||
- `setConnections(connections)` - Replace all connections
|
||||
- `addConnection(conn)` - Add a new connection
|
||||
- `removeConnection(edgeId)` - Remove connection by ID
|
||||
- `upsertMany(conns)` - Bulk update connections
|
||||
- `isInputConnected(nodeId, handle)` - Check input connection
|
||||
- `isOutputConnected(nodeId, handle)` - Check output connection
|
||||
- `getNodeConnections(nodeId)` - Get all connections for a node
|
||||
- `getBackendLinks()` - Convert to backend format
|
||||
|
||||
## Form Creator System
|
||||
|
||||
The FormCreator is a dynamic form generator built on React JSON Schema Form (RJSF) that automatically creates forms based on JSON schemas.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Schema Processing**: Input schemas are preprocessed to ensure all properties have types
|
||||
2. **Widget Mapping**: Schema types are mapped to appropriate input widgets
|
||||
3. **Field Rendering**: Custom fields handle complex data structures
|
||||
4. **State Management**: Form data is automatically synced with the node store
|
||||
|
||||
### Key Components
|
||||
|
||||
#### FormCreator
|
||||
|
||||
```typescript
|
||||
<FormCreator
|
||||
jsonSchema={preprocessedSchema}
|
||||
nodeId={nodeId}
|
||||
/>
|
||||
```
|
||||
|
||||
#### Custom Widgets
|
||||
|
||||
- `TextInputWidget` - Text, number, password inputs
|
||||
- `SelectWidget` - Dropdown and multi-select
|
||||
- `SwitchWidget` - Boolean toggles
|
||||
- `FileWidget` - File upload
|
||||
- `DateInputWidget` - Date picker
|
||||
- `TimeInputWidget` - Time picker
|
||||
- `DateTimeInputWidget` - DateTime picker
|
||||
|
||||
#### Custom Fields
|
||||
|
||||
- `AnyOfField` - Union type handling
|
||||
- `ObjectField` - Free-form object editing
|
||||
- `CredentialsField` - API credential management
|
||||
|
||||
#### Templates
|
||||
|
||||
- `FieldTemplate` - Custom field wrapper with handles
|
||||
- `ArrayFieldTemplate` - Array editing interface
|
||||
@@ -0,0 +1,59 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
BaseEdge,
|
||||
EdgeLabelRenderer,
|
||||
EdgeProps,
|
||||
getBezierPath,
|
||||
} from "@xyflow/react";
|
||||
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { XIcon } from "@phosphor-icons/react";
|
||||
|
||||
const CustomEdge = ({
|
||||
id,
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
selected,
|
||||
}: EdgeProps) => {
|
||||
const removeConnection = useEdgeStore((state) => state.removeConnection);
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
markerEnd={markerEnd}
|
||||
className={
|
||||
selected ? "[stroke:#555]" : "[stroke:#555]80 hover:[stroke:#555]"
|
||||
}
|
||||
/>
|
||||
<EdgeLabelRenderer>
|
||||
<Button
|
||||
onClick={() => removeConnection(id)}
|
||||
className={`absolute z-10 min-w-0 p-1`}
|
||||
variant="secondary"
|
||||
style={{
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: "all",
|
||||
}}
|
||||
>
|
||||
<XIcon className="h-3 w-3" weight="bold" />
|
||||
</Button>
|
||||
</EdgeLabelRenderer>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default CustomEdge;
|
||||
@@ -0,0 +1,12 @@
|
||||
import { Link } from "@/app/api/__generated__/models/link";
|
||||
import { Connection } from "@xyflow/react";
|
||||
|
||||
export const convertConnectionsToBackendLinks = (
|
||||
connections: Connection[],
|
||||
): Link[] =>
|
||||
connections.map((c) => ({
|
||||
source_id: c.source || "",
|
||||
sink_id: c.target || "",
|
||||
source_name: c.sourceHandle || "",
|
||||
sink_name: c.targetHandle || "",
|
||||
}));
|
||||
@@ -0,0 +1,70 @@
|
||||
import {
|
||||
Connection as RFConnection,
|
||||
Edge as RFEdge,
|
||||
MarkerType,
|
||||
EdgeChange,
|
||||
} from "@xyflow/react";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { useCallback, useMemo } from "react";
|
||||
|
||||
export const useCustomEdge = () => {
|
||||
const connections = useEdgeStore((s) => s.connections);
|
||||
const addConnection = useEdgeStore((s) => s.addConnection);
|
||||
const removeConnection = useEdgeStore((s) => s.removeConnection);
|
||||
|
||||
const edges: RFEdge[] = useMemo(
|
||||
() =>
|
||||
connections.map((c) => ({
|
||||
id: c.edge_id,
|
||||
type: "custom",
|
||||
source: c.source,
|
||||
target: c.target,
|
||||
sourceHandle: c.sourceHandle,
|
||||
targetHandle: c.targetHandle,
|
||||
markerEnd: {
|
||||
type: MarkerType.ArrowClosed,
|
||||
strokeWidth: 2,
|
||||
color: "#555",
|
||||
},
|
||||
})),
|
||||
[connections],
|
||||
);
|
||||
|
||||
const onConnect = useCallback(
|
||||
(conn: RFConnection) => {
|
||||
if (
|
||||
!conn.source ||
|
||||
!conn.target ||
|
||||
!conn.sourceHandle ||
|
||||
!conn.targetHandle
|
||||
)
|
||||
return;
|
||||
const exists = connections.some(
|
||||
(c) =>
|
||||
c.source === conn.source &&
|
||||
c.target === conn.target &&
|
||||
c.sourceHandle === conn.sourceHandle &&
|
||||
c.targetHandle === conn.targetHandle,
|
||||
);
|
||||
if (exists) return;
|
||||
addConnection({
|
||||
source: conn.source,
|
||||
target: conn.target,
|
||||
sourceHandle: conn.sourceHandle,
|
||||
targetHandle: conn.targetHandle,
|
||||
});
|
||||
},
|
||||
[connections, addConnection],
|
||||
);
|
||||
|
||||
const onEdgesChange = useCallback(
|
||||
(changes: EdgeChange[]) => {
|
||||
changes.forEach((ch) => {
|
||||
if (ch.type === "remove") removeConnection(ch.id);
|
||||
});
|
||||
},
|
||||
[removeConnection],
|
||||
);
|
||||
|
||||
return { edges, onConnect, onEdgesChange };
|
||||
};
|
||||
@@ -0,0 +1,32 @@
|
||||
import { CircleIcon } from "@phosphor-icons/react";
|
||||
import { Handle, Position } from "@xyflow/react";
|
||||
|
||||
const NodeHandle = ({
|
||||
id,
|
||||
isConnected,
|
||||
side,
|
||||
}: {
|
||||
id: string;
|
||||
isConnected: boolean;
|
||||
side: "left" | "right";
|
||||
}) => {
|
||||
console.log("id", id);
|
||||
return (
|
||||
<Handle
|
||||
type={side === "left" ? "target" : "source"}
|
||||
position={side === "left" ? Position.Left : Position.Right}
|
||||
id={id}
|
||||
className={side === "left" ? "-ml-4 mr-2" : "-mr-2 ml-2"}
|
||||
>
|
||||
<div className="pointer-events-none">
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={isConnected ? "fill" : "duotone"}
|
||||
className={"text-gray-400 opacity-100"}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
);
|
||||
};
|
||||
|
||||
export default NodeHandle;
|
||||
@@ -0,0 +1,117 @@
|
||||
/**
|
||||
* Handle ID Types for different input structures
|
||||
*
|
||||
* Examples:
|
||||
* SIMPLE: "message"
|
||||
* NESTED: "config.api_key"
|
||||
* ARRAY: "items_$_0", "items_$_1"
|
||||
* KEY_VALUE: "headers_#_Authorization", "params_#_limit"
|
||||
*
|
||||
* Note: All handle IDs are sanitized to remove spaces and special characters.
|
||||
* Spaces become underscores, and special characters are removed.
|
||||
* Example: "user name" becomes "user_name", "email@domain.com" becomes "emaildomaincom"
|
||||
*/
|
||||
export enum HandleIdType {
|
||||
SIMPLE = "SIMPLE",
|
||||
NESTED = "NESTED",
|
||||
ARRAY = "ARRAY",
|
||||
KEY_VALUE = "KEY_VALUE",
|
||||
}
|
||||
|
||||
const fromRjsfId = (id: string): string => {
|
||||
if (!id) return "";
|
||||
const parts = id.split("_");
|
||||
const filtered = parts.filter(
|
||||
(p) => p !== "root" && p !== "properties" && p.length > 0,
|
||||
);
|
||||
return filtered.join("_") || "";
|
||||
};
|
||||
|
||||
const sanitizeForHandleId = (str: string): string => {
|
||||
if (!str) return "";
|
||||
|
||||
return str
|
||||
.trim()
|
||||
.replace(/\s+/g, "_") // Replace spaces with underscores
|
||||
.replace(/[^a-zA-Z0-9_-]/g, "") // Remove special characters except underscores and hyphens
|
||||
.replace(/_+/g, "_") // Replace multiple consecutive underscores with single underscore
|
||||
.replace(/^_|_$/g, ""); // Remove leading/trailing underscores
|
||||
};
|
||||
|
||||
export const generateHandleId = (
|
||||
mainKey: string,
|
||||
nestedValues: string[] = [],
|
||||
type: HandleIdType = HandleIdType.SIMPLE,
|
||||
): string => {
|
||||
if (!mainKey) return "";
|
||||
|
||||
mainKey = fromRjsfId(mainKey);
|
||||
mainKey = sanitizeForHandleId(mainKey);
|
||||
|
||||
if (type === HandleIdType.SIMPLE || nestedValues.length === 0) {
|
||||
return mainKey;
|
||||
}
|
||||
|
||||
const sanitizedNestedValues = nestedValues.map((value) =>
|
||||
sanitizeForHandleId(value),
|
||||
);
|
||||
|
||||
switch (type) {
|
||||
case HandleIdType.NESTED:
|
||||
return [mainKey, ...sanitizedNestedValues].join(".");
|
||||
|
||||
case HandleIdType.ARRAY:
|
||||
return [mainKey, ...sanitizedNestedValues].join("_$_");
|
||||
|
||||
case HandleIdType.KEY_VALUE:
|
||||
return [mainKey, ...sanitizedNestedValues].join("_#_");
|
||||
|
||||
default:
|
||||
return mainKey;
|
||||
}
|
||||
};
|
||||
|
||||
export const parseHandleId = (
|
||||
handleId: string,
|
||||
): {
|
||||
mainKey: string;
|
||||
nestedValues: string[];
|
||||
type: HandleIdType;
|
||||
} => {
|
||||
if (!handleId) {
|
||||
return { mainKey: "", nestedValues: [], type: HandleIdType.SIMPLE };
|
||||
}
|
||||
|
||||
if (handleId.includes("_#_")) {
|
||||
const parts = handleId.split("_#_");
|
||||
return {
|
||||
mainKey: parts[0],
|
||||
nestedValues: parts.slice(1),
|
||||
type: HandleIdType.KEY_VALUE,
|
||||
};
|
||||
}
|
||||
|
||||
if (handleId.includes("_$_")) {
|
||||
const parts = handleId.split("_$_");
|
||||
return {
|
||||
mainKey: parts[0],
|
||||
nestedValues: parts.slice(1),
|
||||
type: HandleIdType.ARRAY,
|
||||
};
|
||||
}
|
||||
|
||||
if (handleId.includes(".")) {
|
||||
const parts = handleId.split(".");
|
||||
return {
|
||||
mainKey: parts[0],
|
||||
nestedValues: parts.slice(1),
|
||||
type: HandleIdType.NESTED,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
mainKey: handleId,
|
||||
nestedValues: [],
|
||||
type: HandleIdType.SIMPLE,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,69 @@
|
||||
import React from "react";
|
||||
import { Node as XYNode, NodeProps } from "@xyflow/react";
|
||||
import { FormCreator } from "./FormCreator";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { preprocessInputSchema } from "../processors/input-schema-pre-processor";
|
||||
import { OutputHandler } from "./OutputHandler";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
[key: string]: any;
|
||||
};
|
||||
title: string;
|
||||
description: string;
|
||||
inputSchema: RJSFSchema;
|
||||
outputSchema: RJSFSchema;
|
||||
};
|
||||
|
||||
export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id }) => {
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[id] || false,
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border border-slate-200/60 bg-gradient-to-br from-white to-slate-50/30 shadow-lg shadow-slate-900/5 backdrop-blur-sm">
|
||||
{/* Header */}
|
||||
<div className="flex h-14 items-center justify-center rounded-xl border-b border-slate-200/50 bg-gradient-to-r from-slate-50/80 to-white/90">
|
||||
<Text
|
||||
variant="large-semibold"
|
||||
className="tracking-tight text-slate-800"
|
||||
>
|
||||
{data.title} #{id}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{/* Input Handles */}
|
||||
<div className="bg-white/40 pb-6 pr-6">
|
||||
<FormCreator
|
||||
jsonSchema={preprocessInputSchema(data.inputSchema)}
|
||||
nodeId={id}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Advanced Button */}
|
||||
<div className="flex items-center justify-between gap-2 rounded-b-xl border-t border-slate-200/50 bg-gradient-to-r from-slate-50/60 to-white/80 px-5 py-3.5">
|
||||
<Text variant="body" className="font-medium text-slate-700">
|
||||
Advanced
|
||||
</Text>
|
||||
<Switch
|
||||
onCheckedChange={(checked) => setShowAdvanced(id, checked)}
|
||||
checked={showAdvanced}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Output Handles */}
|
||||
<OutputHandler outputSchema={data.outputSchema} nodeId={id} />
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
CustomNode.displayName = "CustomNode";
|
||||
@@ -0,0 +1,33 @@
|
||||
import Form from "@rjsf/core";
|
||||
import validator from "@rjsf/validator-ajv8";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import React from "react";
|
||||
import { widgets } from "./widgets";
|
||||
import { fields } from "./fields";
|
||||
import { templates } from "./templates";
|
||||
import { uiSchema } from "./uiSchema";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
export const FormCreator = React.memo(
|
||||
({ jsonSchema, nodeId }: { jsonSchema: RJSFSchema; nodeId: string }) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
const handleChange = ({ formData }: any) => {
|
||||
updateNodeData(nodeId, { hardcodedValues: formData });
|
||||
};
|
||||
|
||||
return (
|
||||
<Form
|
||||
schema={jsonSchema}
|
||||
validator={validator}
|
||||
fields={fields}
|
||||
templates={templates}
|
||||
widgets={widgets}
|
||||
formContext={{ nodeId: nodeId }}
|
||||
onChange={handleChange}
|
||||
uiSchema={uiSchema}
|
||||
/>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
FormCreator.displayName = "FormCreator";
|
||||
@@ -0,0 +1,90 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CaretDownIcon, InfoIcon } from "@phosphor-icons/react";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { useState } from "react";
|
||||
|
||||
import NodeHandle from "../handlers/NodeHandle";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { getTypeDisplayInfo } from "./helpers";
|
||||
|
||||
export const OutputHandler = ({
|
||||
outputSchema,
|
||||
nodeId,
|
||||
}: {
|
||||
outputSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
}) => {
|
||||
const { isOutputConnected } = useEdgeStore();
|
||||
const properties = outputSchema?.properties || {};
|
||||
const [isOutputVisible, setIsOutputVisible] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col items-end justify-between gap-2 rounded-b-xl border-t border-slate-200/50 bg-white py-3.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="mr-4 p-0"
|
||||
onClick={() => setIsOutputVisible(!isOutputVisible)}
|
||||
>
|
||||
<Text
|
||||
variant="body"
|
||||
className="flex items-center gap-2 font-medium text-slate-700"
|
||||
>
|
||||
Output{" "}
|
||||
<CaretDownIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className={`transition-transform ${isOutputVisible ? "rotate-180" : ""}`}
|
||||
/>
|
||||
</Text>
|
||||
</Button>
|
||||
|
||||
{
|
||||
<div className="flex flex-col items-end gap-2">
|
||||
{Object.entries(properties).map(([key, property]: [string, any]) => {
|
||||
const isConnected = isOutputConnected(nodeId, key);
|
||||
const shouldShow = isConnected || isOutputVisible;
|
||||
const { displayType, colorClass } = getTypeDisplayInfo(property);
|
||||
|
||||
return shouldShow ? (
|
||||
<div key={key} className="relative flex items-center gap-2">
|
||||
<Text
|
||||
variant="body"
|
||||
className="flex items-center gap-2 font-medium text-slate-700"
|
||||
>
|
||||
{property?.description && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span
|
||||
style={{ marginLeft: 6, cursor: "pointer" }}
|
||||
aria-label="info"
|
||||
tabIndex={0}
|
||||
>
|
||||
<InfoIcon />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{property?.description}</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
{property?.title || key}{" "}
|
||||
<Text variant="small" as="span" className={colorClass}>
|
||||
({displayType})
|
||||
</Text>
|
||||
</Text>
|
||||
<NodeHandle id={key} isConnected={isConnected} side="right" />
|
||||
</div>
|
||||
) : null;
|
||||
})}
|
||||
</div>
|
||||
}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,195 @@
|
||||
import React from "react";
|
||||
import { FieldProps, RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { Select } from "@/components/atoms/Select/Select";
|
||||
import { InputType, mapJsonSchemaTypeToInputType } from "../../helpers";
|
||||
|
||||
import { InfoIcon } from "@phosphor-icons/react";
|
||||
import { useAnyOfField } from "./useAnyOfField";
|
||||
import NodeHandle from "../../../handlers/NodeHandle";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { generateHandleId } from "../../../handlers/helpers";
|
||||
import { getTypeDisplayInfo } from "../../helpers";
|
||||
import merge from "lodash/merge";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
|
||||
type TypeOption = {
|
||||
type: string;
|
||||
title: string;
|
||||
index: number;
|
||||
format?: string;
|
||||
enum?: any[];
|
||||
secret?: boolean;
|
||||
schema: RJSFSchema;
|
||||
};
|
||||
|
||||
export const AnyOfField = ({
|
||||
schema,
|
||||
formData,
|
||||
onChange,
|
||||
name,
|
||||
idSchema,
|
||||
formContext,
|
||||
registry,
|
||||
uiSchema,
|
||||
disabled,
|
||||
onBlur,
|
||||
onFocus,
|
||||
}: FieldProps) => {
|
||||
const fieldKey = generateHandleId(idSchema.$id ?? "");
|
||||
const updatedFormContexrt = { ...formContext, fromAnyOf: true };
|
||||
|
||||
const { nodeId } = updatedFormContexrt;
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
const isConnected = isInputConnected(nodeId, fieldKey);
|
||||
const {
|
||||
isNullableType,
|
||||
nonNull,
|
||||
selectedType,
|
||||
handleTypeChange,
|
||||
handleNullableToggle,
|
||||
handleValueChange,
|
||||
currentTypeOption,
|
||||
isEnabled,
|
||||
typeOptions,
|
||||
} = useAnyOfField(schema, formData, onChange);
|
||||
|
||||
const renderInput = (typeOption: TypeOption) => {
|
||||
const optionSchema = (typeOption.schema || {
|
||||
type: typeOption.type,
|
||||
format: typeOption.format,
|
||||
secret: typeOption.secret,
|
||||
enum: typeOption.enum,
|
||||
}) as RJSFSchema;
|
||||
const inputType = mapJsonSchemaTypeToInputType(optionSchema);
|
||||
|
||||
// Help us to tell the field under the anyOf field that you are a part of anyOf field.
|
||||
// We can't use formContext in this case that's why we are using this.
|
||||
// We could use context api here, but i think it's better to keep it simple.
|
||||
const uiSchemaFromAnyOf = merge({}, uiSchema, {
|
||||
"ui:options": { fromAnyOf: true },
|
||||
});
|
||||
|
||||
// We are using SchemaField to render the field recursively.
|
||||
if (inputType === InputType.ARRAY_EDITOR) {
|
||||
const SchemaField = registry.fields.SchemaField;
|
||||
return (
|
||||
<div className="-ml-2">
|
||||
<SchemaField
|
||||
schema={optionSchema}
|
||||
formData={formData}
|
||||
idSchema={idSchema}
|
||||
uiSchema={uiSchemaFromAnyOf}
|
||||
onChange={handleValueChange}
|
||||
onBlur={onBlur}
|
||||
onFocus={onFocus}
|
||||
name={name}
|
||||
registry={registry}
|
||||
disabled={disabled}
|
||||
formContext={updatedFormContexrt}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const SchemaField = registry.fields.SchemaField;
|
||||
return (
|
||||
<div className="-ml-2">
|
||||
<SchemaField
|
||||
schema={optionSchema}
|
||||
formData={formData}
|
||||
idSchema={idSchema}
|
||||
uiSchema={uiSchemaFromAnyOf}
|
||||
onChange={handleValueChange}
|
||||
onBlur={onBlur}
|
||||
onFocus={onFocus}
|
||||
name={name}
|
||||
registry={registry}
|
||||
disabled={disabled}
|
||||
formContext={updatedFormContexrt}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// I am doing this, because we need different UI for optional types.
|
||||
if (isNullableType && nonNull) {
|
||||
const { displayType, colorClass } = getTypeDisplayInfo(nonNull);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<div className="-ml-2 flex items-center gap-1">
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
<Text variant="body">
|
||||
{name.charAt(0).toUpperCase() + name.slice(1)}
|
||||
</Text>
|
||||
<Text variant="small" className={colorClass}>
|
||||
({displayType} | null)
|
||||
</Text>
|
||||
</div>
|
||||
{!isConnected && (
|
||||
<Switch
|
||||
className="z-10"
|
||||
checked={isEnabled}
|
||||
onCheckedChange={handleNullableToggle}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{!isConnected && isEnabled && renderInput(nonNull)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col">
|
||||
<div className="-mb-3 -ml-2 flex items-center gap-1">
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
<Text variant="body">
|
||||
{name.charAt(0).toUpperCase() + name.slice(1)}
|
||||
</Text>
|
||||
{!isConnected && (
|
||||
<Select
|
||||
label=""
|
||||
id={`${name}-type-select`}
|
||||
hideLabel={true}
|
||||
value={selectedType}
|
||||
onValueChange={handleTypeChange}
|
||||
options={typeOptions.map((o) => {
|
||||
const { displayType } = getTypeDisplayInfo(o);
|
||||
return { value: o.type, label: displayType };
|
||||
})}
|
||||
size="small"
|
||||
wrapperClassName="!mb-0 "
|
||||
className="h-6 w-fit gap-1 pl-3 pr-2"
|
||||
/>
|
||||
)}
|
||||
|
||||
{schema.description && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span
|
||||
style={{ marginLeft: 6, cursor: "pointer" }}
|
||||
aria-label="info"
|
||||
tabIndex={0}
|
||||
>
|
||||
<InfoIcon />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{schema.description}</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
</div>
|
||||
{!isConnected && currentTypeOption && renderInput(currentTypeOption)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,97 @@
|
||||
import { useMemo, useState } from "react";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
const getDefaultValueForType = (type?: string): any => {
|
||||
if (!type) return "";
|
||||
|
||||
switch (type) {
|
||||
case "string":
|
||||
return "";
|
||||
case "number":
|
||||
case "integer":
|
||||
return 0;
|
||||
case "boolean":
|
||||
return false;
|
||||
case "array":
|
||||
return [];
|
||||
case "object":
|
||||
return {};
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
export const useAnyOfField = (
|
||||
schema: RJSFSchema,
|
||||
formData: any,
|
||||
onChange: (value: any) => void,
|
||||
) => {
|
||||
const typeOptions: any[] = useMemo(
|
||||
() =>
|
||||
schema.anyOf?.map((opt: any, i: number) => ({
|
||||
type: opt.type || "string",
|
||||
title: opt.title || `Option ${i + 1}`,
|
||||
index: i,
|
||||
format: opt.format,
|
||||
enum: opt.enum,
|
||||
secret: opt.secret,
|
||||
schema: opt,
|
||||
})) || [],
|
||||
[schema.anyOf],
|
||||
);
|
||||
|
||||
const isNullableType = useMemo(
|
||||
() =>
|
||||
typeOptions.length === 2 &&
|
||||
typeOptions.some((o) => o.type === "null") &&
|
||||
typeOptions.some((o) => o.type !== "null"),
|
||||
[typeOptions],
|
||||
);
|
||||
|
||||
const nonNull = useMemo(
|
||||
() => (isNullableType ? typeOptions.find((o) => o.type !== "null") : null),
|
||||
[isNullableType, typeOptions],
|
||||
);
|
||||
|
||||
const initialSelectedType = useMemo(() => {
|
||||
const def = schema.default;
|
||||
const first = typeOptions[0]?.type || "string";
|
||||
if (isNullableType) return nonNull?.type || "string";
|
||||
if (typeof def === "string" && typeOptions.some((o) => o.type === def))
|
||||
return def;
|
||||
return first;
|
||||
}, [schema.default, typeOptions, isNullableType, nonNull?.type]);
|
||||
|
||||
const [selectedType, setSelectedType] = useState<string>(initialSelectedType);
|
||||
|
||||
const isEnabled = formData !== null && formData !== undefined;
|
||||
|
||||
const handleTypeChange = (t: string) => {
|
||||
setSelectedType(t);
|
||||
onChange(undefined); // clear current value when switching type
|
||||
};
|
||||
|
||||
const handleNullableToggle = (checked: boolean) => {
|
||||
if (checked) {
|
||||
onChange(getDefaultValueForType(nonNull?.type));
|
||||
} else {
|
||||
onChange(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleValueChange = (value: any) => onChange(value);
|
||||
|
||||
const currentTypeOption = typeOptions.find((o) => o.type === selectedType);
|
||||
|
||||
return {
|
||||
isNullableType,
|
||||
nonNull,
|
||||
selectedType,
|
||||
handleTypeChange,
|
||||
handleNullableToggle,
|
||||
handleValueChange,
|
||||
currentTypeOption,
|
||||
isEnabled,
|
||||
typeOptions,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,34 @@
|
||||
import React from "react";
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
|
||||
// We need to add all the logic for the credential fields here
|
||||
export const CredentialsField = (props: FieldProps) => {
|
||||
const { formData = {}, onChange, required: _required, schema } = props;
|
||||
|
||||
const _credentialProvider = schema.credentials_provider;
|
||||
const _credentialType = schema.credentials_types;
|
||||
const _description = schema.description;
|
||||
const _title = schema.title;
|
||||
|
||||
// Helper to update one property
|
||||
const setField = (key: string, value: any) =>
|
||||
onChange({ ...formData, [key]: value });
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Input
|
||||
hideLabel={true}
|
||||
label={""}
|
||||
id="credentials-id"
|
||||
type="text"
|
||||
value={formData.id || ""}
|
||||
onChange={(e) => setField("id", e.target.value)}
|
||||
placeholder="Enter your API Key"
|
||||
required
|
||||
size="small"
|
||||
wrapperClassName="mb-0"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,41 @@
|
||||
import React from "react";
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { getDefaultRegistry } from "@rjsf/core";
|
||||
import { generateHandleId } from "../../handlers/helpers";
|
||||
import { ObjectEditor } from "../../components/ObjectEditor/ObjectEditor";
|
||||
|
||||
export const ObjectField = (props: FieldProps) => {
|
||||
const {
|
||||
schema,
|
||||
formData = {},
|
||||
onChange,
|
||||
name,
|
||||
idSchema,
|
||||
formContext,
|
||||
} = props;
|
||||
const DefaultObjectField = getDefaultRegistry().fields.ObjectField;
|
||||
|
||||
// Let the default field render for root or fixed-schema objects
|
||||
const isFreeForm =
|
||||
!schema.properties ||
|
||||
Object.keys(schema.properties).length === 0 ||
|
||||
schema.additionalProperties === true;
|
||||
|
||||
if (idSchema?.$id === "root" || !isFreeForm) {
|
||||
return <DefaultObjectField {...props} />;
|
||||
}
|
||||
|
||||
const fieldKey = generateHandleId(idSchema.$id ?? "");
|
||||
const { nodeId } = formContext;
|
||||
|
||||
return (
|
||||
<ObjectEditor
|
||||
id={`${name}-input`}
|
||||
nodeId={nodeId}
|
||||
fieldKey={fieldKey}
|
||||
value={formData}
|
||||
onChange={onChange}
|
||||
placeholder={`Enter ${name || "Contact Data"}`}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,10 @@
|
||||
import { RegistryFieldsType } from "@rjsf/utils";
|
||||
import { CredentialsField } from "./CredentialField";
|
||||
import { AnyOfField } from "./AnyOfField/AnyOfField";
|
||||
import { ObjectField } from "./ObjectField";
|
||||
|
||||
export const fields: RegistryFieldsType = {
|
||||
AnyOfField: AnyOfField,
|
||||
credentials: CredentialsField,
|
||||
ObjectField: ObjectField,
|
||||
};
|
||||
@@ -0,0 +1,141 @@
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
export enum InputType {
|
||||
SINGLE_LINE_TEXT = "single-line-text",
|
||||
TEXT_AREA = "text-area",
|
||||
PASSWORD = "password",
|
||||
FILE = "file",
|
||||
DATE = "date",
|
||||
TIME = "time",
|
||||
DATE_TIME = "datetime",
|
||||
NUMBER = "number",
|
||||
INTEGER = "integer",
|
||||
SWITCH = "switch",
|
||||
ARRAY_EDITOR = "array-editor",
|
||||
SELECT = "select",
|
||||
MULTI_SELECT = "multi-select",
|
||||
OBJECT_EDITOR = "object-editor",
|
||||
ENUM = "enum",
|
||||
}
|
||||
|
||||
// This helper function maps a JSONSchema type to an InputType [help us to determine the type of the input]
|
||||
export function mapJsonSchemaTypeToInputType(
|
||||
schema: RJSFSchema,
|
||||
): InputType | undefined {
|
||||
if (schema.type === "string") {
|
||||
if (schema.secret) return InputType.PASSWORD;
|
||||
if (schema.format === "date") return InputType.DATE;
|
||||
if (schema.format === "time") return InputType.TIME;
|
||||
if (schema.format === "date-time") return InputType.DATE_TIME;
|
||||
if (schema.format === "long-text") return InputType.TEXT_AREA;
|
||||
if (schema.format === "short-text") return InputType.SINGLE_LINE_TEXT;
|
||||
if (schema.format === "file") return InputType.FILE;
|
||||
return InputType.SINGLE_LINE_TEXT;
|
||||
}
|
||||
|
||||
if (schema.type === "number") return InputType.NUMBER;
|
||||
if (schema.type === "integer") return InputType.INTEGER;
|
||||
if (schema.type === "boolean") return InputType.SWITCH;
|
||||
|
||||
if (schema.type === "array") {
|
||||
if (
|
||||
schema.items &&
|
||||
typeof schema.items === "object" &&
|
||||
!Array.isArray(schema.items) &&
|
||||
schema.items.enum
|
||||
) {
|
||||
return InputType.MULTI_SELECT;
|
||||
}
|
||||
console.log("schema", schema);
|
||||
return InputType.ARRAY_EDITOR;
|
||||
}
|
||||
|
||||
if (schema.type === "object") {
|
||||
return InputType.OBJECT_EDITOR;
|
||||
}
|
||||
|
||||
if (schema.enum) {
|
||||
return InputType.SELECT;
|
||||
}
|
||||
|
||||
if (schema.type === "null") return;
|
||||
|
||||
if (schema.anyOf || schema.oneOf) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return InputType.SINGLE_LINE_TEXT;
|
||||
}
|
||||
|
||||
// Helper to extract options from schema
|
||||
export function extractOptions(
|
||||
schema: any,
|
||||
): { value: string; label: string }[] {
|
||||
if (schema.enum) {
|
||||
return schema.enum.map((value: any) => ({
|
||||
value: String(value),
|
||||
label: String(value),
|
||||
}));
|
||||
}
|
||||
|
||||
if (schema.type === "array" && schema.items?.enum) {
|
||||
return schema.items.enum.map((value: any) => ({
|
||||
value: String(value),
|
||||
label: String(value),
|
||||
}));
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
// get display type and color for schema types [need for type display next to field name]
|
||||
export const getTypeDisplayInfo = (schema: any) => {
|
||||
if (schema?.type === "string" && schema?.format) {
|
||||
const formatMap: Record<
|
||||
string,
|
||||
{ displayType: string; colorClass: string }
|
||||
> = {
|
||||
file: { displayType: "file", colorClass: "!text-green-500" },
|
||||
date: { displayType: "date", colorClass: "!text-blue-500" },
|
||||
time: { displayType: "time", colorClass: "!text-blue-500" },
|
||||
"date-time": { displayType: "datetime", colorClass: "!text-blue-500" },
|
||||
"long-text": { displayType: "text", colorClass: "!text-green-500" },
|
||||
"short-text": { displayType: "text", colorClass: "!text-green-500" },
|
||||
};
|
||||
|
||||
const formatInfo = formatMap[schema.format];
|
||||
if (formatInfo) {
|
||||
return formatInfo;
|
||||
}
|
||||
}
|
||||
|
||||
const typeMap: Record<string, string> = {
|
||||
string: "text",
|
||||
number: "number",
|
||||
integer: "integer",
|
||||
boolean: "true/false",
|
||||
object: "object",
|
||||
array: "list",
|
||||
null: "null",
|
||||
};
|
||||
|
||||
const displayType = typeMap[schema?.type] || schema?.type || "any";
|
||||
|
||||
const colorMap: Record<string, string> = {
|
||||
string: "!text-green-500",
|
||||
number: "!text-blue-500",
|
||||
integer: "!text-blue-500",
|
||||
boolean: "!text-yellow-500",
|
||||
object: "!text-purple-500",
|
||||
array: "!text-indigo-500",
|
||||
null: "!text-gray-500",
|
||||
any: "!text-gray-500",
|
||||
};
|
||||
|
||||
const colorClass = colorMap[schema?.type] || "!text-gray-500";
|
||||
|
||||
return {
|
||||
displayType,
|
||||
colorClass,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,30 @@
|
||||
import React from "react";
|
||||
import { ArrayFieldTemplateProps } from "@rjsf/utils";
|
||||
import { ArrayEditor } from "../../components/ArrayEditor/ArrayEditor";
|
||||
|
||||
function ArrayFieldTemplate(props: ArrayFieldTemplateProps) {
|
||||
const {
|
||||
items,
|
||||
canAdd,
|
||||
onAddClick,
|
||||
disabled,
|
||||
readonly,
|
||||
formContext,
|
||||
idSchema,
|
||||
} = props;
|
||||
const { nodeId } = formContext;
|
||||
|
||||
return (
|
||||
<ArrayEditor
|
||||
items={items}
|
||||
nodeId={nodeId}
|
||||
canAdd={canAdd}
|
||||
onAddClick={onAddClick}
|
||||
disabled={disabled}
|
||||
readonly={readonly}
|
||||
id={idSchema.$id}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default ArrayFieldTemplate;
|
||||
@@ -0,0 +1,103 @@
|
||||
import React, { useContext } from "react";
|
||||
import { FieldTemplateProps } from "@rjsf/utils";
|
||||
import { InfoIcon } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
|
||||
import NodeHandle from "../../handlers/NodeHandle";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { generateHandleId } from "../../handlers/helpers";
|
||||
import { getTypeDisplayInfo } from "../helpers";
|
||||
import { ArrayEditorContext } from "../../components/ArrayEditor/ArrayEditorContext";
|
||||
|
||||
const FieldTemplate: React.FC<FieldTemplateProps> = ({
|
||||
id,
|
||||
label,
|
||||
required,
|
||||
description,
|
||||
children,
|
||||
schema,
|
||||
formContext,
|
||||
uiSchema,
|
||||
}) => {
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
const { nodeId } = formContext;
|
||||
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[nodeId] ?? false,
|
||||
);
|
||||
|
||||
const {
|
||||
isArrayItem,
|
||||
fieldKey: arrayFieldKey,
|
||||
isConnected: isArrayItemConnected,
|
||||
} = useContext(ArrayEditorContext);
|
||||
|
||||
let fieldKey = generateHandleId(id);
|
||||
let isConnected = isInputConnected(nodeId, fieldKey);
|
||||
if (isArrayItem) {
|
||||
fieldKey = arrayFieldKey;
|
||||
isConnected = isArrayItemConnected;
|
||||
}
|
||||
const isAnyOf = Array.isArray((schema as any)?.anyOf);
|
||||
const isOneOf = Array.isArray((schema as any)?.oneOf);
|
||||
const suppressHandle = isAnyOf || isOneOf;
|
||||
|
||||
if (!showAdvanced && schema.advanced === true && !isConnected) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const fromAnyOf =
|
||||
Boolean((uiSchema as any)?.["ui:options"]?.fromAnyOf) ||
|
||||
Boolean((formContext as any)?.fromAnyOf);
|
||||
|
||||
const { displayType, colorClass } = getTypeDisplayInfo(schema);
|
||||
|
||||
return (
|
||||
<div className="mt-4 w-[400px] space-y-1">
|
||||
{label && schema.type && (
|
||||
<label htmlFor={id} className="flex items-center gap-1">
|
||||
{!suppressHandle && !fromAnyOf && (
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
)}
|
||||
{!fromAnyOf && (
|
||||
<Text variant="body" className="line-clamp-1">
|
||||
{label}
|
||||
</Text>
|
||||
)}
|
||||
{!fromAnyOf && (
|
||||
<Text variant="small" className={colorClass}>
|
||||
({displayType})
|
||||
</Text>
|
||||
)}
|
||||
{required && <span style={{ color: "red" }}>*</span>}
|
||||
{description?.props?.description && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span
|
||||
style={{ marginLeft: 6, cursor: "pointer" }}
|
||||
aria-label="info"
|
||||
tabIndex={0}
|
||||
>
|
||||
<InfoIcon />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{description}</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
{(isAnyOf || !isConnected) && <div className="pl-2">{children}</div>}{" "}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default FieldTemplate;
|
||||
@@ -0,0 +1,10 @@
|
||||
import ArrayFieldTemplate from "./ArrayFieldTemplate";
|
||||
import FieldTemplate from "./FieldTemplate";
|
||||
|
||||
const NoSubmitButton = () => null;
|
||||
|
||||
export const templates = {
|
||||
FieldTemplate,
|
||||
ButtonTemplates: { SubmitButton: NoSubmitButton },
|
||||
ArrayFieldTemplate,
|
||||
};
|
||||
@@ -0,0 +1,12 @@
|
||||
export const uiSchema = {
|
||||
credentials: {
|
||||
"ui:field": "credentials",
|
||||
provider: { "ui:widget": "hidden" },
|
||||
type: { "ui:widget": "hidden" },
|
||||
id: { "ui:autofocus": true },
|
||||
title: { "ui:placeholder": "Optional title" },
|
||||
},
|
||||
properties: {
|
||||
"ui:field": "CustomObjectField",
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,23 @@
|
||||
import * as React from "react";
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { DateInput } from "@/components/atoms/DateInput/DateInput";
|
||||
|
||||
export const DateInputWidget = (props: WidgetProps) => {
|
||||
const { value, onChange, disabled, readonly, placeholder, autofocus, id } =
|
||||
props;
|
||||
|
||||
return (
|
||||
<DateInput
|
||||
size="small"
|
||||
id={id}
|
||||
hideLabel={true}
|
||||
label={""}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
readonly={readonly}
|
||||
autoFocus={autofocus}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,21 @@
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { DateTimeInput } from "@/components/atoms/DateTimeInput/DateTimeInput";
|
||||
|
||||
export const DateTimeInputWidget = (props: WidgetProps) => {
|
||||
const { value, onChange, disabled, readonly, placeholder, autofocus, id } =
|
||||
props;
|
||||
return (
|
||||
<DateTimeInput
|
||||
size="small"
|
||||
id={id}
|
||||
hideLabel={true}
|
||||
label={""}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
readonly={readonly}
|
||||
autoFocus={autofocus}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,33 @@
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { Input } from "@/components/__legacy__/ui/input";
|
||||
|
||||
export const FileWidget = (props: WidgetProps) => {
|
||||
const { onChange, multiple = false, disabled, readonly, id } = props;
|
||||
|
||||
// TODO: It's temporary solution for file input, will complete it follow up prs
|
||||
const handleChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const files = event.target.files;
|
||||
if (!files || files.length === 0) {
|
||||
onChange(undefined);
|
||||
return;
|
||||
}
|
||||
|
||||
const file = files[0];
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
onChange(e.target?.result);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
};
|
||||
|
||||
return (
|
||||
<Input
|
||||
id={id}
|
||||
type="file"
|
||||
multiple={multiple}
|
||||
disabled={disabled || readonly}
|
||||
onChange={handleChange}
|
||||
className="rounded-full"
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,62 @@
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { InputType, mapJsonSchemaTypeToInputType } from "../helpers";
|
||||
import { Select } from "@/components/atoms/Select/Select";
|
||||
import {
|
||||
MultiSelector,
|
||||
MultiSelectorContent,
|
||||
MultiSelectorInput,
|
||||
MultiSelectorItem,
|
||||
MultiSelectorList,
|
||||
MultiSelectorTrigger,
|
||||
} from "@/components/__legacy__/ui/multiselect";
|
||||
|
||||
export const SelectWidget = (props: WidgetProps) => {
|
||||
const { options, value, onChange, disabled, readonly, id } = props;
|
||||
const enumOptions = options.enumOptions || [];
|
||||
const type = mapJsonSchemaTypeToInputType(props.schema);
|
||||
|
||||
const renderInput = () => {
|
||||
if (type === InputType.MULTI_SELECT) {
|
||||
return (
|
||||
<MultiSelector
|
||||
values={Array.isArray(value) ? value : []}
|
||||
onValuesChange={onChange}
|
||||
className="w-full"
|
||||
>
|
||||
<MultiSelectorTrigger>
|
||||
<MultiSelectorInput placeholder="Select options..." />
|
||||
</MultiSelectorTrigger>
|
||||
<MultiSelectorContent>
|
||||
<MultiSelectorList>
|
||||
{enumOptions?.map((option) => (
|
||||
<MultiSelectorItem key={option.value} value={option.value}>
|
||||
{option.label}
|
||||
</MultiSelectorItem>
|
||||
))}
|
||||
</MultiSelectorList>
|
||||
</MultiSelectorContent>
|
||||
</MultiSelector>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Select
|
||||
label=""
|
||||
id={id}
|
||||
hideLabel={true}
|
||||
disabled={disabled || readonly}
|
||||
size="small"
|
||||
value={value ?? ""}
|
||||
onValueChange={onChange}
|
||||
options={
|
||||
enumOptions?.map((option) => ({
|
||||
value: option.value,
|
||||
label: option.label,
|
||||
})) || []
|
||||
}
|
||||
wrapperClassName="!mb-0 "
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
return renderInput();
|
||||
};
|
||||
@@ -0,0 +1,15 @@
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
|
||||
export function SwitchWidget(props: WidgetProps) {
|
||||
const { value = false, onChange, disabled, readonly, autofocus, id } = props;
|
||||
return (
|
||||
<Switch
|
||||
id={id}
|
||||
checked={Boolean(value)}
|
||||
onCheckedChange={(checked) => onChange(checked)}
|
||||
disabled={disabled || readonly}
|
||||
autoFocus={autofocus}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { InputType, mapJsonSchemaTypeToInputType } from "../helpers";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
|
||||
export const TextInputWidget = (props: WidgetProps) => {
|
||||
const { schema } = props;
|
||||
const mapped = mapJsonSchemaTypeToInputType(schema);
|
||||
|
||||
type InputConfig = {
|
||||
htmlType: string;
|
||||
placeholder: string;
|
||||
handleChange: (v: string) => any;
|
||||
};
|
||||
|
||||
const inputConfig: Partial<Record<InputType, InputConfig>> = {
|
||||
[InputType.TEXT_AREA]: {
|
||||
htmlType: "textarea",
|
||||
placeholder: "Enter text...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : v),
|
||||
},
|
||||
[InputType.PASSWORD]: {
|
||||
htmlType: "password",
|
||||
placeholder: "Enter secret text...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : v),
|
||||
},
|
||||
[InputType.NUMBER]: {
|
||||
htmlType: "number",
|
||||
placeholder: "Enter number value...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : Number(v)),
|
||||
},
|
||||
[InputType.INTEGER]: {
|
||||
htmlType: "account",
|
||||
placeholder: "Enter integer value...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : Number(v)),
|
||||
},
|
||||
};
|
||||
|
||||
const defaultConfig: InputConfig = {
|
||||
htmlType: "text",
|
||||
placeholder: "Enter string value...",
|
||||
handleChange: (v: string) => (v === "" ? undefined : v),
|
||||
};
|
||||
|
||||
const config = (mapped && inputConfig[mapped]) || defaultConfig;
|
||||
|
||||
const handleChange = (
|
||||
e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>,
|
||||
) => {
|
||||
const v = e.target.value;
|
||||
return props.onChange(config.handleChange(v));
|
||||
};
|
||||
|
||||
return (
|
||||
<Input
|
||||
id={props.id}
|
||||
hideLabel={true}
|
||||
type={config.htmlType as any}
|
||||
label={""}
|
||||
size="small"
|
||||
wrapperClassName="mb-0"
|
||||
value={props.value ?? ""}
|
||||
onChange={handleChange as any}
|
||||
placeholder={schema.placeholder || config.placeholder}
|
||||
required={props.required}
|
||||
disabled={props.disabled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,20 @@
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { TimeInput } from "@/components/atoms/TimeInput/TimeInput";
|
||||
|
||||
export const TimeInputWidget = (props: WidgetProps) => {
|
||||
const { value, onChange, disabled, readonly, placeholder, id } = props;
|
||||
return (
|
||||
<TimeInput
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
className="w-full"
|
||||
label={""}
|
||||
id={id}
|
||||
hideLabel={true}
|
||||
size="small"
|
||||
wrapperClassName="!mb-0 "
|
||||
disabled={disabled || readonly}
|
||||
placeholder={placeholder}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,18 @@
|
||||
import { RegistryWidgetsType } from "@rjsf/utils";
|
||||
import { SelectWidget } from "./SelectWidget";
|
||||
import { TextInputWidget } from "./TextInputWidget";
|
||||
import { SwitchWidget } from "./SwitchWidget";
|
||||
import { FileWidget } from "./FileWidget";
|
||||
import { DateInputWidget } from "./DateInputWidget";
|
||||
import { TimeInputWidget } from "./TimeInputWidget";
|
||||
import { DateTimeInputWidget } from "./DateTimeInputWidget";
|
||||
|
||||
export const widgets: RegistryWidgetsType = {
|
||||
TextWidget: TextInputWidget,
|
||||
SelectWidget: SelectWidget,
|
||||
CheckboxWidget: SwitchWidget,
|
||||
FileWidget: FileWidget,
|
||||
DateWidget: DateInputWidget,
|
||||
TimeWidget: TimeInputWidget,
|
||||
DateTimeWidget: DateTimeInputWidget,
|
||||
};
|
||||
@@ -0,0 +1,112 @@
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
/**
|
||||
* Pre-processes the input schema to ensure all properties have a type defined.
|
||||
* If a property doesn't have a type, it assigns a union of all supported JSON Schema types.
|
||||
*/
|
||||
export function preprocessInputSchema(schema: RJSFSchema): RJSFSchema {
|
||||
if (!schema || typeof schema !== "object") {
|
||||
return schema;
|
||||
}
|
||||
|
||||
const processedSchema = { ...schema };
|
||||
|
||||
// Recursively process properties
|
||||
if (processedSchema.properties) {
|
||||
processedSchema.properties = { ...processedSchema.properties };
|
||||
|
||||
for (const [key, property] of Object.entries(processedSchema.properties)) {
|
||||
if (property && typeof property === "object") {
|
||||
const processedProperty = { ...property };
|
||||
|
||||
// Only add type if no type is defined AND no anyOf/oneOf/allOf is present
|
||||
if (
|
||||
!processedProperty.type &&
|
||||
!processedProperty.anyOf &&
|
||||
!processedProperty.oneOf &&
|
||||
!processedProperty.allOf
|
||||
) {
|
||||
processedProperty.anyOf = [
|
||||
{ type: "string" },
|
||||
{ type: "number" },
|
||||
{ type: "integer" },
|
||||
{ type: "boolean" },
|
||||
{ type: "array", items: { type: "string" } },
|
||||
{ type: "object" },
|
||||
{ type: "null" },
|
||||
];
|
||||
}
|
||||
|
||||
// when encountering an array with items missing type
|
||||
if (processedProperty.type === "array" && processedProperty.items) {
|
||||
const items = processedProperty.items as RJSFSchema;
|
||||
if (!items.type && !items.anyOf && !items.oneOf && !items.allOf) {
|
||||
processedProperty.items = {
|
||||
type: "string",
|
||||
title: items.title ?? "",
|
||||
};
|
||||
} else {
|
||||
processedProperty.items = preprocessInputSchema(items);
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively process nested objects
|
||||
if (
|
||||
processedProperty.type === "object" ||
|
||||
(Array.isArray(processedProperty.type) &&
|
||||
processedProperty.type.includes("object"))
|
||||
) {
|
||||
processedProperty.properties = processProperties(
|
||||
processedProperty.properties,
|
||||
);
|
||||
}
|
||||
|
||||
// Process array items
|
||||
if (
|
||||
processedProperty.type === "array" ||
|
||||
(Array.isArray(processedProperty.type) &&
|
||||
processedProperty.type.includes("array"))
|
||||
) {
|
||||
if (processedProperty.items) {
|
||||
processedProperty.items = preprocessInputSchema(
|
||||
processedProperty.items as RJSFSchema,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
processedSchema.properties[key] = processedProperty;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process array items at root level
|
||||
if (processedSchema.items) {
|
||||
processedSchema.items = preprocessInputSchema(
|
||||
processedSchema.items as RJSFSchema,
|
||||
);
|
||||
}
|
||||
|
||||
processedSchema.title = ""; // Otherwise our form creator will show the title of the schema in the input field
|
||||
processedSchema.description = ""; // Otherwise our form creator will show the description of the schema in the input field
|
||||
|
||||
return processedSchema;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to process properties object
|
||||
*/
|
||||
function processProperties(properties: any): any {
|
||||
if (!properties || typeof properties !== "object") {
|
||||
return properties;
|
||||
}
|
||||
|
||||
const processedProperties = { ...properties };
|
||||
|
||||
for (const [key, property] of Object.entries(processedProperties)) {
|
||||
if (property && typeof property === "object") {
|
||||
processedProperties[key] = preprocessInputSchema(property as RJSFSchema);
|
||||
}
|
||||
}
|
||||
|
||||
return processedProperties;
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import { beautifyString } from "@/lib/utils";
|
||||
import { useAllBlockContent } from "./useAllBlockContent";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { blockMenuContainerStyle } from "../style";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
export const AllBlocksContent = () => {
|
||||
const {
|
||||
@@ -18,6 +19,8 @@ export const AllBlocksContent = () => {
|
||||
isErrorOnLoadingMore,
|
||||
} = useAllBlockContent();
|
||||
|
||||
const addBlock = useNodeStore((state) => state.addBlock);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className={blockMenuContainerStyle}>
|
||||
@@ -71,6 +74,7 @@ export const AllBlocksContent = () => {
|
||||
key={`${category.name}-${block.id}`}
|
||||
title={block.name as string}
|
||||
description={block.name as string}
|
||||
onClick={() => addBlock(block)}
|
||||
/>
|
||||
))}
|
||||
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
import React from "react";
|
||||
import { Block } from "../Block";
|
||||
import { blockMenuContainerStyle } from "../style";
|
||||
|
||||
export interface BlockType {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
category?: string;
|
||||
type?: string;
|
||||
provider?: string;
|
||||
}
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
|
||||
interface BlocksListProps {
|
||||
blocks: BlockType[];
|
||||
blocks: BlockInfo[];
|
||||
loading?: boolean;
|
||||
}
|
||||
|
||||
@@ -20,6 +13,7 @@ export const BlocksList: React.FC<BlocksListProps> = ({
|
||||
blocks,
|
||||
loading = false,
|
||||
}) => {
|
||||
const { addBlock } = useNodeStore();
|
||||
if (loading) {
|
||||
return (
|
||||
<div className={blockMenuContainerStyle}>
|
||||
@@ -30,6 +24,11 @@ export const BlocksList: React.FC<BlocksListProps> = ({
|
||||
);
|
||||
}
|
||||
return blocks.map((block) => (
|
||||
<Block key={block.id} title={block.name} description={block.description} />
|
||||
<Block
|
||||
key={block.id}
|
||||
title={block.name}
|
||||
description={block.description}
|
||||
onClick={() => addBlock(block)}
|
||||
/>
|
||||
));
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@ import { BlockMenuStateProvider } from "../block-menu-provider";
|
||||
import { LegoIcon } from "@phosphor-icons/react";
|
||||
|
||||
interface BlockMenuProps {
|
||||
pinBlocksPopover: boolean;
|
||||
// pinBlocksPopover: boolean;
|
||||
blockMenuSelected: "save" | "block" | "search" | "";
|
||||
setBlockMenuSelected: React.Dispatch<
|
||||
React.SetStateAction<"" | "save" | "block" | "search">
|
||||
@@ -19,16 +19,17 @@ interface BlockMenuProps {
|
||||
}
|
||||
|
||||
export const BlockMenu: React.FC<BlockMenuProps> = ({
|
||||
pinBlocksPopover,
|
||||
// pinBlocksPopover,
|
||||
blockMenuSelected,
|
||||
setBlockMenuSelected,
|
||||
}) => {
|
||||
const { open, onOpen } = useBlockMenu({
|
||||
pinBlocksPopover,
|
||||
const { open: _open, onOpen } = useBlockMenu({
|
||||
// pinBlocksPopover,
|
||||
setBlockMenuSelected,
|
||||
});
|
||||
return (
|
||||
<Popover open={pinBlocksPopover ? true : open} onOpenChange={onOpen}>
|
||||
// pinBlocksPopover ? true : open
|
||||
<Popover onOpenChange={onOpen}>
|
||||
<PopoverTrigger className="hover:cursor-pointer">
|
||||
<ControlPanelButton
|
||||
data-id="blocks-control-popover-trigger"
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
import { useState } from "react";
|
||||
|
||||
interface useBlockMenuProps {
|
||||
pinBlocksPopover: boolean;
|
||||
// pinBlocksPopover: boolean;
|
||||
setBlockMenuSelected: React.Dispatch<
|
||||
React.SetStateAction<"" | "save" | "block" | "search">
|
||||
>;
|
||||
}
|
||||
|
||||
export const useBlockMenu = ({
|
||||
pinBlocksPopover,
|
||||
// pinBlocksPopover,
|
||||
setBlockMenuSelected,
|
||||
}: useBlockMenuProps) => {
|
||||
const [open, setOpen] = useState(false);
|
||||
const onOpen = (newOpen: boolean) => {
|
||||
if (!pinBlocksPopover) {
|
||||
setOpen(newOpen);
|
||||
setBlockMenuSelected(newOpen ? "block" : "");
|
||||
}
|
||||
// if (!pinBlocksPopover) {
|
||||
setOpen(newOpen);
|
||||
setBlockMenuSelected(newOpen ? "block" : "");
|
||||
// }
|
||||
};
|
||||
|
||||
return {
|
||||
|
||||
@@ -6,6 +6,7 @@ import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { useIntegrationBlocks } from "./useIntegrationBlocks";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
export const IntegrationBlocks = () => {
|
||||
const { integration, setIntegration } = useBlockMenuContext();
|
||||
@@ -20,6 +21,7 @@ export const IntegrationBlocks = () => {
|
||||
error,
|
||||
refetch,
|
||||
} = useIntegrationBlocks();
|
||||
const addBlock = useNodeStore((state) => state.addBlock);
|
||||
|
||||
if (blocksLoading) {
|
||||
return (
|
||||
@@ -92,6 +94,7 @@ export const IntegrationBlocks = () => {
|
||||
title={block.name}
|
||||
description={block.description}
|
||||
icon_url={`/integrations/${integration}.png`}
|
||||
onClick={() => addBlock(block)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
// import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { cn } from "@/lib/utils";
|
||||
import React, { useMemo } from "react";
|
||||
import { BlockMenu } from "../BlockMenu/BlockMenu";
|
||||
import { useNewControlPanel } from "./useNewControlPanel";
|
||||
import { NewSaveControl } from "../SaveControl/NewSaveControl";
|
||||
// import { NewSaveControl } from "../SaveControl/NewSaveControl";
|
||||
import { GraphExecutionID } from "@/lib/autogpt-server-api";
|
||||
import { history } from "@/app/(platform)/build/components/legacy-builder/history";
|
||||
import { ControlPanelButton } from "../ControlPanelButton";
|
||||
// import { ControlPanelButton } from "../ControlPanelButton";
|
||||
import { ArrowUUpLeftIcon, ArrowUUpRightIcon } from "@phosphor-icons/react";
|
||||
import { GraphSearchMenu } from "../GraphMenu/GraphMenu";
|
||||
import { CustomNode } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
// import { GraphSearchMenu } from "../GraphMenu/GraphMenu";
|
||||
import { CustomNode } from "../../FlowEditor/nodes/CustomNode";
|
||||
import { history } from "@/app/(platform)/build/components/legacy-builder/history";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
export type Control = {
|
||||
@@ -19,44 +19,41 @@ export type Control = {
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
interface ControlPanelProps {
|
||||
className?: string;
|
||||
flowExecutionID: GraphExecutionID | undefined;
|
||||
visualizeBeads: "no" | "static" | "animate";
|
||||
pinSavePopover: boolean;
|
||||
pinBlocksPopover: boolean;
|
||||
nodes: CustomNode[];
|
||||
onNodeSelect: (nodeId: string) => void;
|
||||
onNodeHover?: (nodeId: string | null) => void;
|
||||
}
|
||||
|
||||
export type NewControlPanelProps = {
|
||||
flowExecutionID?: GraphExecutionID | undefined;
|
||||
visualizeBeads?: "no" | "static" | "animate";
|
||||
pinSavePopover?: boolean;
|
||||
pinBlocksPopover?: boolean;
|
||||
nodes?: CustomNode[];
|
||||
onNodeSelect?: (nodeId: string) => void;
|
||||
onNodeHover?: (nodeId: string) => void;
|
||||
};
|
||||
export const NewControlPanel = ({
|
||||
flowExecutionID,
|
||||
visualizeBeads,
|
||||
pinSavePopover,
|
||||
pinBlocksPopover,
|
||||
nodes,
|
||||
onNodeSelect,
|
||||
onNodeHover,
|
||||
className,
|
||||
}: ControlPanelProps) => {
|
||||
const isGraphSearchEnabled = useGetFlag(Flag.GRAPH_SEARCH);
|
||||
flowExecutionID: _flowExecutionID,
|
||||
visualizeBeads: _visualizeBeads,
|
||||
pinSavePopover: _pinSavePopover,
|
||||
pinBlocksPopover: _pinBlocksPopover,
|
||||
nodes: _nodes,
|
||||
onNodeSelect: _onNodeSelect,
|
||||
onNodeHover: _onNodeHover,
|
||||
}: NewControlPanelProps) => {
|
||||
const _isGraphSearchEnabled = useGetFlag(Flag.GRAPH_SEARCH);
|
||||
|
||||
const {
|
||||
blockMenuSelected,
|
||||
setBlockMenuSelected,
|
||||
agentDescription,
|
||||
setAgentDescription,
|
||||
saveAgent,
|
||||
agentName,
|
||||
setAgentName,
|
||||
savedAgent,
|
||||
isSaving,
|
||||
isRunning,
|
||||
isStopping,
|
||||
} = useNewControlPanel({ flowExecutionID, visualizeBeads });
|
||||
// agentDescription,
|
||||
// setAgentDescription,
|
||||
// saveAgent,
|
||||
// agentName,
|
||||
// setAgentName,
|
||||
// savedAgent,
|
||||
// isSaving,
|
||||
// isRunning,
|
||||
// isStopping,
|
||||
} = useNewControlPanel({});
|
||||
|
||||
const controls: Control[] = useMemo(
|
||||
const _controls: Control[] = useMemo(
|
||||
() => [
|
||||
{
|
||||
label: "Undo",
|
||||
@@ -77,17 +74,16 @@ export const NewControlPanel = ({
|
||||
return (
|
||||
<section
|
||||
className={cn(
|
||||
"absolute left-4 top-24 z-10 w-[4.25rem] overflow-hidden rounded-[1rem] border-none bg-white p-0 shadow-[0_1px_5px_0_rgba(0,0,0,0.1)]",
|
||||
className,
|
||||
"top- absolute left-4 z-10 w-[4.25rem] overflow-hidden rounded-[1rem] border-none bg-white p-0 shadow-[0_1px_5px_0_rgba(0,0,0,0.1)]",
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col items-center justify-center rounded-[1rem] p-0">
|
||||
<BlockMenu
|
||||
pinBlocksPopover={pinBlocksPopover}
|
||||
// pinBlocksPopover={pinBlocksPopover}
|
||||
blockMenuSelected={blockMenuSelected}
|
||||
setBlockMenuSelected={setBlockMenuSelected}
|
||||
/>
|
||||
<Separator className="text-[#E1E1E1]" />
|
||||
{/* <Separator className="text-[#E1E1E1]" />
|
||||
{isGraphSearchEnabled && (
|
||||
<>
|
||||
<GraphSearchMenu
|
||||
@@ -124,7 +120,7 @@ export const NewControlPanel = ({
|
||||
pinSavePopover={pinSavePopover}
|
||||
blockMenuSelected={blockMenuSelected}
|
||||
setBlockMenuSelected={setBlockMenuSelected}
|
||||
/>
|
||||
/> */}
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
|
||||
@@ -1,53 +1,54 @@
|
||||
import useAgentGraph from "@/hooks/useAgentGraph";
|
||||
import { GraphExecutionID, GraphID } from "@/lib/autogpt-server-api";
|
||||
import { GraphID } from "@/lib/autogpt-server-api";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
|
||||
export interface NewControlPanelProps {
|
||||
flowExecutionID: GraphExecutionID | undefined;
|
||||
visualizeBeads: "no" | "static" | "animate";
|
||||
// flowExecutionID: GraphExecutionID | undefined;
|
||||
visualizeBeads?: "no" | "static" | "animate";
|
||||
}
|
||||
|
||||
export const useNewControlPanel = ({
|
||||
flowExecutionID,
|
||||
visualizeBeads,
|
||||
// flowExecutionID,
|
||||
visualizeBeads: _visualizeBeads,
|
||||
}: NewControlPanelProps) => {
|
||||
const [blockMenuSelected, setBlockMenuSelected] = useState<
|
||||
"save" | "block" | "search" | ""
|
||||
>("");
|
||||
const query = useSearchParams();
|
||||
const _graphVersion = query.get("flowVersion");
|
||||
const graphVersion = _graphVersion ? parseInt(_graphVersion) : undefined;
|
||||
const _graphVersionParsed = _graphVersion
|
||||
? parseInt(_graphVersion)
|
||||
: undefined;
|
||||
|
||||
const flowID = (query.get("flowID") as GraphID | null) ?? undefined;
|
||||
const {
|
||||
agentDescription,
|
||||
setAgentDescription,
|
||||
saveAgent,
|
||||
agentName,
|
||||
setAgentName,
|
||||
savedAgent,
|
||||
isSaving,
|
||||
isRunning,
|
||||
isStopping,
|
||||
} = useAgentGraph(
|
||||
flowID,
|
||||
graphVersion,
|
||||
flowExecutionID,
|
||||
visualizeBeads !== "no",
|
||||
);
|
||||
const _flowID = (query.get("flowID") as GraphID | null) ?? undefined;
|
||||
// const {
|
||||
// agentDescription,
|
||||
// setAgentDescription,
|
||||
// saveAgent,
|
||||
// agentName,
|
||||
// setAgentName,
|
||||
// savedAgent,
|
||||
// isSaving,
|
||||
// isRunning,
|
||||
// isStopping,
|
||||
// } = useAgentGraph(
|
||||
// flowID,
|
||||
// graphVersion,
|
||||
// flowExecutionID,
|
||||
// visualizeBeads !== "no",
|
||||
// );
|
||||
|
||||
return {
|
||||
blockMenuSelected,
|
||||
setBlockMenuSelected,
|
||||
agentDescription,
|
||||
setAgentDescription,
|
||||
saveAgent,
|
||||
agentName,
|
||||
setAgentName,
|
||||
savedAgent,
|
||||
isSaving,
|
||||
isRunning,
|
||||
isStopping,
|
||||
// agentDescription,
|
||||
// setAgentDescription,
|
||||
// saveAgent,
|
||||
// agentName,
|
||||
// setAgentName,
|
||||
// savedAgent,
|
||||
// isSaving,
|
||||
// isRunning,
|
||||
// isStopping,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -5,10 +5,12 @@ import { DefaultStateType, useBlockMenuContext } from "../block-menu-provider";
|
||||
import { useSuggestionContent } from "./useSuggestionContent";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { blockMenuContainerStyle } from "../style";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
|
||||
export const SuggestionContent = () => {
|
||||
const { setIntegration, setDefaultState } = useBlockMenuContext();
|
||||
const { data, isLoading, isError, error, refetch } = useSuggestionContent();
|
||||
const addBlock = useNodeStore((state) => state.addBlock);
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
@@ -73,6 +75,7 @@ export const SuggestionContent = () => {
|
||||
key={`block-${index}`}
|
||||
title={block.name}
|
||||
description={block.description}
|
||||
onClick={() => addBlock(block)}
|
||||
/>
|
||||
))
|
||||
: Array(3)
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import { useMemo } from "react";
|
||||
|
||||
import { Link } from "@/app/api/__generated__/models/link";
|
||||
import { useEdgeStore } from "../stores/edgeStore";
|
||||
import { useNodeStore } from "../stores/nodeStore";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const RightSidebar = () => {
|
||||
const connections = useEdgeStore((s) => s.connections);
|
||||
const nodes = useNodeStore((s) => s.nodes);
|
||||
|
||||
const backendLinks: Link[] = useMemo(
|
||||
() =>
|
||||
connections.map((c) => ({
|
||||
source_id: c.source,
|
||||
sink_id: c.target,
|
||||
source_name: c.sourceHandle,
|
||||
sink_name: c.targetHandle,
|
||||
})),
|
||||
[connections],
|
||||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-full w-full flex-col border-l border-slate-200 bg-white p-4 dark:border-slate-700 dark:bg-slate-900",
|
||||
scrollbarStyles,
|
||||
)}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<h2 className="text-lg font-semibold text-slate-800 dark:text-slate-200">
|
||||
Flow Debug Panel
|
||||
</h2>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<h3 className="mb-2 text-sm font-semibold text-slate-700 dark:text-slate-200">
|
||||
Nodes ({nodes.length})
|
||||
</h3>
|
||||
<div className="mb-6 space-y-3">
|
||||
{nodes.map((n) => (
|
||||
<div
|
||||
key={n.id}
|
||||
className="rounded border p-2 text-xs dark:border-slate-700"
|
||||
>
|
||||
<div className="mb-1 font-medium">
|
||||
#{n.id} {n.data?.title ? `– ${n.data.title}` : ""}
|
||||
</div>
|
||||
<div className="text-slate-500 dark:text-slate-400">
|
||||
hardcodedValues
|
||||
</div>
|
||||
<pre className="mt-1 max-h-40 overflow-auto rounded bg-slate-50 p-2 dark:bg-slate-800">
|
||||
{JSON.stringify(n.data?.hardcodedValues ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<h3 className="mb-2 text-sm font-semibold text-slate-700 dark:text-slate-200">
|
||||
Links ({backendLinks.length})
|
||||
</h3>
|
||||
<div className="mb-6 space-y-3">
|
||||
{connections.map((c) => (
|
||||
<div
|
||||
key={c.edge_id}
|
||||
className="rounded border p-2 text-xs dark:border-slate-700"
|
||||
>
|
||||
<div className="font-medium">
|
||||
{c.source}[{c.sourceHandle}] → {c.target}[{c.targetHandle}]
|
||||
</div>
|
||||
<div className="mt-1 text-slate-500 dark:text-slate-400">
|
||||
edge_id: {c.edge_id}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<h4 className="mb-2 text-xs font-semibold text-slate-600 dark:text-slate-300">
|
||||
Backend Links JSON
|
||||
</h4>
|
||||
<pre className="max-h-64 overflow-auto rounded bg-slate-50 p-2 text-[11px] dark:bg-slate-800">
|
||||
{JSON.stringify(backendLinks, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,13 @@
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import { CustomNodeData } from "./FlowEditor/nodes/CustomNode";
|
||||
|
||||
export const convertBlockInfoIntoCustomNodeData = (block: BlockInfo) => {
|
||||
const customNodeData: CustomNodeData = {
|
||||
hardcodedValues: {},
|
||||
title: block.name,
|
||||
description: block.description,
|
||||
inputSchema: block.inputSchema,
|
||||
outputSchema: block.outputSchema,
|
||||
};
|
||||
return customNodeData;
|
||||
};
|
||||
@@ -858,7 +858,7 @@ const FlowEditor: React.FC<{
|
||||
visualizeBeads={visualizeBeads}
|
||||
pinSavePopover={pinSavePopover}
|
||||
pinBlocksPopover={pinBlocksPopover}
|
||||
nodes={nodes}
|
||||
// nodes={nodes}
|
||||
onNodeSelect={navigateToNode}
|
||||
onNodeHover={highlightNode}
|
||||
/>
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import FlowEditor from "@/app/(platform)/build/components/legacy-builder/Flow/Flow";
|
||||
import LoadingBox from "@/components/__legacy__/ui/loading";
|
||||
// import LoadingBox from "@/components/__legacy__/ui/loading";
|
||||
import { GraphID } from "@/lib/autogpt-server-api/types";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { Suspense, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { Flow } from "./components/FlowEditor/Flow";
|
||||
import { BuilderViewTabs } from "./components/BuilderViewTabs/BuilderViewTabs";
|
||||
import { useBuilderView } from "./components/BuilderViewTabs/useBuilderViewTabs";
|
||||
|
||||
function BuilderContent() {
|
||||
const query = useSearchParams();
|
||||
@@ -19,7 +22,7 @@ function BuilderContent() {
|
||||
const graphVersion = _graphVersion ? parseInt(_graphVersion) : undefined;
|
||||
return (
|
||||
<FlowEditor
|
||||
className="flow-container"
|
||||
className="flex h-full w-full"
|
||||
flowID={(query.get("flowID") as GraphID | null) ?? undefined}
|
||||
flowVersion={graphVersion}
|
||||
/>
|
||||
@@ -27,9 +30,22 @@ function BuilderContent() {
|
||||
}
|
||||
|
||||
export default function BuilderPage() {
|
||||
return (
|
||||
<Suspense fallback={<LoadingBox className="h-[80vh]" />}>
|
||||
<BuilderContent />
|
||||
</Suspense>
|
||||
);
|
||||
const {
|
||||
isSwitchEnabled,
|
||||
selectedView,
|
||||
setSelectedView,
|
||||
isNewFlowEditorEnabled,
|
||||
} = useBuilderView();
|
||||
|
||||
// Switch is temporary, we will remove it once our new flow editor is ready
|
||||
if (isSwitchEnabled) {
|
||||
return (
|
||||
<div className="relative h-full w-full">
|
||||
<BuilderViewTabs value={selectedView} onChange={setSelectedView} />
|
||||
{selectedView === "new" ? <Flow /> : <BuilderContent />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return isNewFlowEditorEnabled ? <Flow /> : <BuilderContent />;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
import { create } from "zustand";
|
||||
import { convertConnectionsToBackendLinks } from "../components/FlowEditor/edges/helpers";
|
||||
|
||||
export type Connection = {
|
||||
edge_id: string;
|
||||
source: string;
|
||||
sourceHandle: string;
|
||||
target: string;
|
||||
targetHandle: string;
|
||||
};
|
||||
|
||||
type EdgeStore = {
|
||||
connections: Connection[];
|
||||
|
||||
setConnections: (connections: Connection[]) => void;
|
||||
addConnection: (
|
||||
conn: Omit<Connection, "edge_id"> & { edge_id?: string },
|
||||
) => Connection;
|
||||
removeConnection: (edge_id: string) => void;
|
||||
upsertMany: (conns: Connection[]) => void;
|
||||
|
||||
getNodeConnections: (nodeId: string) => Connection[];
|
||||
isInputConnected: (nodeId: string, handle: string) => boolean;
|
||||
isOutputConnected: (nodeId: string, handle: string) => boolean;
|
||||
};
|
||||
|
||||
function makeEdgeId(conn: Omit<Connection, "edge_id">) {
|
||||
return `${conn.source}:${conn.sourceHandle}->${conn.target}:${conn.targetHandle}`;
|
||||
}
|
||||
|
||||
export const useEdgeStore = create<EdgeStore>((set, get) => ({
|
||||
connections: [],
|
||||
|
||||
setConnections: (connections) => set({ connections }),
|
||||
|
||||
addConnection: (conn) => {
|
||||
const edge_id = conn.edge_id || makeEdgeId(conn);
|
||||
const newConn: Connection = { edge_id, ...conn };
|
||||
|
||||
set((state) => {
|
||||
const exists = state.connections.some(
|
||||
(c) =>
|
||||
c.source === newConn.source &&
|
||||
c.target === newConn.target &&
|
||||
c.sourceHandle === newConn.sourceHandle &&
|
||||
c.targetHandle === newConn.targetHandle,
|
||||
);
|
||||
if (exists) return state;
|
||||
return { connections: [...state.connections, newConn] };
|
||||
});
|
||||
|
||||
return { edge_id, ...conn };
|
||||
},
|
||||
|
||||
removeConnection: (edge_id) =>
|
||||
set((state) => ({
|
||||
connections: state.connections.filter((c) => c.edge_id !== edge_id),
|
||||
})),
|
||||
|
||||
upsertMany: (conns) =>
|
||||
set((state) => {
|
||||
const byKey = new Map(state.connections.map((c) => [c.edge_id, c]));
|
||||
conns.forEach((c) => {
|
||||
byKey.set(c.edge_id, c);
|
||||
});
|
||||
return { connections: Array.from(byKey.values()) };
|
||||
}),
|
||||
|
||||
getNodeConnections: (nodeId) =>
|
||||
get().connections.filter((c) => c.source === nodeId || c.target === nodeId),
|
||||
|
||||
isInputConnected: (nodeId, handle) =>
|
||||
get().connections.some(
|
||||
(c) => c.target === nodeId && c.targetHandle === handle,
|
||||
),
|
||||
|
||||
isOutputConnected: (nodeId, handle) =>
|
||||
get().connections.some(
|
||||
(c) => c.source === nodeId && c.sourceHandle === handle,
|
||||
),
|
||||
getBackendLinks: () => convertConnectionsToBackendLinks(get().connections),
|
||||
}));
|
||||
@@ -0,0 +1,75 @@
|
||||
import { create } from "zustand";
|
||||
import { NodeChange, applyNodeChanges } from "@xyflow/react";
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import { convertBlockInfoIntoCustomNodeData } from "../components/helper";
|
||||
|
||||
type NodeStore = {
|
||||
nodes: CustomNode[];
|
||||
nodeCounter: number;
|
||||
nodeAdvancedStates: Record<string, boolean>;
|
||||
setNodes: (nodes: CustomNode[]) => void;
|
||||
onNodesChange: (changes: NodeChange<CustomNode>[]) => void;
|
||||
addNode: (node: CustomNode) => void;
|
||||
addBlock: (block: BlockInfo) => void;
|
||||
incrementNodeCounter: () => void;
|
||||
updateNodeData: (nodeId: string, data: Partial<CustomNode["data"]>) => void;
|
||||
toggleAdvanced: (nodeId: string) => void;
|
||||
setShowAdvanced: (nodeId: string, show: boolean) => void;
|
||||
getShowAdvanced: (nodeId: string) => boolean;
|
||||
};
|
||||
|
||||
export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
nodes: [],
|
||||
setNodes: (nodes) => set({ nodes }),
|
||||
nodeCounter: 0,
|
||||
nodeAdvancedStates: {},
|
||||
incrementNodeCounter: () =>
|
||||
set((state) => ({
|
||||
nodeCounter: state.nodeCounter + 1,
|
||||
})),
|
||||
onNodesChange: (changes) =>
|
||||
set((state) => ({
|
||||
nodes: applyNodeChanges(changes, state.nodes),
|
||||
})),
|
||||
addNode: (node) =>
|
||||
set((state) => ({
|
||||
nodes: [...state.nodes, node],
|
||||
})),
|
||||
addBlock: (block: BlockInfo) => {
|
||||
const customNodeData = convertBlockInfoIntoCustomNodeData(block);
|
||||
get().incrementNodeCounter();
|
||||
const nodeNumber = get().nodeCounter;
|
||||
const customNode: CustomNode = {
|
||||
id: nodeNumber.toString(),
|
||||
data: customNodeData,
|
||||
type: "custom",
|
||||
position: { x: 0, y: 0 },
|
||||
};
|
||||
set((state) => ({
|
||||
nodes: [...state.nodes, customNode],
|
||||
}));
|
||||
},
|
||||
updateNodeData: (nodeId, data) =>
|
||||
set((state) => ({
|
||||
nodes: state.nodes.map((n) =>
|
||||
n.id === nodeId ? { ...n, data: { ...n.data, ...data } } : n,
|
||||
),
|
||||
})),
|
||||
toggleAdvanced: (nodeId: string) =>
|
||||
set((state) => ({
|
||||
nodeAdvancedStates: {
|
||||
...state.nodeAdvancedStates,
|
||||
[nodeId]: !state.nodeAdvancedStates[nodeId],
|
||||
},
|
||||
})),
|
||||
setShowAdvanced: (nodeId: string, show: boolean) =>
|
||||
set((state) => ({
|
||||
nodeAdvancedStates: {
|
||||
...state.nodeAdvancedStates,
|
||||
[nodeId]: show,
|
||||
},
|
||||
})),
|
||||
getShowAdvanced: (nodeId: string) =>
|
||||
get().nodeAdvancedStates[nodeId] || false,
|
||||
}));
|
||||
@@ -3,9 +3,9 @@ import { ReactNode } from "react";
|
||||
|
||||
export default function PlatformLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<>
|
||||
<main className="flex h-screen w-screen flex-col">
|
||||
<Navbar />
|
||||
<main>{children}</main>
|
||||
</>
|
||||
<section className="flex-1">{children}</section>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { Calendar as CalendarIcon } from "lucide-react";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/__legacy__/ui/popover";
|
||||
import { Calendar } from "@/components/__legacy__/ui/calendar";
|
||||
|
||||
function toLocalISODateString(d: Date) {
|
||||
const year = d.getFullYear();
|
||||
const month = String(d.getMonth() + 1).padStart(2, "0");
|
||||
const day = String(d.getDate()).padStart(2, "0");
|
||||
return `${year}-${month}-${day}`;
|
||||
}
|
||||
|
||||
function parseISODateString(s?: string): Date | undefined {
|
||||
if (!s) return undefined;
|
||||
// Expecting "YYYY-MM-DD"
|
||||
const m = /^(\d{4})-(\d{2})-(\d{2})$/.exec(s);
|
||||
if (!m) return undefined;
|
||||
const [_, y, mo, d] = m;
|
||||
const date = new Date(Number(y), Number(mo) - 1, Number(d));
|
||||
return isNaN(date.getTime()) ? undefined : date;
|
||||
}
|
||||
|
||||
export interface DateInputProps {
|
||||
value?: string;
|
||||
onChange?: (value?: string) => void;
|
||||
disabled?: boolean;
|
||||
readonly?: boolean;
|
||||
placeholder?: string;
|
||||
autoFocus?: boolean;
|
||||
className?: string;
|
||||
label?: string;
|
||||
hideLabel?: boolean;
|
||||
error?: string;
|
||||
id?: string;
|
||||
size?: "default" | "small";
|
||||
}
|
||||
|
||||
export const DateInput = ({
|
||||
value,
|
||||
onChange,
|
||||
disabled,
|
||||
readonly,
|
||||
placeholder,
|
||||
autoFocus,
|
||||
className,
|
||||
label,
|
||||
hideLabel = false,
|
||||
error,
|
||||
id,
|
||||
size = "default",
|
||||
}: DateInputProps) => {
|
||||
const selected = React.useMemo(() => parseISODateString(value), [value]);
|
||||
const [open, setOpen] = React.useState(false);
|
||||
|
||||
const setDate = (d?: Date) => {
|
||||
onChange?.(d ? toLocalISODateString(d) : undefined);
|
||||
setOpen(false);
|
||||
};
|
||||
|
||||
const buttonText =
|
||||
selected?.toLocaleDateString(undefined, {
|
||||
year: "numeric",
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
}) ||
|
||||
placeholder ||
|
||||
"Pick a date";
|
||||
|
||||
const isDisabled = disabled || readonly;
|
||||
|
||||
const triggerStyles = cn(
|
||||
// Base styles matching other form components
|
||||
"rounded-3xl border border-zinc-200 bg-white px-4 shadow-none",
|
||||
"font-normal text-black w-full text-sm",
|
||||
"placeholder:font-normal !placeholder:text-zinc-400",
|
||||
// Focus and hover states
|
||||
"focus:border-zinc-400 focus:shadow-none focus:outline-none focus:ring-1 focus:ring-zinc-400 focus:ring-offset-0",
|
||||
// Error state
|
||||
error &&
|
||||
"border-1.5 border-red-500 focus:border-red-500 focus:ring-red-500",
|
||||
// Placeholder styling
|
||||
!selected && "text-zinc-400",
|
||||
"justify-start text-left",
|
||||
// Size variants
|
||||
size === "default" && "h-[2.875rem] py-2.5",
|
||||
className,
|
||||
size === "small" && [
|
||||
"min-h-[2.25rem]", // 36px minimum
|
||||
"py-2",
|
||||
"text-sm leading-[22px]",
|
||||
"placeholder:text-sm placeholder:leading-[22px]",
|
||||
],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
{label && !hideLabel && (
|
||||
<label htmlFor={id} className="text-sm font-medium text-gray-700">
|
||||
{label}
|
||||
</label>
|
||||
)}
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
className={triggerStyles}
|
||||
disabled={isDisabled}
|
||||
autoFocus={autoFocus}
|
||||
id={id}
|
||||
{...(hideLabel && label ? { "aria-label": label } : {})}
|
||||
>
|
||||
<CalendarIcon
|
||||
className={cn("mr-2", size === "default" ? "h-4 w-4" : "h-3 w-3")}
|
||||
/>
|
||||
{buttonText}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-auto p-0" sideOffset={6}>
|
||||
<Calendar
|
||||
mode="single"
|
||||
selected={selected}
|
||||
onSelect={setDate}
|
||||
showOutsideDays
|
||||
// Prevent selection when disabled/readonly
|
||||
modifiersClassNames={{
|
||||
disabled: "pointer-events-none opacity-50",
|
||||
}}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
{error && <span className="text-sm text-red-500">{error}</span>}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,253 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { Calendar as CalendarIcon, Clock } from "lucide-react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { Text } from "../Text/Text";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/__legacy__/ui/popover";
|
||||
import { Calendar } from "@/components/__legacy__/ui/calendar";
|
||||
|
||||
function toLocalISODateTimeString(d: Date) {
|
||||
const year = d.getFullYear();
|
||||
const month = String(d.getMonth() + 1).padStart(2, "0");
|
||||
const day = String(d.getDate()).padStart(2, "0");
|
||||
const hours = String(d.getHours()).padStart(2, "0");
|
||||
const minutes = String(d.getMinutes()).padStart(2, "0");
|
||||
return `${year}-${month}-${day}T${hours}:${minutes}`;
|
||||
}
|
||||
|
||||
function parseISODateTimeString(s?: string): Date | undefined {
|
||||
if (!s) return undefined;
|
||||
// Expecting "YYYY-MM-DDTHH:MM" or "YYYY-MM-DD HH:MM"
|
||||
const normalized = s.replace(" ", "T");
|
||||
const date = new Date(normalized);
|
||||
return isNaN(date.getTime()) ? undefined : date;
|
||||
}
|
||||
|
||||
export interface DateTimeInputProps {
|
||||
value?: string;
|
||||
onChange?: (value?: string) => void;
|
||||
disabled?: boolean;
|
||||
readonly?: boolean;
|
||||
placeholder?: string;
|
||||
autoFocus?: boolean;
|
||||
className?: string;
|
||||
label?: string;
|
||||
hideLabel?: boolean;
|
||||
error?: string;
|
||||
hint?: React.ReactNode;
|
||||
id?: string;
|
||||
size?: "default" | "small";
|
||||
wrapperClassName?: string;
|
||||
}
|
||||
|
||||
export const DateTimeInput = ({
|
||||
value,
|
||||
onChange,
|
||||
disabled = false,
|
||||
readonly = false,
|
||||
placeholder,
|
||||
autoFocus,
|
||||
className,
|
||||
label,
|
||||
hideLabel = false,
|
||||
error,
|
||||
hint,
|
||||
id,
|
||||
size = "default",
|
||||
wrapperClassName,
|
||||
}: DateTimeInputProps) => {
|
||||
const selected = React.useMemo(() => parseISODateTimeString(value), [value]);
|
||||
const [open, setOpen] = React.useState(false);
|
||||
const [timeValue, setTimeValue] = React.useState("");
|
||||
|
||||
// Update time value when selected date changes
|
||||
React.useEffect(() => {
|
||||
if (selected) {
|
||||
const hours = String(selected.getHours()).padStart(2, "0");
|
||||
const minutes = String(selected.getMinutes()).padStart(2, "0");
|
||||
setTimeValue(`${hours}:${minutes}`);
|
||||
} else {
|
||||
setTimeValue("");
|
||||
}
|
||||
}, [selected]);
|
||||
|
||||
const setDate = (d?: Date) => {
|
||||
if (!d) {
|
||||
onChange?.(undefined);
|
||||
setOpen(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// If we have a time value, apply it to the selected date
|
||||
if (timeValue) {
|
||||
const [hours, minutes] = timeValue.split(":").map(Number);
|
||||
if (!isNaN(hours) && !isNaN(minutes)) {
|
||||
d.setHours(hours, minutes, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
onChange?.(toLocalISODateTimeString(d));
|
||||
setOpen(false);
|
||||
};
|
||||
|
||||
const handleTimeChange = (time: string) => {
|
||||
setTimeValue(time);
|
||||
|
||||
if (selected && time) {
|
||||
const [hours, minutes] = time.split(":").map(Number);
|
||||
if (!isNaN(hours) && !isNaN(minutes)) {
|
||||
const newDate = new Date(selected);
|
||||
newDate.setHours(hours, minutes, 0, 0);
|
||||
onChange?.(toLocalISODateTimeString(newDate));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const buttonText = selected
|
||||
? selected.toLocaleDateString(undefined, {
|
||||
year: "numeric",
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
}) +
|
||||
" " +
|
||||
selected.toLocaleTimeString(undefined, {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})
|
||||
: placeholder || "Pick date and time";
|
||||
|
||||
const isDisabled = disabled || readonly;
|
||||
|
||||
const triggerStyles = cn(
|
||||
// Base styles matching other form components
|
||||
"rounded-3xl border border-zinc-200 bg-white px-4 shadow-none",
|
||||
"font-normal text-black w-full text-sm",
|
||||
"placeholder:font-normal !placeholder:text-zinc-400",
|
||||
// Focus and hover states
|
||||
"focus:border-zinc-400 focus:shadow-none focus:outline-none focus:ring-1 focus:ring-zinc-400 focus:ring-offset-0",
|
||||
// Error state
|
||||
error &&
|
||||
"border-1.5 border-red-500 focus:border-red-500 focus:ring-red-500",
|
||||
// Placeholder styling
|
||||
!selected && "text-zinc-400",
|
||||
"justify-start text-left",
|
||||
// Size variants
|
||||
size === "default" && "h-[2.875rem] py-2.5",
|
||||
size === "small" && [
|
||||
"min-h-[2.25rem]", // 36px minimum
|
||||
"py-2",
|
||||
"text-sm leading-[22px]",
|
||||
"placeholder:text-sm placeholder:leading-[22px]",
|
||||
],
|
||||
className,
|
||||
);
|
||||
|
||||
const timeInputStyles = cn(
|
||||
// Base styles
|
||||
"rounded-3xl border border-zinc-200 bg-white px-4 shadow-none",
|
||||
"font-normal text-black w-full",
|
||||
"placeholder:font-normal placeholder:text-zinc-400",
|
||||
// Focus and hover states
|
||||
"focus:border-zinc-400 focus:shadow-none focus:outline-none focus:ring-1 focus:ring-zinc-400 focus:ring-offset-0",
|
||||
// Size variants
|
||||
size === "small" && [
|
||||
"h-[2.25rem]", // 36px
|
||||
"py-2",
|
||||
"text-sm leading-[22px]", // 14px font, 22px line height
|
||||
"placeholder:text-sm placeholder:leading-[22px]",
|
||||
],
|
||||
size === "default" && [
|
||||
"h-[2.875rem]", // 46px
|
||||
"py-2.5",
|
||||
],
|
||||
);
|
||||
|
||||
const inputWithError = (
|
||||
<div className={cn("relative", error ? "mb-6" : "", wrapperClassName)}>
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
className={triggerStyles}
|
||||
disabled={isDisabled}
|
||||
autoFocus={autoFocus}
|
||||
id={id}
|
||||
{...(hideLabel && label ? { "aria-label": label } : {})}
|
||||
>
|
||||
<CalendarIcon
|
||||
className={cn("mr-2", size === "default" ? "h-4 w-4" : "h-3 w-3")}
|
||||
/>
|
||||
<Clock
|
||||
className={cn("mr-2", size === "default" ? "h-4 w-4" : "h-3 w-3")}
|
||||
/>
|
||||
{buttonText}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-auto p-0" sideOffset={6}>
|
||||
<div className="p-3">
|
||||
<Calendar
|
||||
mode="single"
|
||||
selected={selected}
|
||||
onSelect={setDate}
|
||||
showOutsideDays
|
||||
modifiersClassNames={{
|
||||
disabled: "pointer-events-none opacity-50",
|
||||
}}
|
||||
/>
|
||||
<div className="mt-3 border-t pt-3">
|
||||
<label className="mb-2 block text-sm font-medium text-gray-700">
|
||||
Time
|
||||
</label>
|
||||
<input
|
||||
type="time"
|
||||
value={timeValue}
|
||||
onChange={(e) => handleTimeChange(e.target.value)}
|
||||
className={timeInputStyles}
|
||||
disabled={isDisabled}
|
||||
placeholder="HH:MM"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
{error && (
|
||||
<Text
|
||||
variant="small-medium"
|
||||
as="span"
|
||||
className={cn(
|
||||
"absolute left-0 top-full mt-1 !text-red-500 transition-opacity duration-200",
|
||||
error ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
{error || " "}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
return hideLabel || !label ? (
|
||||
inputWithError
|
||||
) : (
|
||||
<label htmlFor={id} className="flex flex-col gap-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="body-medium" as="span" className="text-black">
|
||||
{label}
|
||||
</Text>
|
||||
{hint ? (
|
||||
<Text variant="small" as="span" className="!text-zinc-400">
|
||||
{hint}
|
||||
</Text>
|
||||
) : null}
|
||||
</div>
|
||||
{inputWithError}
|
||||
</label>
|
||||
);
|
||||
};
|
||||
@@ -36,6 +36,7 @@ export interface SelectFieldProps {
|
||||
options: SelectOption[];
|
||||
size?: "small" | "medium";
|
||||
renderItem?: (option: SelectOption) => React.ReactNode;
|
||||
wrapperClassName?: string;
|
||||
}
|
||||
|
||||
export function Select({
|
||||
@@ -52,6 +53,7 @@ export function Select({
|
||||
options,
|
||||
size = "medium",
|
||||
renderItem,
|
||||
wrapperClassName,
|
||||
}: SelectFieldProps) {
|
||||
const triggerStyles = cn(
|
||||
// Base styles matching Input
|
||||
@@ -117,7 +119,7 @@ export function Select({
|
||||
);
|
||||
|
||||
const selectWithError = (
|
||||
<div className="relative mb-6">
|
||||
<div className={cn("relative mb-6", wrapperClassName)}>
|
||||
{select}
|
||||
<Text
|
||||
variant="small-medium"
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
import React, { ReactNode } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Text } from "../Text/Text";
|
||||
|
||||
interface TimeInputProps {
|
||||
value?: string;
|
||||
onChange?: (value: string) => void;
|
||||
className?: string;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
label?: string;
|
||||
id?: string;
|
||||
hideLabel?: boolean;
|
||||
error?: string;
|
||||
hint?: ReactNode;
|
||||
size?: "small" | "medium";
|
||||
wrapperClassName?: string;
|
||||
}
|
||||
|
||||
export const TimeInput: React.FC<TimeInputProps> = ({
|
||||
value = "",
|
||||
onChange,
|
||||
className,
|
||||
disabled = false,
|
||||
placeholder = "HH:MM",
|
||||
label,
|
||||
id,
|
||||
hideLabel = false,
|
||||
error,
|
||||
hint,
|
||||
size = "medium",
|
||||
wrapperClassName,
|
||||
}) => {
|
||||
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
onChange?.(e.target.value);
|
||||
};
|
||||
|
||||
const baseStyles = cn(
|
||||
// Base styles
|
||||
"rounded-3xl border border-zinc-200 bg-white px-4 shadow-none",
|
||||
"font-normal text-black",
|
||||
"placeholder:font-normal placeholder:text-zinc-400",
|
||||
// Focus and hover states
|
||||
"focus:border-zinc-400 focus:shadow-none focus:outline-none focus:ring-1 focus:ring-zinc-400 focus:ring-offset-0",
|
||||
className,
|
||||
);
|
||||
|
||||
const errorStyles =
|
||||
error && "!border !border-red-500 focus:border-red-500 focus:ring-red-500";
|
||||
|
||||
const input = (
|
||||
<div className={cn("relative", wrapperClassName)}>
|
||||
<input
|
||||
type="time"
|
||||
value={value}
|
||||
onChange={handleChange}
|
||||
className={cn(
|
||||
baseStyles,
|
||||
errorStyles,
|
||||
// Size variants
|
||||
size === "small" && [
|
||||
"h-[2.25rem]", // 36px
|
||||
"py-2",
|
||||
"text-sm leading-[22px]", // 14px font, 22px line height
|
||||
"placeholder:text-sm placeholder:leading-[22px]",
|
||||
],
|
||||
size === "medium" && [
|
||||
"h-[2.875rem]", // 46px (current default)
|
||||
"py-2.5",
|
||||
],
|
||||
)}
|
||||
disabled={disabled}
|
||||
placeholder={placeholder || label}
|
||||
{...(hideLabel ? { "aria-label": label } : {})}
|
||||
id={id}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
const inputWithError = (
|
||||
<div className={cn("relative mb-6", wrapperClassName)}>
|
||||
{input}
|
||||
<Text
|
||||
variant="small-medium"
|
||||
as="span"
|
||||
className={cn(
|
||||
"absolute left-0 top-full mt-1 !text-red-500 transition-opacity duration-200",
|
||||
error ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
{error || " "}{" "}
|
||||
{/* Always render with space to maintain consistent height calculation */}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
|
||||
return hideLabel || !label ? (
|
||||
inputWithError
|
||||
) : (
|
||||
<label htmlFor={id} className="flex flex-col gap-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="body-medium" as="span" className="text-black">
|
||||
{label}
|
||||
</Text>
|
||||
{hint ? (
|
||||
<Text variant="small" as="span" className="!text-zinc-400">
|
||||
{hint}
|
||||
</Text>
|
||||
) : null}
|
||||
</div>
|
||||
{inputWithError}
|
||||
</label>
|
||||
);
|
||||
};
|
||||
@@ -27,7 +27,7 @@ export const NavbarView = ({ isLoggedIn }: NavbarViewProps) => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<nav className="sticky top-0 z-40 inline-flex h-16 items-center border border-white/50 bg-[#f3f4f6]/20 p-3 backdrop-blur-[26px]">
|
||||
<nav className="sticky top-0 z-40 inline-flex h-16 w-full items-center border border-white/50 bg-[#f3f4f6]/20 p-3 backdrop-blur-[26px]">
|
||||
{/* Left section */}
|
||||
<div className="hidden flex-1 items-center gap-3 md:flex md:gap-5">
|
||||
{isLoggedIn
|
||||
|
||||
@@ -10,6 +10,8 @@ export enum Flag {
|
||||
NEW_AGENT_RUNS = "new-agent-runs",
|
||||
GRAPH_SEARCH = "graph-search",
|
||||
ENABLE_ENHANCED_OUTPUT_HANDLING = "enable-enhanced-output-handling",
|
||||
NEW_FLOW_EDITOR = "new-flow-editor",
|
||||
BUILDER_VIEW_SWITCH = "builder-view-switch",
|
||||
SHARE_EXECUTION_RESULTS = "share-execution-results",
|
||||
AGENT_FAVORITING = "agent-favoriting",
|
||||
}
|
||||
@@ -21,6 +23,8 @@ export type FlagValues = {
|
||||
[Flag.NEW_AGENT_RUNS]: boolean;
|
||||
[Flag.GRAPH_SEARCH]: boolean;
|
||||
[Flag.ENABLE_ENHANCED_OUTPUT_HANDLING]: boolean;
|
||||
[Flag.NEW_FLOW_EDITOR]: boolean;
|
||||
[Flag.BUILDER_VIEW_SWITCH]: boolean;
|
||||
[Flag.SHARE_EXECUTION_RESULTS]: boolean;
|
||||
[Flag.AGENT_FAVORITING]: boolean;
|
||||
};
|
||||
@@ -34,6 +38,8 @@ const mockFlags = {
|
||||
[Flag.NEW_AGENT_RUNS]: false,
|
||||
[Flag.GRAPH_SEARCH]: true,
|
||||
[Flag.ENABLE_ENHANCED_OUTPUT_HANDLING]: false,
|
||||
[Flag.NEW_FLOW_EDITOR]: false,
|
||||
[Flag.BUILDER_VIEW_SWITCH]: false,
|
||||
[Flag.SHARE_EXECUTION_RESULTS]: false,
|
||||
[Flag.AGENT_FAVORITING]: false,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user