mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
48 Commits
seer/featu
...
fix/sentry
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3500d24c6a | ||
|
|
4d43570552 | ||
|
|
d3d78660da | ||
|
|
749be06599 | ||
|
|
27d886f05c | ||
|
|
be01a1316a | ||
|
|
32bb6705d1 | ||
|
|
536e2a5ec8 | ||
|
|
a3e5f7fce2 | ||
|
|
d674cb80e2 | ||
|
|
c3a6235cee | ||
|
|
43638defa2 | ||
|
|
9371528aab | ||
|
|
711c439642 | ||
|
|
33989f09d0 | ||
|
|
d6ee402483 | ||
|
|
08a4a6652d | ||
|
|
58928b516b | ||
|
|
18bb78d93e | ||
|
|
8058b9487b | ||
|
|
e68896a25a | ||
|
|
dfed092869 | ||
|
|
5559d978d7 | ||
|
|
dcecb17bd1 | ||
|
|
a056d9e71a | ||
|
|
42b643579f | ||
|
|
5b52ca9227 | ||
|
|
8e83586d13 | ||
|
|
df9850a141 | ||
|
|
cbe4086e79 | ||
|
|
fbd5f34a61 | ||
|
|
0b680d4990 | ||
|
|
6037f80502 | ||
|
|
37b3e4e82e | ||
|
|
de7c5b5c31 | ||
|
|
d68dceb9c1 | ||
|
|
193866232c | ||
|
|
979826f559 | ||
|
|
2f87e13d17 | ||
|
|
2ad5a88a5c | ||
|
|
e9cd40c0d4 | ||
|
|
4744675ef9 | ||
|
|
910fd2640d | ||
|
|
eae2616fb5 | ||
|
|
ad3ea59d90 | ||
|
|
69b6b732a2 | ||
|
|
b1a2d21892 | ||
|
|
a78b08f5e7 |
8
.github/workflows/platform-frontend-ci.yml
vendored
8
.github/workflows/platform-frontend-ci.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
4
.github/workflows/platform-fullstack-ci.yml
vendored
4
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -178,3 +178,4 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
@@ -192,6 +192,8 @@ Quick steps:
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from .config import verify_settings
|
||||
from .dependencies import get_user_id, requires_admin_user, requires_user
|
||||
from .dependencies import (
|
||||
get_optional_user_id,
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from .helpers import add_auth_responses_to_openapi
|
||||
from .models import User
|
||||
|
||||
@@ -8,6 +13,7 @@ __all__ = [
|
||||
"get_user_id",
|
||||
"requires_admin_user",
|
||||
"requires_user",
|
||||
"get_optional_user_id",
|
||||
"add_auth_responses_to_openapi",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -4,11 +4,53 @@ FastAPI dependency functions for JWT-based authentication and authorization.
|
||||
These are the high-level dependency functions used in route definitions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import fastapi
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
optional_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
# Header name for admin impersonation
|
||||
IMPERSONATION_HEADER_NAME = "X-Act-As-User-Id"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_optional_user_id(
|
||||
credentials: HTTPAuthorizationCredentials | None = fastapi.Security(
|
||||
optional_bearer
|
||||
),
|
||||
) -> str | None:
|
||||
"""
|
||||
Attempts to extract the user ID ("sub" claim) from a Bearer JWT if provided.
|
||||
|
||||
This dependency allows for both authenticated and anonymous access. If a valid bearer token is
|
||||
supplied, it parses the JWT and extracts the user ID. If the token is missing or invalid, it returns None,
|
||||
treating the request as anonymous.
|
||||
|
||||
Args:
|
||||
credentials: Optional HTTPAuthorizationCredentials object from FastAPI Security dependency.
|
||||
|
||||
Returns:
|
||||
The user ID (str) extracted from the JWT "sub" claim, or None if no valid token is present.
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse JWT token to get user ID
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auth token validation failed (anonymous access): {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
@@ -32,16 +74,44 @@ async def requires_admin_user(
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
async def get_user_id(
|
||||
request: fastapi.Request, jwt_payload: dict = fastapi.Security(get_jwt_payload)
|
||||
) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
Supports admin impersonation via X-Act-As-User-Id header:
|
||||
- If the header is present and user is admin, returns the impersonated user ID
|
||||
- Otherwise returns the authenticated user's own ID
|
||||
- Logs all impersonation actions for audit trail
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures or missing user ID
|
||||
HTTPException: 403 if non-admin tries to use impersonation
|
||||
"""
|
||||
# Get the authenticated user's ID from JWT
|
||||
user_id = jwt_payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
# Check for admin impersonation header
|
||||
impersonate_header = request.headers.get(IMPERSONATION_HEADER_NAME, "").strip()
|
||||
if impersonate_header:
|
||||
# Verify the authenticated user is an admin
|
||||
authenticated_user = verify_user(jwt_payload, admin_only=False)
|
||||
if authenticated_user.role != "admin":
|
||||
raise fastapi.HTTPException(
|
||||
status_code=403, detail="Only admin users can impersonate other users"
|
||||
)
|
||||
|
||||
# Log the impersonation for audit trail
|
||||
logger.info(
|
||||
f"Admin impersonation: {authenticated_user.user_id} ({authenticated_user.email}) "
|
||||
f"acting as user {impersonate_header} for requesting {request.method} {request.url}"
|
||||
)
|
||||
|
||||
return impersonate_header
|
||||
|
||||
return user_id
|
||||
|
||||
@@ -4,9 +4,10 @@ Tests the full authentication flow from HTTP requests to user validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Security
|
||||
from fastapi import FastAPI, HTTPException, Request, Security
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
@@ -45,6 +46,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
@@ -58,6 +60,7 @@ class TestAuthDependencies:
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
@@ -73,6 +76,7 @@ class TestAuthDependencies:
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
@@ -82,6 +86,7 @@ class TestAuthDependencies:
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
@@ -90,6 +95,7 @@ class TestAuthDependencies:
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
@@ -105,6 +111,7 @@ class TestAuthDependencies:
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
@@ -114,6 +121,7 @@ class TestAuthDependencies:
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
@@ -121,31 +129,40 @@ class TestAuthDependencies:
|
||||
with pytest.raises(KeyError):
|
||||
await requires_admin_user(jwt_payload)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
await get_user_id(request, jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
await get_user_id(request, jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,6 +187,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
@@ -184,6 +202,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
@@ -203,6 +222,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
@@ -240,6 +260,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
@@ -263,6 +284,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
@@ -276,6 +298,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
@@ -290,6 +313,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
@@ -314,6 +338,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
@@ -325,6 +350,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
@@ -333,3 +359,196 @@ class TestAuthDependenciesEdgeCases:
|
||||
# Valid case
|
||||
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
|
||||
assert user.user_id == "user"
|
||||
|
||||
|
||||
class TestAdminImpersonation:
|
||||
"""Test suite for admin user impersonation functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_impersonation_success(self, mocker: MockerFixture):
|
||||
"""Test admin successfully impersonating another user."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to verify audit logging
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the impersonated user ID
|
||||
assert user_id == "target-user-123"
|
||||
|
||||
# Should log the impersonation attempt
|
||||
mock_logger.info.assert_called_once()
|
||||
log_call = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_call
|
||||
assert "admin@example.com" in log_call
|
||||
assert "target-user-123" in log_call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_impersonation_attempt(self, mocker: MockerFixture):
|
||||
"""Test non-admin user attempting impersonation returns 403."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "regular-user",
|
||||
"role": "user",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return regular user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="regular-user", email="user@example.com", role="user"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Only admin users can impersonate other users" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_empty_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with empty header falls back to regular user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": ""}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_missing_header(self, mocker: MockerFixture):
|
||||
"""Test normal behavior when impersonation header is missing."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {} # No impersonation header
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_audit_logging_details(self, mocker: MockerFixture):
|
||||
"""Test that impersonation audit logging includes all required details."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "victim-user-789"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-999",
|
||||
"role": "admin",
|
||||
"email": "superadmin@company.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-999", email="superadmin@company.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to capture audit trail
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Verify all audit details are logged
|
||||
assert user_id == "victim-user-789"
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
log_message = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_message
|
||||
assert "superadmin@company.com" in log_message
|
||||
assert "victim-user-789" in log_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_header_case_sensitivity(self, mocker: MockerFixture):
|
||||
"""Test that impersonation header is case-sensitive."""
|
||||
request = Mock(spec=Request)
|
||||
# Use wrong case - should not trigger impersonation
|
||||
request.headers = {"x-act-as-user-id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to admin's own ID (header case mismatch)
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_with_whitespace_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with whitespace in header value."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": " target-user-123 "}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should strip whitespace and impersonate successfully
|
||||
assert user_id == "target-user-123"
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
@@ -84,6 +84,7 @@ class AgentExecutorBlock(Block):
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
parent_graph_exec_id=graph_exec_id,
|
||||
is_sub_graph=True, # AgentExecutorBlock executions are always sub-graphs
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
|
||||
22
autogpt_platform/backend/backend/blocks/exa/_test.py
Normal file
22
autogpt_platform/backend/backend/blocks/exa/_test.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Test credentials and helpers for Exa blocks.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,52 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.api import AnswerResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
MediaFileType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float
|
||||
neuralSearch: float
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
class AnswerCitation(BaseModel):
|
||||
"""Citation model for answer endpoint."""
|
||||
|
||||
id: str = SchemaField(description="The temporary ID for the document")
|
||||
url: str = SchemaField(description="The URL of the search result")
|
||||
title: Optional[str] = SchemaField(description="The title of the search result")
|
||||
author: Optional[str] = SchemaField(description="The author of the content")
|
||||
publishedDate: Optional[str] = SchemaField(
|
||||
description="An estimate of the creation date"
|
||||
)
|
||||
text: Optional[str] = SchemaField(description="The full text content of the source")
|
||||
image: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the image associated with the result"
|
||||
)
|
||||
favicon: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the favicon for the domain"
|
||||
)
|
||||
|
||||
class SearchBreakdown(BaseModel):
|
||||
search: float
|
||||
contents: float
|
||||
breakdown: CostBreakdown
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float
|
||||
neuralSearch_26_100_results: float
|
||||
neuralSearch_100_plus_results: float
|
||||
keywordSearch_1_100_results: float
|
||||
keywordSearch_100_plus_results: float
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float
|
||||
breakDown: list[SearchBreakdown]
|
||||
perRequestPrices: PerRequestPrices
|
||||
perPagePrices: PerPagePrices
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_citation) -> "AnswerCitation":
|
||||
"""Convert SDK AnswerResult (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_citation, "id", ""),
|
||||
url=getattr(sdk_citation, "url", ""),
|
||||
title=getattr(sdk_citation, "title", None),
|
||||
author=getattr(sdk_citation, "author", None),
|
||||
publishedDate=getattr(sdk_citation, "published_date", None),
|
||||
text=getattr(sdk_citation, "text", None),
|
||||
image=getattr(sdk_citation, "image", None),
|
||||
favicon=getattr(sdk_citation, "favicon", None),
|
||||
)
|
||||
|
||||
|
||||
class ExaAnswerBlock(Block):
|
||||
@@ -59,31 +62,21 @@ class ExaAnswerBlock(Block):
|
||||
placeholder="What is the latest valuation of SpaceX?",
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
default=False,
|
||||
description="If true, the response includes full text content in the search results",
|
||||
advanced=True,
|
||||
)
|
||||
model: str = SchemaField(
|
||||
default="exa",
|
||||
description="The search model to use (exa or exa-pro)",
|
||||
placeholder="exa",
|
||||
advanced=True,
|
||||
description="Include full text content in the search results used for the answer",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
answer: str = SchemaField(
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
citations: list[dict] = SchemaField(
|
||||
description="Search results used to generate the answer",
|
||||
default_factory=list,
|
||||
citations: list[AnswerCitation] = SchemaField(
|
||||
description="Search results used to generate the answer"
|
||||
)
|
||||
cost_dollars: CostDollars = SchemaField(
|
||||
description="Cost breakdown of the request"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
citation: AnswerCitation = SchemaField(
|
||||
description="Individual citation from the answer"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -97,26 +90,24 @@ class ExaAnswerBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/answer"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"text": input_data.text,
|
||||
"model": input_data.model,
|
||||
}
|
||||
# Get answer using SDK (stream=False for blocks) - this IS async, needs await
|
||||
response = await aexa.answer(
|
||||
query=input_data.query, text=input_data.text, stream=False
|
||||
)
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# this should remain true as long as they don't start defaulting to streaming only.
|
||||
# provides a bit of safety for sdk updates.
|
||||
assert type(response) is AnswerResponse
|
||||
|
||||
yield "answer", data.get("answer", "")
|
||||
yield "citations", data.get("citations", [])
|
||||
yield "cost_dollars", data.get("costDollars", {})
|
||||
yield "answer", response.answer
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
citations = [
|
||||
AnswerCitation.from_sdk(sdk_citation)
|
||||
for sdk_citation in response.citations or []
|
||||
]
|
||||
|
||||
yield "citations", citations
|
||||
for citation in citations:
|
||||
yield "citation", citation
|
||||
|
||||
118
autogpt_platform/backend/backend/blocks/exa/code_context.py
Normal file
118
autogpt_platform/backend/backend/blocks/exa/code_context.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Exa Code Context Block
|
||||
|
||||
Provides code search capabilities to find relevant code snippets and examples
|
||||
from open source repositories, documentation, and Stack Overflow.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class CodeContextResponse(BaseModel):
|
||||
"""Stable output model for code context responses."""
|
||||
|
||||
request_id: str
|
||||
query: str
|
||||
response: str
|
||||
results_count: int
|
||||
cost_dollars: str
|
||||
search_time: float
|
||||
output_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "CodeContextResponse":
|
||||
"""Convert API response to our stable model."""
|
||||
return cls(
|
||||
request_id=data.get("requestId", ""),
|
||||
query=data.get("query", ""),
|
||||
response=data.get("response", ""),
|
||||
results_count=data.get("resultsCount", 0),
|
||||
cost_dollars=data.get("costDollars", ""),
|
||||
search_time=data.get("searchTime", 0.0),
|
||||
output_tokens=data.get("outputTokens", 0),
|
||||
)
|
||||
|
||||
|
||||
class ExaCodeContextBlock(Block):
|
||||
"""Get relevant code snippets and examples from open source repositories."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query to find relevant code snippets. Describe what you're trying to do or what code you're looking for.",
|
||||
placeholder="how to use React hooks for state management",
|
||||
)
|
||||
tokens_num: Union[str, int] = SchemaField(
|
||||
default="dynamic",
|
||||
description="Token limit for response. Use 'dynamic' for automatic sizing, 5000 for standard queries, or 10000 for comprehensive examples.",
|
||||
placeholder="dynamic",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
request_id: str = SchemaField(description="Unique identifier for this request")
|
||||
query: str = SchemaField(description="The search query used")
|
||||
response: str = SchemaField(
|
||||
description="Formatted code snippets and contextual examples with sources"
|
||||
)
|
||||
results_count: int = SchemaField(
|
||||
description="Number of code sources found and included"
|
||||
)
|
||||
cost_dollars: str = SchemaField(description="Cost of this request in dollars")
|
||||
search_time: float = SchemaField(
|
||||
description="Time taken to search in milliseconds"
|
||||
)
|
||||
output_tokens: int = SchemaField(description="Number of tokens in the response")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f9e0d1c-2b3a-4567-8901-23456789abcd",
|
||||
description="Search billions of GitHub repos, docs, and Stack Overflow for relevant code examples",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExaCodeContextBlock.Input,
|
||||
output_schema=ExaCodeContextBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/context"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"tokensNum": input_data.tokens_num,
|
||||
}
|
||||
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
context = CodeContextResponse.from_api(data)
|
||||
|
||||
yield "request_id", context.request_id
|
||||
yield "query", context.query
|
||||
yield "response", context.response
|
||||
yield "results_count", context.results_count
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
@@ -1,3 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -6,12 +12,45 @@ from backend.sdk import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
from .helpers import (
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
ExtrasSettings,
|
||||
HighlightSettings,
|
||||
LivecrawlTypes,
|
||||
SummarySettings,
|
||||
)
|
||||
|
||||
|
||||
class ContentStatusTag(str, Enum):
|
||||
CRAWL_NOT_FOUND = "CRAWL_NOT_FOUND"
|
||||
CRAWL_TIMEOUT = "CRAWL_TIMEOUT"
|
||||
CRAWL_LIVECRAWL_TIMEOUT = "CRAWL_LIVECRAWL_TIMEOUT"
|
||||
SOURCE_NOT_AVAILABLE = "SOURCE_NOT_AVAILABLE"
|
||||
CRAWL_UNKNOWN_ERROR = "CRAWL_UNKNOWN_ERROR"
|
||||
|
||||
|
||||
class ContentError(BaseModel):
|
||||
tag: Optional[ContentStatusTag] = SchemaField(
|
||||
default=None, description="Specific error type"
|
||||
)
|
||||
httpStatusCode: Optional[int] = SchemaField(
|
||||
default=None, description="The corresponding HTTP status code"
|
||||
)
|
||||
|
||||
|
||||
class ContentStatus(BaseModel):
|
||||
id: str = SchemaField(description="The URL that was requested")
|
||||
status: str = SchemaField(
|
||||
description="Status of the content fetch operation (success or error)"
|
||||
)
|
||||
error: Optional[ContentError] = SchemaField(
|
||||
default=None, description="Error details, only present when status is 'error'"
|
||||
)
|
||||
|
||||
|
||||
class ExaContentsBlock(Block):
|
||||
@@ -19,22 +58,70 @@ class ExaContentsBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
ids: list[str] = SchemaField(
|
||||
description="Array of document IDs obtained from searches"
|
||||
urls: list[str] = SchemaField(
|
||||
description="Array of URLs to crawl (preferred over 'ids')",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentSettings(),
|
||||
ids: list[str] = SchemaField(
|
||||
description="[DEPRECATED - use 'urls' instead] Array of document IDs obtained from searches",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
description="Retrieve text content from pages",
|
||||
default=True,
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
description="Text snippets most relevant from each page",
|
||||
default=HighlightSettings(),
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
description="LLM-generated summary of the webpage",
|
||||
default=SummarySettings(),
|
||||
)
|
||||
livecrawl: Optional[LivecrawlTypes] = SchemaField(
|
||||
description="Livecrawling options: never, fallback (default), always, preferred",
|
||||
default=LivecrawlTypes.FALLBACK,
|
||||
advanced=True,
|
||||
)
|
||||
livecrawl_timeout: Optional[int] = SchemaField(
|
||||
description="Timeout for livecrawling in milliseconds",
|
||||
default=10000,
|
||||
advanced=True,
|
||||
)
|
||||
subpages: Optional[int] = SchemaField(
|
||||
description="Number of subpages to crawl", default=0, ge=0, advanced=True
|
||||
)
|
||||
subpage_target: Optional[str | list[str]] = SchemaField(
|
||||
description="Keyword(s) to find specific subpages of search results",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
extras: ExtrasSettings = SchemaField(
|
||||
description="Extra parameters for additional content",
|
||||
default=ExtrasSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents", default_factory=list
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of document contents with metadata"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single document content result"
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs"
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
statuses: list[ContentStatus] = SchemaField(
|
||||
description="Status information for each requested URL"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -48,23 +135,91 @@ class ExaContentsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/contents"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
if not input_data.urls and not input_data.ids:
|
||||
raise ValueError("Either 'urls' or 'ids' must be provided")
|
||||
|
||||
# Convert ContentSettings to API format
|
||||
payload = {
|
||||
"ids": input_data.ids,
|
||||
"text": input_data.contents.text,
|
||||
"highlights": input_data.contents.highlights,
|
||||
"summary": input_data.contents.summary,
|
||||
}
|
||||
sdk_kwargs = {}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Prefer urls over ids
|
||||
if input_data.urls:
|
||||
sdk_kwargs["urls"] = input_data.urls
|
||||
elif input_data.ids:
|
||||
sdk_kwargs["ids"] = input_data.ids
|
||||
|
||||
if input_data.text:
|
||||
sdk_kwargs["text"] = {"includeHtmlTags": True}
|
||||
|
||||
# Handle highlights - only include if modified from defaults
|
||||
if input_data.highlights and (
|
||||
input_data.highlights.num_sentences != 1
|
||||
or input_data.highlights.highlights_per_url != 1
|
||||
or input_data.highlights.query is not None
|
||||
):
|
||||
highlights_dict = {}
|
||||
highlights_dict["numSentences"] = input_data.highlights.num_sentences
|
||||
highlights_dict["highlightsPerUrl"] = (
|
||||
input_data.highlights.highlights_per_url
|
||||
)
|
||||
if input_data.highlights.query:
|
||||
highlights_dict["query"] = input_data.highlights.query
|
||||
sdk_kwargs["highlights"] = highlights_dict
|
||||
|
||||
# Handle summary - only include if modified from defaults
|
||||
if input_data.summary and (
|
||||
input_data.summary.query is not None
|
||||
or input_data.summary.schema is not None
|
||||
):
|
||||
summary_dict = {}
|
||||
if input_data.summary.query:
|
||||
summary_dict["query"] = input_data.summary.query
|
||||
if input_data.summary.schema:
|
||||
summary_dict["schema"] = input_data.summary.schema
|
||||
sdk_kwargs["summary"] = summary_dict
|
||||
|
||||
if input_data.livecrawl:
|
||||
sdk_kwargs["livecrawl"] = input_data.livecrawl.value
|
||||
|
||||
if input_data.livecrawl_timeout is not None:
|
||||
sdk_kwargs["livecrawl_timeout"] = input_data.livecrawl_timeout
|
||||
|
||||
if input_data.subpages is not None:
|
||||
sdk_kwargs["subpages"] = input_data.subpages
|
||||
|
||||
if input_data.subpage_target:
|
||||
sdk_kwargs["subpage_target"] = input_data.subpage_target
|
||||
|
||||
# Handle extras - only include if modified from defaults
|
||||
if input_data.extras and (
|
||||
input_data.extras.links > 0 or input_data.extras.image_links > 0
|
||||
):
|
||||
extras_dict = {}
|
||||
if input_data.extras.links:
|
||||
extras_dict["links"] = input_data.extras.links
|
||||
if input_data.extras.image_links:
|
||||
extras_dict["image_links"] = input_data.extras.image_links
|
||||
sdk_kwargs["extras"] = extras_dict
|
||||
|
||||
# Always enable context for LLM-ready output
|
||||
sdk_kwargs["context"] = True
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
response = await aexa.get_contents(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.statuses:
|
||||
yield "statuses", response.statuses
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -1,51 +1,150 @@
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from backend.sdk import BaseModel, SchemaField
|
||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
||||
|
||||
|
||||
class TextSettings(BaseModel):
|
||||
max_characters: int = SchemaField(
|
||||
default=1000,
|
||||
class LivecrawlTypes(str, Enum):
|
||||
NEVER = "never"
|
||||
FALLBACK = "fallback"
|
||||
ALWAYS = "always"
|
||||
PREFERRED = "preferred"
|
||||
|
||||
|
||||
class TextEnabled(BaseModel):
|
||||
discriminator: Literal["enabled"] = "enabled"
|
||||
|
||||
|
||||
class TextDisabled(BaseModel):
|
||||
discriminator: Literal["disabled"] = "disabled"
|
||||
|
||||
|
||||
class TextAdvanced(BaseModel):
|
||||
discriminator: Literal["advanced"] = "advanced"
|
||||
max_characters: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Maximum number of characters to return",
|
||||
placeholder="1000",
|
||||
)
|
||||
include_html_tags: bool = SchemaField(
|
||||
default=False,
|
||||
description="Whether to include HTML tags in the text",
|
||||
description="Include HTML tags in the response, helps LLMs understand text structure",
|
||||
placeholder="False",
|
||||
)
|
||||
|
||||
|
||||
class HighlightSettings(BaseModel):
|
||||
num_sentences: int = SchemaField(
|
||||
default=3,
|
||||
default=1,
|
||||
description="Number of sentences per highlight",
|
||||
placeholder="3",
|
||||
placeholder="1",
|
||||
ge=1,
|
||||
)
|
||||
highlights_per_url: int = SchemaField(
|
||||
default=3,
|
||||
default=1,
|
||||
description="Number of highlights per URL",
|
||||
placeholder="3",
|
||||
placeholder="1",
|
||||
ge=1,
|
||||
)
|
||||
query: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Custom query to direct the LLM's selection of highlights",
|
||||
placeholder="Key advancements",
|
||||
)
|
||||
|
||||
|
||||
class SummarySettings(BaseModel):
|
||||
query: Optional[str] = SchemaField(
|
||||
default="",
|
||||
description="Query string for summarization",
|
||||
placeholder="Enter query",
|
||||
default=None,
|
||||
description="Custom query for the LLM-generated summary",
|
||||
placeholder="Main developments",
|
||||
)
|
||||
schema: Optional[dict] = SchemaField( # type: ignore
|
||||
default=None,
|
||||
description="JSON schema for structured output from summary",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class ExtrasSettings(BaseModel):
|
||||
links: int = SchemaField(
|
||||
default=0,
|
||||
description="Number of URLs to return from each webpage",
|
||||
placeholder="1",
|
||||
ge=0,
|
||||
)
|
||||
image_links: int = SchemaField(
|
||||
default=0,
|
||||
description="Number of images to return for each result",
|
||||
placeholder="1",
|
||||
ge=0,
|
||||
)
|
||||
|
||||
|
||||
class ContextEnabled(BaseModel):
|
||||
discriminator: Literal["enabled"] = "enabled"
|
||||
|
||||
|
||||
class ContextDisabled(BaseModel):
|
||||
discriminator: Literal["disabled"] = "disabled"
|
||||
|
||||
|
||||
class ContextAdvanced(BaseModel):
|
||||
discriminator: Literal["advanced"] = "advanced"
|
||||
max_characters: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Maximum character limit for context string",
|
||||
placeholder="10000",
|
||||
)
|
||||
|
||||
|
||||
class ContentSettings(BaseModel):
|
||||
text: TextSettings = SchemaField(
|
||||
default=TextSettings(),
|
||||
text: Optional[Union[bool, TextEnabled, TextDisabled, TextAdvanced]] = SchemaField(
|
||||
default=None,
|
||||
description="Text content retrieval. Boolean for simple enable/disable or object for advanced settings",
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
default=HighlightSettings(),
|
||||
highlights: Optional[HighlightSettings] = SchemaField(
|
||||
default=None,
|
||||
description="Text snippets most relevant from each page",
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
default=SummarySettings(),
|
||||
summary: Optional[SummarySettings] = SchemaField(
|
||||
default=None,
|
||||
description="LLM-generated summary of the webpage",
|
||||
)
|
||||
livecrawl: Optional[LivecrawlTypes] = SchemaField(
|
||||
default=None,
|
||||
description="Livecrawling options: never, fallback, always, preferred",
|
||||
advanced=True,
|
||||
)
|
||||
livecrawl_timeout: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Timeout for livecrawling in milliseconds",
|
||||
placeholder="10000",
|
||||
advanced=True,
|
||||
)
|
||||
subpages: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Number of subpages to crawl",
|
||||
placeholder="0",
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
subpage_target: Optional[Union[str, list[str]]] = SchemaField(
|
||||
default=None,
|
||||
description="Keyword(s) to find specific subpages of search results",
|
||||
advanced=True,
|
||||
)
|
||||
extras: Optional[ExtrasSettings] = SchemaField(
|
||||
default=None,
|
||||
description="Extra parameters for additional content",
|
||||
advanced=True,
|
||||
)
|
||||
context: Optional[Union[bool, ContextEnabled, ContextDisabled, ContextAdvanced]] = (
|
||||
SchemaField(
|
||||
default=None,
|
||||
description="Format search results into a context string for LLMs",
|
||||
advanced=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -127,3 +226,225 @@ class WebsetEnrichmentConfig(BaseModel):
|
||||
default=None,
|
||||
description="Options for the enrichment",
|
||||
)
|
||||
|
||||
|
||||
# Shared result models
|
||||
class ExaSearchExtras(BaseModel):
|
||||
links: list[str] = SchemaField(
|
||||
default_factory=list, description="Array of links from the search result"
|
||||
)
|
||||
imageLinks: list[str] = SchemaField(
|
||||
default_factory=list, description="Array of image links from the search result"
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchResults(BaseModel):
|
||||
title: str | None = None
|
||||
url: str | None = None
|
||||
publishedDate: str | None = None
|
||||
author: str | None = None
|
||||
id: str
|
||||
image: MediaFileType | None = None
|
||||
favicon: MediaFileType | None = None
|
||||
text: str | None = None
|
||||
highlights: list[str] = SchemaField(default_factory=list)
|
||||
highlightScores: list[float] = SchemaField(default_factory=list)
|
||||
summary: str | None = None
|
||||
subpages: list[dict] = SchemaField(default_factory=list)
|
||||
extras: ExaSearchExtras | None = None
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_result) -> "ExaSearchResults":
|
||||
"""Convert SDK Result (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_result, "id", ""),
|
||||
url=getattr(sdk_result, "url", None),
|
||||
title=getattr(sdk_result, "title", None),
|
||||
author=getattr(sdk_result, "author", None),
|
||||
publishedDate=getattr(sdk_result, "published_date", None),
|
||||
text=getattr(sdk_result, "text", None),
|
||||
highlights=getattr(sdk_result, "highlights", None) or [],
|
||||
highlightScores=getattr(sdk_result, "highlight_scores", None) or [],
|
||||
summary=getattr(sdk_result, "summary", None),
|
||||
subpages=getattr(sdk_result, "subpages", None) or [],
|
||||
image=getattr(sdk_result, "image", None),
|
||||
favicon=getattr(sdk_result, "favicon", None),
|
||||
extras=getattr(sdk_result, "extras", None),
|
||||
)
|
||||
|
||||
|
||||
# Cost tracking models
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float = SchemaField(default=0.0)
|
||||
neuralSearch: float = SchemaField(default=0.0)
|
||||
contentText: float = SchemaField(default=0.0)
|
||||
contentHighlight: float = SchemaField(default=0.0)
|
||||
contentSummary: float = SchemaField(default=0.0)
|
||||
|
||||
|
||||
class CostBreakdownItem(BaseModel):
|
||||
search: float = SchemaField(default=0.0)
|
||||
contents: float = SchemaField(default=0.0)
|
||||
breakdown: CostBreakdown = SchemaField(default_factory=CostBreakdown)
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float = SchemaField(default=0.005)
|
||||
neuralSearch_26_100_results: float = SchemaField(default=0.025)
|
||||
neuralSearch_100_plus_results: float = SchemaField(default=1.0)
|
||||
keywordSearch_1_100_results: float = SchemaField(default=0.0025)
|
||||
keywordSearch_100_plus_results: float = SchemaField(default=3.0)
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float = SchemaField(default=0.001)
|
||||
contentHighlight: float = SchemaField(default=0.001)
|
||||
contentSummary: float = SchemaField(default=0.001)
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float = SchemaField(description="Total dollar cost for your request")
|
||||
breakDown: list[CostBreakdownItem] = SchemaField(
|
||||
default_factory=list, description="Breakdown of costs by operation type"
|
||||
)
|
||||
perRequestPrices: PerRequestPrices = SchemaField(
|
||||
default_factory=PerRequestPrices,
|
||||
description="Standard price per request for different operations",
|
||||
)
|
||||
perPagePrices: PerPagePrices = SchemaField(
|
||||
default_factory=PerPagePrices,
|
||||
description="Standard price per page for different content operations",
|
||||
)
|
||||
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
# Handle backward compatibility with boolean
|
||||
if isinstance(text, bool):
|
||||
return text
|
||||
elif isinstance(text, TextDisabled):
|
||||
return False
|
||||
elif isinstance(text, TextEnabled):
|
||||
return True
|
||||
elif isinstance(text, TextAdvanced):
|
||||
text_dict = {}
|
||||
if text.max_characters:
|
||||
text_dict["maxCharacters"] = text.max_characters
|
||||
if text.include_html_tags:
|
||||
text_dict["includeHtmlTags"] = text.include_html_tags
|
||||
return text_dict if text_dict else True
|
||||
return None
|
||||
|
||||
|
||||
def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str, Any]:
|
||||
"""Process ContentSettings into API payload format."""
|
||||
if not contents:
|
||||
return {}
|
||||
|
||||
content_settings = {}
|
||||
|
||||
# Handle text field (can be boolean or object)
|
||||
text_value = process_text_field(contents.text)
|
||||
if text_value is not None:
|
||||
content_settings["text"] = text_value
|
||||
|
||||
# Handle highlights
|
||||
if contents.highlights:
|
||||
highlights_dict: Dict[str, Any] = {
|
||||
"numSentences": contents.highlights.num_sentences,
|
||||
"highlightsPerUrl": contents.highlights.highlights_per_url,
|
||||
}
|
||||
if contents.highlights.query:
|
||||
highlights_dict["query"] = contents.highlights.query
|
||||
content_settings["highlights"] = highlights_dict
|
||||
|
||||
if contents.summary:
|
||||
summary_dict = {}
|
||||
if contents.summary.query:
|
||||
summary_dict["query"] = contents.summary.query
|
||||
if contents.summary.schema:
|
||||
summary_dict["schema"] = contents.summary.schema
|
||||
content_settings["summary"] = summary_dict
|
||||
|
||||
if contents.livecrawl:
|
||||
content_settings["livecrawl"] = contents.livecrawl.value
|
||||
|
||||
if contents.livecrawl_timeout is not None:
|
||||
content_settings["livecrawlTimeout"] = contents.livecrawl_timeout
|
||||
|
||||
if contents.subpages is not None:
|
||||
content_settings["subpages"] = contents.subpages
|
||||
|
||||
if contents.subpage_target:
|
||||
content_settings["subpageTarget"] = contents.subpage_target
|
||||
|
||||
if contents.extras:
|
||||
extras_dict = {}
|
||||
if contents.extras.links:
|
||||
extras_dict["links"] = contents.extras.links
|
||||
if contents.extras.image_links:
|
||||
extras_dict["imageLinks"] = contents.extras.image_links
|
||||
content_settings["extras"] = extras_dict
|
||||
|
||||
context_value = process_context_field(contents.context)
|
||||
if context_value is not None:
|
||||
content_settings["context"] = context_value
|
||||
|
||||
return content_settings
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
return None
|
||||
|
||||
# Handle backward compatibility with boolean
|
||||
if isinstance(context, bool):
|
||||
return context if context else None
|
||||
elif isinstance(context, dict) and "maxCharacters" in context:
|
||||
return {"maxCharacters": context["maxCharacters"]}
|
||||
elif isinstance(context, ContextDisabled):
|
||||
return None # Don't send context field at all when disabled
|
||||
elif isinstance(context, ContextEnabled):
|
||||
return True
|
||||
elif isinstance(context, ContextAdvanced):
|
||||
if context.max_characters:
|
||||
return {"maxCharacters": context.max_characters}
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
def format_date_fields(
|
||||
input_data: Any, date_field_mapping: Dict[str, str]
|
||||
) -> Dict[str, str]:
|
||||
"""Format datetime fields for API payload."""
|
||||
formatted_dates = {}
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
formatted_dates[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
return formatted_dates
|
||||
|
||||
|
||||
def add_optional_fields(
|
||||
input_data: Any,
|
||||
field_mapping: Dict[str, str],
|
||||
payload: Dict[str, Any],
|
||||
process_enums: bool = False,
|
||||
) -> None:
|
||||
"""Add optional fields to payload if they have values."""
|
||||
for input_field, api_field in field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value: # Only add non-empty values
|
||||
if process_enums and hasattr(value, "value"):
|
||||
payload[api_field] = value.value
|
||||
else:
|
||||
payload[api_field] = value
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Enum definitions based on available options
|
||||
class WebsetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WebsetSearchStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known, based on example it's "created"
|
||||
|
||||
|
||||
class ImportStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
CSV = "csv"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
ENABLED = "enabled"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
SEARCH = "search"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorRunStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class CanceledReason(str, Enum):
|
||||
WEBSET_DELETED = "webset_deleted"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class FailedReason(str, Enum):
|
||||
INVALID_FORMAT = "invalid_format"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class Confidence(str, Enum):
|
||||
HIGH = "high"
|
||||
# Add more if known
|
||||
|
||||
|
||||
# Nested models
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Criterion(BaseModel):
|
||||
description: str
|
||||
successRate: Optional[int] = None
|
||||
|
||||
|
||||
class ExcludeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
definition: str
|
||||
limit: Optional[float] = None
|
||||
|
||||
|
||||
class ScopeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
relationship: Optional[Relationship] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
found: int
|
||||
analyzed: int
|
||||
completion: int
|
||||
timeLeft: int
|
||||
|
||||
|
||||
class Bounds(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class Expected(BaseModel):
|
||||
total: int
|
||||
confidence: str = Field(default="high") # Use str or Confidence enum
|
||||
bounds: Bounds
|
||||
|
||||
|
||||
class Recall(BaseModel):
|
||||
expected: Expected
|
||||
reasoning: str
|
||||
|
||||
|
||||
class WebsetSearch(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_search")
|
||||
status: str = Field(default="created") # Or use WebsetSearchStatus
|
||||
websetId: str
|
||||
query: str
|
||||
entity: Entity
|
||||
criteria: List[Criterion]
|
||||
count: int
|
||||
behavior: str = Field(default="override")
|
||||
exclude: List[ExcludeItem]
|
||||
scope: List[ScopeItem]
|
||||
progress: Progress
|
||||
recall: Recall
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
canceledAt: Optional[datetime] = None
|
||||
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class ImportEntity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="import")
|
||||
status: str = Field(default="pending") # Or use ImportStatus
|
||||
format: str = Field(default="csv") # Or use ImportFormat
|
||||
entity: ImportEntity
|
||||
title: str
|
||||
count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
|
||||
failedAt: Optional[datetime] = None
|
||||
failedMessage: Optional[str] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
label: str
|
||||
|
||||
|
||||
class WebsetEnrichment(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_enrichment")
|
||||
status: str = Field(default="pending") # Or use EnrichmentStatus
|
||||
websetId: str
|
||||
title: str
|
||||
description: str
|
||||
format: str = Field(default="text") # Or use EnrichmentFormat
|
||||
options: List[Option]
|
||||
instructions: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Cadence(BaseModel):
|
||||
cron: str
|
||||
timezone: str = Field(default="Etc/UTC")
|
||||
|
||||
|
||||
class BehaviorConfig(BaseModel):
|
||||
query: Optional[str] = None
|
||||
criteria: Optional[List[Criterion]] = None
|
||||
entity: Optional[Entity] = None
|
||||
count: Optional[int] = None
|
||||
behavior: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class Behavior(BaseModel):
|
||||
type: str = Field(default="search") # Or use MonitorBehaviorType
|
||||
config: BehaviorConfig
|
||||
|
||||
|
||||
class MonitorRun(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor_run")
|
||||
status: str = Field(default="created") # Or use MonitorRunStatus
|
||||
monitorId: str
|
||||
type: str = Field(default="search")
|
||||
completedAt: Optional[datetime] = None
|
||||
failedAt: Optional[datetime] = None
|
||||
failedReason: Optional[str] = None
|
||||
canceledAt: Optional[datetime] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Monitor(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor")
|
||||
status: str = Field(default="enabled") # Or use MonitorStatus
|
||||
websetId: str
|
||||
cadence: Cadence
|
||||
behavior: Behavior
|
||||
lastRun: Optional[MonitorRun] = None
|
||||
nextRunAt: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset")
|
||||
status: WebsetStatus
|
||||
externalId: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
searches: List[WebsetSearch]
|
||||
imports: List[Import]
|
||||
enrichments: List[WebsetEnrichment]
|
||||
monitors: List[Monitor]
|
||||
streams: List[Any]
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ListWebsets(BaseModel):
|
||||
data: List[Webset]
|
||||
hasMore: bool
|
||||
nextCursor: Optional[str] = None
|
||||
518
autogpt_platform/backend/backend/blocks/exa/research.py
Normal file
518
autogpt_platform/backend/backend/blocks/exa/research.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""
|
||||
Exa Research Task Blocks
|
||||
|
||||
Provides asynchronous research capabilities that explore the web, gather sources,
|
||||
synthesize findings, and return structured results with citations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class ResearchModel(str, Enum):
|
||||
"""Available research models."""
|
||||
|
||||
FAST = "exa-research-fast"
|
||||
STANDARD = "exa-research"
|
||||
PRO = "exa-research-pro"
|
||||
|
||||
|
||||
class ResearchStatus(str, Enum):
|
||||
"""Research task status."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ResearchCostModel(BaseModel):
|
||||
"""Cost breakdown for a research request."""
|
||||
|
||||
total: float
|
||||
num_searches: int
|
||||
num_pages: int
|
||||
reasoning_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "ResearchCostModel":
|
||||
"""Convert API response, rounding fractional counts to integers."""
|
||||
return cls(
|
||||
total=data.get("total", 0.0),
|
||||
num_searches=int(round(data.get("numSearches", 0))),
|
||||
num_pages=int(round(data.get("numPages", 0))),
|
||||
reasoning_tokens=int(round(data.get("reasoningTokens", 0))),
|
||||
)
|
||||
|
||||
|
||||
class ResearchOutputModel(BaseModel):
|
||||
"""Research output with content and optional structured data."""
|
||||
|
||||
content: str
|
||||
parsed: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ResearchTaskModel(BaseModel):
|
||||
"""Stable output model for research tasks."""
|
||||
|
||||
research_id: str
|
||||
created_at: int
|
||||
model: str
|
||||
instructions: str
|
||||
status: str
|
||||
output_schema: Optional[Dict[str, Any]] = None
|
||||
output: Optional[ResearchOutputModel] = None
|
||||
cost_dollars: Optional[ResearchCostModel] = None
|
||||
finished_at: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "ResearchTaskModel":
|
||||
"""Convert API response to our stable model."""
|
||||
output_data = data.get("output")
|
||||
output = None
|
||||
if output_data:
|
||||
output = ResearchOutputModel(
|
||||
content=output_data.get("content", ""),
|
||||
parsed=output_data.get("parsed"),
|
||||
)
|
||||
|
||||
cost_data = data.get("costDollars")
|
||||
cost = None
|
||||
if cost_data:
|
||||
cost = ResearchCostModel.from_api(cost_data)
|
||||
|
||||
return cls(
|
||||
research_id=data.get("researchId", ""),
|
||||
created_at=data.get("createdAt", 0),
|
||||
model=data.get("model", "exa-research"),
|
||||
instructions=data.get("instructions", ""),
|
||||
status=data.get("status", "pending"),
|
||||
output_schema=data.get("outputSchema"),
|
||||
output=output,
|
||||
cost_dollars=cost,
|
||||
finished_at=data.get("finishedAt"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
class ExaCreateResearchBlock(Block):
|
||||
"""Create an asynchronous research task that explores the web and synthesizes findings."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="Research instructions - clearly define what information to find, how to conduct research, and desired output format.",
|
||||
placeholder="Research the top 5 AI coding assistants, their features, pricing, and user reviews",
|
||||
)
|
||||
model: ResearchModel = SchemaField(
|
||||
default=ResearchModel.STANDARD,
|
||||
description="Research model: 'fast' for quick results, 'standard' for balanced quality, 'pro' for thorough analysis",
|
||||
)
|
||||
output_schema: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="JSON Schema to enforce structured output. When provided, results are validated and returned as parsed JSON.",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=True,
|
||||
description="Wait for research to complete before returning. Ensures you get results immediately.",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=600,
|
||||
description="Maximum time to wait for completion in seconds (only if wait_for_completion is True)",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=3600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(
|
||||
description="Unique identifier for tracking this research request"
|
||||
)
|
||||
status: str = SchemaField(description="Final status of the research")
|
||||
model: str = SchemaField(description="The research model used")
|
||||
instructions: str = SchemaField(
|
||||
description="The research instructions provided"
|
||||
)
|
||||
created_at: int = SchemaField(
|
||||
description="When the research was created (Unix timestamp in ms)"
|
||||
)
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (only if wait_for_completion was True and completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output (only if wait_for_completion and outputSchema were provided)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(
|
||||
description="Total cost in USD (only if wait_for_completion was True and completed)"
|
||||
)
|
||||
elapsed_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (only if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a1f2e3d4-c5b6-4a78-9012-3456789abcde",
|
||||
description="Create research task with optional waiting - explores web and synthesizes findings with citations",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.AI},
|
||||
input_schema=ExaCreateResearchBlock.Input,
|
||||
output_schema=ExaCreateResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/research/v1"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": input_data.model.value,
|
||||
"instructions": input_data.instructions,
|
||||
}
|
||||
|
||||
if input_data.output_schema:
|
||||
payload["outputSchema"] = input_data.output_schema
|
||||
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
research_id = data.get("researchId", "")
|
||||
|
||||
if input_data.wait_for_completion:
|
||||
start_time = time.time()
|
||||
get_url = f"https://api.exa.ai/research/v1/{research_id}"
|
||||
get_headers = {"x-api-key": credentials.api_key.get_secret_value()}
|
||||
check_interval = 10
|
||||
|
||||
while time.time() - start_time < input_data.polling_timeout:
|
||||
poll_response = await Requests().get(url=get_url, headers=get_headers)
|
||||
poll_data = poll_response.json()
|
||||
|
||||
status = poll_data.get("status", "")
|
||||
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
research = ResearchTaskModel.from_api(poll_data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "status", research.status
|
||||
yield "model", research.model
|
||||
yield "instructions", research.instructions
|
||||
yield "created_at", research.created_at
|
||||
yield "elapsed_time", elapsed
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
raise ValueError(
|
||||
f"Research did not complete within {input_data.polling_timeout} seconds"
|
||||
)
|
||||
else:
|
||||
yield "research_id", research_id
|
||||
yield "status", data.get("status", "pending")
|
||||
yield "model", data.get("model", input_data.model.value)
|
||||
yield "instructions", data.get("instructions", input_data.instructions)
|
||||
yield "created_at", data.get("createdAt", 0)
|
||||
|
||||
|
||||
class ExaGetResearchBlock(Block):
|
||||
"""Get the status and results of a research task."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
research_id: str = SchemaField(
|
||||
description="The ID of the research task to retrieve",
|
||||
placeholder="01jszdfs0052sg4jc552sg4jc5",
|
||||
)
|
||||
include_events: bool = SchemaField(
|
||||
default=False,
|
||||
description="Include detailed event log of research operations",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(description="The research task identifier")
|
||||
status: str = SchemaField(
|
||||
description="Current status: pending, running, completed, canceled, or failed"
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="The original research instructions"
|
||||
)
|
||||
model: str = SchemaField(description="The research model used")
|
||||
created_at: int = SchemaField(
|
||||
description="When research was created (Unix timestamp in ms)"
|
||||
)
|
||||
finished_at: Optional[int] = SchemaField(
|
||||
description="When research finished (Unix timestamp in ms, if completed/canceled/failed)"
|
||||
)
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (if completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output matching outputSchema (if provided and completed)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(
|
||||
description="Total cost in USD (if completed)"
|
||||
)
|
||||
cost_searches: Optional[int] = SchemaField(
|
||||
description="Number of searches performed (if completed)"
|
||||
)
|
||||
cost_pages: Optional[int] = SchemaField(
|
||||
description="Number of pages crawled (if completed)"
|
||||
)
|
||||
cost_reasoning_tokens: Optional[int] = SchemaField(
|
||||
description="AI tokens used for reasoning (if completed)"
|
||||
)
|
||||
error_message: Optional[str] = SchemaField(
|
||||
description="Error message if research failed"
|
||||
)
|
||||
events: Optional[List[dict]] = SchemaField(
|
||||
description="Detailed event log (if include_events was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b2e3f4a5-6789-4bcd-9012-3456789abcde",
|
||||
description="Get status and results of a research task",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetResearchBlock.Input,
|
||||
output_schema=ExaGetResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/research/v1/{input_data.research_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.include_events:
|
||||
params["events"] = "true"
|
||||
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
research = ResearchTaskModel.from_api(data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "status", research.status
|
||||
yield "instructions", research.instructions
|
||||
yield "model", research.model
|
||||
yield "created_at", research.created_at
|
||||
yield "finished_at", research.finished_at
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
if input_data.include_events:
|
||||
yield "events", data.get("events", [])
|
||||
|
||||
|
||||
class ExaWaitForResearchBlock(Block):
|
||||
"""Wait for a research task to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
research_id: str = SchemaField(
|
||||
description="The ID of the research task to wait for",
|
||||
placeholder="01jszdfs0052sg4jc552sg4jc5",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=600,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=3600,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=10,
|
||||
description="Seconds between status checks",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(description="The research task identifier")
|
||||
final_status: str = SchemaField(description="Final status when polling stopped")
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (if completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output (if outputSchema was provided and completed)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(description="Total cost in USD")
|
||||
elapsed_time: float = SchemaField(description="Total time waited in seconds")
|
||||
timed_out: bool = SchemaField(
|
||||
description="Whether polling timed out before completion"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c3d4e5f6-7890-4abc-9012-3456789abcde",
|
||||
description="Wait for a research task to complete with configurable timeout",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForResearchBlock.Input,
|
||||
output_schema=ExaWaitForResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
url = f"https://api.exa.ai/research/v1/{input_data.research_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
status = data.get("status", "")
|
||||
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
research = ResearchTaskModel.from_api(data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "final_status", research.status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", False
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
|
||||
return
|
||||
|
||||
await asyncio.sleep(input_data.check_interval)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "research_id", input_data.research_id
|
||||
yield "final_status", data.get("status", "unknown")
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
|
||||
class ExaListResearchBlock(Block):
|
||||
"""List all research tasks with pagination support."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of research tasks to return (1-50)",
|
||||
ge=1,
|
||||
le=50,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_tasks: List[ResearchTaskModel] = SchemaField(
|
||||
description="List of research tasks ordered by creation time (newest first)"
|
||||
)
|
||||
research_task: ResearchTaskModel = SchemaField(
|
||||
description="Individual research task (yielded for each task)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more tasks to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d4e5f6a7-8901-4bcd-9012-3456789abcde",
|
||||
description="List all research tasks with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListResearchBlock.Input,
|
||||
output_schema=ExaListResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/research/v1"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"limit": input_data.limit,
|
||||
}
|
||||
if input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
tasks = [ResearchTaskModel.from_api(task) for task in data.get("data", [])]
|
||||
|
||||
yield "research_tasks", tasks
|
||||
|
||||
for task in tasks:
|
||||
yield "research_task", task
|
||||
|
||||
yield "has_more", data.get("hasMore", False)
|
||||
yield "next_cursor", data.get("nextCursor")
|
||||
@@ -1,4 +1,8 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -8,12 +12,35 @@ from backend.sdk import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchTypes(Enum):
|
||||
KEYWORD = "keyword"
|
||||
NEURAL = "neural"
|
||||
FAST = "fast"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ExaSearchCategories(Enum):
|
||||
COMPANY = "company"
|
||||
RESEARCH_PAPER = "research paper"
|
||||
NEWS = "news"
|
||||
PDF = "pdf"
|
||||
GITHUB = "github"
|
||||
TWEET = "tweet"
|
||||
PERSONAL_SITE = "personal site"
|
||||
LINKEDIN_PROFILE = "linkedin profile"
|
||||
FINANCIAL_REPORT = "financial report"
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
@@ -22,12 +49,18 @@ class ExaSearchBlock(Block):
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(description="The search query")
|
||||
use_auto_prompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt", default=True, advanced=True
|
||||
type: ExaSearchTypes = SchemaField(
|
||||
description="Type of search", default=ExaSearchTypes.AUTO, advanced=True
|
||||
)
|
||||
type: str = SchemaField(description="Type of search", default="", advanced=True)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within", default="", advanced=True
|
||||
category: ExaSearchCategories | None = SchemaField(
|
||||
description="Category to search within: company, research paper, news, pdf, github, tweet, personal site, linkedin profile, financial report",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
user_location: str | None = SchemaField(
|
||||
description="The two-letter ISO country code of the user (e.g., 'US')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
@@ -40,17 +73,17 @@ class ExaSearchBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content"
|
||||
start_crawl_date: datetime | None = SchemaField(
|
||||
description="Start date for crawled content", advanced=True, default=None
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content"
|
||||
end_crawl_date: datetime | None = SchemaField(
|
||||
description="End date for crawled content", advanced=True, default=None
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content"
|
||||
start_published_date: datetime | None = SchemaField(
|
||||
description="Start date for published content", advanced=True, default=None
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content"
|
||||
end_published_date: datetime | None = SchemaField(
|
||||
description="End date for published content", advanced=True, default=None
|
||||
)
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include", default_factory=list, advanced=True
|
||||
@@ -63,14 +96,30 @@ class ExaSearchBlock(Block):
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
moderation: bool = SchemaField(
|
||||
description="Enable content moderation to filter unsafe content from search results",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list = SchemaField(
|
||||
description="List of search results", default_factory=list
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of search results"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
result: ExaSearchResults = SchemaField(description="Single search result")
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the search results ready for LLMs."
|
||||
)
|
||||
search_type: str = SchemaField(
|
||||
description="For auto searches, indicates which search type was selected."
|
||||
)
|
||||
resolved_search_type: str = SchemaField(
|
||||
description="The search type that was actually used for this request (neural or keyword)"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -84,51 +133,76 @@ class ExaSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
sdk_kwargs = {
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.use_auto_prompt,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.model_dump(),
|
||||
"num_results": input_data.number_of_results,
|
||||
}
|
||||
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
if input_data.type:
|
||||
sdk_kwargs["type"] = input_data.type.value
|
||||
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
if input_data.category:
|
||||
sdk_kwargs["category"] = input_data.category.value
|
||||
|
||||
optional_field_mapping = {
|
||||
"type": "type",
|
||||
"category": "category",
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
if input_data.user_location:
|
||||
sdk_kwargs["user_location"] = input_data.user_location
|
||||
|
||||
# Add other fields
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# Extract just the results array from the response
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
|
||||
if input_data.moderation:
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
|
||||
# heck if we need to use search_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if content_settings:
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.search_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
response = await aexa.search(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.resolved_search_type:
|
||||
yield "resolved_search_type", response.resolved_search_type
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -9,12 +11,16 @@ from backend.sdk import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
|
||||
class ExaFindSimilarBlock(Block):
|
||||
@@ -29,7 +35,7 @@ class ExaFindSimilarBlock(Block):
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: list[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
description="List of domains to include in the search. If specified, results will only come from these domains.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -38,17 +44,17 @@ class ExaFindSimilarBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content"
|
||||
start_crawl_date: Optional[datetime] = SchemaField(
|
||||
description="Start date for crawled content", advanced=True, default=None
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content"
|
||||
end_crawl_date: Optional[datetime] = SchemaField(
|
||||
description="End date for crawled content", advanced=True, default=None
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content"
|
||||
start_published_date: Optional[datetime] = SchemaField(
|
||||
description="Start date for published content", advanced=True, default=None
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content"
|
||||
end_published_date: Optional[datetime] = SchemaField(
|
||||
description="End date for published content", advanced=True, default=None
|
||||
)
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
@@ -65,15 +71,27 @@ class ExaFindSimilarBlock(Block):
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
moderation: bool = SchemaField(
|
||||
description="Enable content moderation to filter unsafe content from search results",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of similar documents with metadata and content"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single similar document result"
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs."
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -87,47 +105,65 @@ class ExaFindSimilarBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/findSimilar"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
sdk_kwargs = {
|
||||
"url": input_data.url,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.model_dump(),
|
||||
"num_results": input_data.number_of_results,
|
||||
}
|
||||
|
||||
optional_field_mapping = {
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
|
||||
# Add optional fields if they have values
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
if input_data.moderation:
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# check if we need to use find_similar_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if content_settings:
|
||||
# Use find_similar_and_contents when contents are requested
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.find_similar_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
response = await aexa.find_similar(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -132,45 +132,33 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""Process incoming Exa webhook payload."""
|
||||
try:
|
||||
payload = input_data.payload
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(
|
||||
event_type, input_data.event_filter
|
||||
)
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(event_type, input_data.event_filter)
|
||||
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors gracefully
|
||||
yield "event_type", "error"
|
||||
yield "event_id", ""
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "data", {"error": str(e)}
|
||||
yield "timestamp", ""
|
||||
yield "metadata", {}
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
def _should_process_event(
|
||||
self, event_type: str, event_filter: WebsetEventFilter
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,554 @@
|
||||
"""
|
||||
Exa Websets Enrichment Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing enrichments on webset items,
|
||||
allowing extraction of additional structured data from existing items.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetEnrichment as SdkWebsetEnrichment
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
class WebsetEnrichmentModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetEnrichment."""
|
||||
|
||||
id: str
|
||||
webset_id: str
|
||||
status: str
|
||||
title: Optional[str]
|
||||
description: str
|
||||
format: str
|
||||
options: List[str]
|
||||
instructions: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, enrichment: SdkWebsetEnrichment) -> "WebsetEnrichmentModel":
|
||||
"""Convert SDK WebsetEnrichment to our stable model."""
|
||||
# Extract options
|
||||
options_list = []
|
||||
if enrichment.options:
|
||||
for option in enrichment.options:
|
||||
option_dict = option.model_dump(by_alias=True)
|
||||
options_list.append(option_dict.get("label", ""))
|
||||
|
||||
return cls(
|
||||
id=enrichment.id,
|
||||
webset_id=enrichment.webset_id,
|
||||
status=(
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
),
|
||||
title=enrichment.title,
|
||||
description=enrichment.description,
|
||||
format=(
|
||||
enrichment.format.value
|
||||
if enrichment.format and hasattr(enrichment.format, "value")
|
||||
else "text"
|
||||
),
|
||||
options=options_list,
|
||||
instructions=enrichment.instructions,
|
||||
metadata=enrichment.metadata if enrichment.metadata else {},
|
||||
created_at=(
|
||||
enrichment.created_at.isoformat() if enrichment.created_at else ""
|
||||
),
|
||||
updated_at=(
|
||||
enrichment.updated_at.isoformat() if enrichment.updated_at else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
"""Format types for enrichment responses."""
|
||||
|
||||
TEXT = "text" # Free text response
|
||||
DATE = "date" # Date/datetime format
|
||||
NUMBER = "number" # Numeric value
|
||||
OPTIONS = "options" # Multiple choice from provided options
|
||||
EMAIL = "email" # Email address format
|
||||
PHONE = "phone" # Phone number format
|
||||
|
||||
|
||||
class ExaCreateEnrichmentBlock(Block):
|
||||
"""Create a new enrichment to extract additional data from webset items."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="What data to extract from each item",
|
||||
placeholder="Extract the company's main product or service offering",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Short title for this enrichment (auto-generated if not provided)",
|
||||
placeholder="Main Product",
|
||||
)
|
||||
format: EnrichmentFormat = SchemaField(
|
||||
default=EnrichmentFormat.TEXT,
|
||||
description="Expected format of the extracted data",
|
||||
)
|
||||
options: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Available options when format is 'options'",
|
||||
placeholder='["B2B", "B2C", "Both", "Unknown"]',
|
||||
advanced=True,
|
||||
)
|
||||
apply_to_existing: bool = SchemaField(
|
||||
default=True,
|
||||
description="Apply this enrichment to existing items in the webset",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the enrichment",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for the enrichment to complete on existing items",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait for completion in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the created enrichment"
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The webset this enrichment belongs to"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(
|
||||
description="Description of what data is extracted"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the extracted data")
|
||||
instructions: str = SchemaField(
|
||||
description="Generated instructions for the enrichment"
|
||||
)
|
||||
items_enriched: Optional[int] = SchemaField(
|
||||
description="Number of items enriched (if wait_for_completion was True)"
|
||||
)
|
||||
completion_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="71146ae8-0cb1-4a15-8cde-eae30de71cb6",
|
||||
description="Create enrichments to extract additional structured data from webset items",
|
||||
categories={BlockCategory.AI, BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateEnrichmentBlock.Input,
|
||||
output_schema=ExaCreateEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import time
|
||||
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"description": input_data.description,
|
||||
"format": input_data.format.value,
|
||||
}
|
||||
|
||||
# Add title if provided
|
||||
if input_data.title:
|
||||
payload["title"] = input_data.title
|
||||
|
||||
# Add options for 'options' format
|
||||
if input_data.format == EnrichmentFormat.OPTIONS and input_data.options:
|
||||
payload["options"] = [{"label": opt} for opt in input_data.options]
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
enrichment_id = sdk_enrichment.id
|
||||
status = (
|
||||
sdk_enrichment.status.value
|
||||
if hasattr(sdk_enrichment.status, "value")
|
||||
else str(sdk_enrichment.status)
|
||||
)
|
||||
|
||||
# If wait_for_completion is True and apply_to_existing is True, poll for completion
|
||||
if input_data.wait_for_completion and input_data.apply_to_existing:
|
||||
import asyncio
|
||||
|
||||
poll_interval = 5
|
||||
max_interval = 30
|
||||
poll_start = time.time()
|
||||
items_enriched = 0
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_enrich = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
current_status = (
|
||||
current_enrich.status.value
|
||||
if hasattr(current_enrich.status, "value")
|
||||
else str(current_enrich.status)
|
||||
)
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
items_enriched += search.progress.found
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", current_status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
yield "items_enriched", items_enriched
|
||||
yield "completion_time", completion_time
|
||||
return
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
poll_interval = min(poll_interval * 1.5, max_interval)
|
||||
|
||||
# Timeout
|
||||
completion_time = time.time() - start_time
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
yield "items_enriched", 0
|
||||
yield "completion_time", completion_time
|
||||
else:
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
|
||||
|
||||
class ExaGetEnrichmentBlock(Block):
|
||||
"""Get the status and details of a webset enrichment."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to retrieve",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(
|
||||
description="Description of what data is extracted"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the extracted data")
|
||||
options: list[str] = SchemaField(
|
||||
description="Available options (for 'options' format)"
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="Generated instructions for the enrichment"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the enrichment was created")
|
||||
updated_at: str = SchemaField(
|
||||
description="When the enrichment was last updated"
|
||||
)
|
||||
metadata: dict = SchemaField(description="Metadata attached to the enrichment")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b8c9d0e1-f2a3-4567-89ab-cdef01234567",
|
||||
description="Get the status and details of a webset enrichment",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetEnrichmentBlock.Input,
|
||||
output_schema=ExaGetEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
enrichment = WebsetEnrichmentModel.from_sdk(sdk_enrichment)
|
||||
|
||||
yield "enrichment_id", enrichment.id
|
||||
yield "status", enrichment.status
|
||||
yield "title", enrichment.title
|
||||
yield "description", enrichment.description
|
||||
yield "format", enrichment.format
|
||||
yield "options", enrichment.options
|
||||
yield "instructions", enrichment.instructions
|
||||
yield "created_at", enrichment.created_at
|
||||
yield "updated_at", enrichment.updated_at
|
||||
yield "metadata", enrichment.metadata
|
||||
|
||||
|
||||
class ExaUpdateEnrichmentBlock(Block):
|
||||
"""Update an existing enrichment configuration."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to update",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New description for what data to extract",
|
||||
)
|
||||
format: Optional[EnrichmentFormat] = SchemaField(
|
||||
default=None,
|
||||
description="New format for the extracted data",
|
||||
)
|
||||
options: Optional[list[str]] = SchemaField(
|
||||
default=None,
|
||||
description="New options when format is 'options'",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="New metadata to attach to the enrichment",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(description="Updated description")
|
||||
format: str = SchemaField(description="Updated format")
|
||||
success: str = SchemaField(description="Whether the update was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c8d5c5fb-9684-4a29-bd2a-5b38d71776c9",
|
||||
description="Update an existing enrichment configuration",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateEnrichmentBlock.Input,
|
||||
output_schema=ExaUpdateEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/enrichments/{input_data.enrichment_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the update payload
|
||||
payload = {}
|
||||
|
||||
if input_data.description is not None:
|
||||
payload["description"] = input_data.description
|
||||
|
||||
if input_data.format is not None:
|
||||
payload["format"] = input_data.format.value
|
||||
|
||||
if input_data.options is not None:
|
||||
payload["options"] = [{"label": opt} for opt in input_data.options]
|
||||
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
try:
|
||||
response = await Requests().patch(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "enrichment_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "title", data.get("title", "")
|
||||
yield "description", data.get("description", "")
|
||||
yield "format", data.get("format", "")
|
||||
yield "success", "true"
|
||||
|
||||
except ValueError as e:
|
||||
# Re-raise user input validation errors
|
||||
raise ValueError(f"Failed to update enrichment: {e}") from e
|
||||
# Let all other exceptions propagate naturally
|
||||
|
||||
|
||||
class ExaDeleteEnrichmentBlock(Block):
|
||||
"""Delete an enrichment from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to delete",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(description="The ID of the deleted enrichment")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b250de56-2ca6-4237-a7b8-b5684892189f",
|
||||
description="Delete an enrichment from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteEnrichmentBlock.Input,
|
||||
output_schema=ExaDeleteEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
yield "enrichment_id", deleted_enrichment.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaCancelEnrichmentBlock(Block):
|
||||
"""Cancel a running enrichment operation."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to cancel",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the canceled enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Status after cancellation")
|
||||
items_enriched_before_cancel: int = SchemaField(
|
||||
description="Approximate number of items enriched before cancellation"
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7e1f8f0f-b6ab-43b3-bd1d-0c534a649295",
|
||||
description="Cancel a running enrichment operation",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelEnrichmentBlock.Input,
|
||||
output_schema=ExaCancelEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
|
||||
for sdk_item in items_response.data:
|
||||
# Check if this enrichment is present
|
||||
for enrich_result in sdk_item.enrichments:
|
||||
if enrich_result.enrichment_id == input_data.enrichment_id:
|
||||
items_enriched += 1
|
||||
break
|
||||
|
||||
status = (
|
||||
canceled_enrichment.status.value
|
||||
if hasattr(canceled_enrichment.status, "value")
|
||||
else str(canceled_enrichment.status)
|
||||
)
|
||||
|
||||
yield "enrichment_id", canceled_enrichment.id
|
||||
yield "status", status
|
||||
yield "items_enriched_before_cancel", items_enriched
|
||||
yield "success", "true"
|
||||
@@ -0,0 +1,676 @@
|
||||
"""
|
||||
Exa Websets Import/Export Management Blocks
|
||||
|
||||
This module provides blocks for importing data into websets from CSV files
|
||||
and exporting webset data in various formats.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import Optional, Union
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import CreateImportResponse
|
||||
from exa_py.websets.types import Import as SdkImport
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class ImportModel(BaseModel):
|
||||
"""Stable output model mirroring SDK Import."""
|
||||
|
||||
id: str
|
||||
status: str
|
||||
title: str
|
||||
format: str
|
||||
entity_type: str
|
||||
count: int
|
||||
upload_url: Optional[str] # Only in CreateImportResponse
|
||||
upload_valid_until: Optional[str] # Only in CreateImportResponse
|
||||
failed_reason: str
|
||||
failed_message: str
|
||||
metadata: dict
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(
|
||||
cls, import_obj: Union[SdkImport, CreateImportResponse]
|
||||
) -> "ImportModel":
|
||||
"""Convert SDK Import or CreateImportResponse to our stable model."""
|
||||
# Extract entity type from union (may be None)
|
||||
entity_type = "unknown"
|
||||
if import_obj.entity:
|
||||
entity_dict = import_obj.entity.model_dump(by_alias=True, exclude_none=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Handle status enum
|
||||
status_str = (
|
||||
import_obj.status.value
|
||||
if hasattr(import_obj.status, "value")
|
||||
else str(import_obj.status)
|
||||
)
|
||||
|
||||
# Handle format enum
|
||||
format_str = (
|
||||
import_obj.format.value
|
||||
if hasattr(import_obj.format, "value")
|
||||
else str(import_obj.format)
|
||||
)
|
||||
|
||||
# Handle failed_reason enum (may be None or enum)
|
||||
failed_reason_str = ""
|
||||
if import_obj.failed_reason:
|
||||
failed_reason_str = (
|
||||
import_obj.failed_reason.value
|
||||
if hasattr(import_obj.failed_reason, "value")
|
||||
else str(import_obj.failed_reason)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=import_obj.id,
|
||||
status=status_str,
|
||||
title=import_obj.title or "",
|
||||
format=format_str,
|
||||
entity_type=entity_type,
|
||||
count=int(import_obj.count or 0),
|
||||
upload_url=getattr(
|
||||
import_obj, "upload_url", None
|
||||
), # Only in CreateImportResponse
|
||||
upload_valid_until=getattr(
|
||||
import_obj, "upload_valid_until", None
|
||||
), # Only in CreateImportResponse
|
||||
failed_reason=failed_reason_str,
|
||||
failed_message=import_obj.failed_message or "",
|
||||
metadata=import_obj.metadata or {},
|
||||
created_at=(
|
||||
import_obj.created_at.isoformat() if import_obj.created_at else ""
|
||||
),
|
||||
updated_at=(
|
||||
import_obj.updated_at.isoformat() if import_obj.updated_at else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
"""Supported import formats."""
|
||||
|
||||
CSV = "csv"
|
||||
# JSON = "json" # Future support
|
||||
|
||||
|
||||
class ImportEntityType(str, Enum):
|
||||
"""Entity types for imports."""
|
||||
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Supported export formats."""
|
||||
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
JSON_LINES = "jsonl"
|
||||
|
||||
|
||||
class ExaCreateImportBlock(Block):
|
||||
"""Create an import to load external data that can be used with websets."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title for this import",
|
||||
placeholder="Customer List Import",
|
||||
)
|
||||
csv_data: str = SchemaField(
|
||||
description="CSV data to import (as a string)",
|
||||
placeholder="name,url\nAcme Corp,https://acme.com\nExample Inc,https://example.com",
|
||||
)
|
||||
entity_type: ImportEntityType = SchemaField(
|
||||
default=ImportEntityType.COMPANY,
|
||||
description="Type of entities being imported",
|
||||
)
|
||||
entity_description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type",
|
||||
advanced=True,
|
||||
)
|
||||
identifier_column: int = SchemaField(
|
||||
default=0,
|
||||
description="Column index containing the identifier (0-based)",
|
||||
ge=0,
|
||||
)
|
||||
url_column: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Column index containing URLs (optional)",
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the import",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(
|
||||
description="The unique identifier for the created import"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the import")
|
||||
title: str = SchemaField(description="Title of the import")
|
||||
count: int = SchemaField(description="Number of items in the import")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (only if csv_data not provided in request)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (only if upload_url is provided)"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the import was created")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="020a35d8-8a53-4e60-8b60-1de5cbab1df3",
|
||||
description="Import CSV data to use with websets for targeted searches",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaCreateImportBlock.Input,
|
||||
output_schema=ExaCreateImportBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"title": "Test Import",
|
||||
"csv_data": "name,url\nAcme,https://acme.com",
|
||||
"entity_type": ImportEntityType.COMPANY,
|
||||
"identifier_column": 0,
|
||||
},
|
||||
test_output=[
|
||||
("import_id", "import-123"),
|
||||
("status", "pending"),
|
||||
("title", "Test Import"),
|
||||
("count", 1),
|
||||
("entity_type", "company"),
|
||||
("upload_url", None),
|
||||
("upload_valid_until", None),
|
||||
("created_at", "2024-01-01T00:00:00"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK import object
|
||||
mock_import = MagicMock()
|
||||
mock_import.id = "import-123"
|
||||
mock_import.status = MagicMock(value="pending")
|
||||
mock_import.title = "Test Import"
|
||||
mock_import.format = MagicMock(value="csv")
|
||||
mock_import.count = 1
|
||||
mock_import.upload_url = None
|
||||
mock_import.upload_valid_until = None
|
||||
mock_import.failed_reason = None
|
||||
mock_import.failed_message = ""
|
||||
mock_import.metadata = {}
|
||||
mock_import.created_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_import.updated_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
|
||||
# Mock entity
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.model_dump = MagicMock(return_value={"type": "company"})
|
||||
mock_import.entity = mock_entity
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
csv_reader = csv.reader(StringIO(input_data.csv_data))
|
||||
rows = list(csv_reader)
|
||||
count = len(rows) - 1 if len(rows) > 1 else 0
|
||||
|
||||
size = len(input_data.csv_data.encode("utf-8"))
|
||||
|
||||
payload = {
|
||||
"title": input_data.title,
|
||||
"format": ImportFormat.CSV.value,
|
||||
"count": count,
|
||||
"size": size,
|
||||
"csv": {
|
||||
"identifier": input_data.identifier_column,
|
||||
},
|
||||
}
|
||||
|
||||
# Add URL column if specified
|
||||
if input_data.url_column is not None:
|
||||
payload["csv"]["url"] = input_data.url_column
|
||||
|
||||
# Add entity configuration
|
||||
entity = {"type": input_data.entity_type.value}
|
||||
if (
|
||||
input_data.entity_type == ImportEntityType.CUSTOM
|
||||
and input_data.entity_description
|
||||
):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_import = aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
yield "import_id", import_obj.id
|
||||
yield "status", import_obj.status
|
||||
yield "title", import_obj.title
|
||||
yield "count", import_obj.count
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "created_at", import_obj.created_at
|
||||
|
||||
|
||||
class ExaGetImportBlock(Block):
|
||||
"""Get the status and details of an import."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
import_id: str = SchemaField(
|
||||
description="The ID of the import to retrieve",
|
||||
placeholder="import-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(description="The unique identifier for the import")
|
||||
status: str = SchemaField(description="Current status of the import")
|
||||
title: str = SchemaField(description="Title of the import")
|
||||
format: str = SchemaField(description="Format of the imported data")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
count: int = SchemaField(description="Number of items imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (if import not yet uploaded)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (if applicable)"
|
||||
)
|
||||
failed_reason: Optional[str] = SchemaField(
|
||||
description="Reason for failure (if applicable)"
|
||||
)
|
||||
failed_message: Optional[str] = SchemaField(
|
||||
description="Detailed failure message (if applicable)"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the import was created")
|
||||
updated_at: str = SchemaField(description="When the import was last updated")
|
||||
metadata: dict = SchemaField(description="Metadata attached to the import")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="236663c8-a8dc-45f7-a050-2676bb0a3dd2",
|
||||
description="Get the status and details of an import",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaGetImportBlock.Input,
|
||||
output_schema=ExaGetImportBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
# Yield all fields
|
||||
yield "import_id", import_obj.id
|
||||
yield "status", import_obj.status
|
||||
yield "title", import_obj.title
|
||||
yield "format", import_obj.format
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "count", import_obj.count
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "failed_reason", import_obj.failed_reason
|
||||
yield "failed_message", import_obj.failed_message
|
||||
yield "created_at", import_obj.created_at
|
||||
yield "updated_at", import_obj.updated_at
|
||||
yield "metadata", import_obj.metadata
|
||||
|
||||
|
||||
class ExaListImportsBlock(Block):
|
||||
"""List all imports with pagination."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of imports to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
imports: list[dict] = SchemaField(description="List of imports")
|
||||
import_item: dict = SchemaField(
|
||||
description="Individual import (yielded for each import)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more imports to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="65323630-f7e9-4692-a624-184ba14c0686",
|
||||
description="List all imports with pagination support",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaListImportsBlock.Input,
|
||||
output_schema=ExaListImportsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.imports.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
# Convert SDK imports to our stable models
|
||||
imports = [ImportModel.from_sdk(i) for i in response.data]
|
||||
|
||||
yield "imports", [i.model_dump() for i in imports]
|
||||
|
||||
for import_obj in imports:
|
||||
yield "import_item", import_obj.model_dump()
|
||||
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
|
||||
|
||||
class ExaDeleteImportBlock(Block):
|
||||
"""Delete an import."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
import_id: str = SchemaField(
|
||||
description="The ID of the import to delete",
|
||||
placeholder="import-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(description="The ID of the deleted import")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="81ae30ed-c7ba-4b5d-8483-b726846e570c",
|
||||
description="Delete an import",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaDeleteImportBlock.Input,
|
||||
output_schema=ExaDeleteImportBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaExportWebsetBlock(Block):
|
||||
"""Export all data from a webset in various formats."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to export",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
format: ExportFormat = SchemaField(
|
||||
default=ExportFormat.JSON,
|
||||
description="Export format",
|
||||
)
|
||||
include_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include full content in export",
|
||||
)
|
||||
include_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include enrichment data in export",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of items to export",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
export_data: str = SchemaField(
|
||||
description="Exported data in the requested format"
|
||||
)
|
||||
item_count: int = SchemaField(description="Number of items exported")
|
||||
total_items: int = SchemaField(
|
||||
description="Total number of items in the webset"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the export was truncated due to max_items limit"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the exported data")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5da9d0fd-4b5b-4318-8302-8f71d0ccce9d",
|
||||
description="Export webset data in JSON, CSV, or JSON Lines format",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaExportWebsetBlock.Input,
|
||||
output_schema=ExaExportWebsetBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"webset_id": "test-webset",
|
||||
"format": ExportFormat.JSON,
|
||||
"include_content": True,
|
||||
"include_enrichments": True,
|
||||
"max_items": 10,
|
||||
},
|
||||
test_output=[
|
||||
("export_data", str),
|
||||
("item_count", 2),
|
||||
("total_items", 2),
|
||||
("truncated", False),
|
||||
("format", "json"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock webset items
|
||||
mock_item1 = MagicMock()
|
||||
mock_item1.model_dump = MagicMock(
|
||||
return_value={
|
||||
"id": "item-1",
|
||||
"url": "https://example.com",
|
||||
"title": "Test Item 1",
|
||||
}
|
||||
)
|
||||
|
||||
mock_item2 = MagicMock()
|
||||
mock_item2.model_dump = MagicMock(
|
||||
return_value={
|
||||
"id": "item-2",
|
||||
"url": "https://example.org",
|
||||
"title": "Test Item 2",
|
||||
}
|
||||
)
|
||||
|
||||
# Create mock iterator
|
||||
mock_items = [mock_item1, mock_item2]
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
all_items = []
|
||||
|
||||
# Use SDK's list_all iterator to fetch items
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
# Convert to dict for export
|
||||
item_dict = sdk_item.model_dump(by_alias=True, exclude_none=True)
|
||||
all_items.append(item_dict)
|
||||
|
||||
# Calculate total and truncated
|
||||
total_items = len(all_items) # SDK doesn't provide total count
|
||||
truncated = len(all_items) >= input_data.max_items
|
||||
|
||||
# Process items based on include flags
|
||||
if not input_data.include_content:
|
||||
for item in all_items:
|
||||
item.pop("content", None)
|
||||
|
||||
if not input_data.include_enrichments:
|
||||
for item in all_items:
|
||||
item.pop("enrichments", None)
|
||||
|
||||
# Format the export data
|
||||
export_data = ""
|
||||
|
||||
if input_data.format == ExportFormat.JSON:
|
||||
export_data = json.dumps(all_items, indent=2, default=str)
|
||||
|
||||
elif input_data.format == ExportFormat.JSON_LINES:
|
||||
lines = [json.dumps(item, default=str) for item in all_items]
|
||||
export_data = "\n".join(lines)
|
||||
|
||||
elif input_data.format == ExportFormat.CSV:
|
||||
# Extract all unique keys for CSV headers
|
||||
all_keys = set()
|
||||
for item in all_items:
|
||||
all_keys.update(self._flatten_dict(item).keys())
|
||||
|
||||
# Create CSV
|
||||
output = StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=sorted(all_keys))
|
||||
writer.writeheader()
|
||||
|
||||
for item in all_items:
|
||||
flat_item = self._flatten_dict(item)
|
||||
writer.writerow(flat_item)
|
||||
|
||||
export_data = output.getvalue()
|
||||
|
||||
yield "export_data", export_data
|
||||
yield "item_count", len(all_items)
|
||||
yield "total_items", total_items
|
||||
yield "truncated", truncated
|
||||
yield "format", input_data.format.value
|
||||
|
||||
except ValueError as e:
|
||||
# Re-raise user input validation errors
|
||||
raise ValueError(f"Failed to export webset: {e}") from e
|
||||
# Let all other exceptions propagate naturally
|
||||
|
||||
def _flatten_dict(self, d: dict, parent_key: str = "", sep: str = "_") -> dict:
|
||||
"""Flatten nested dictionaries for CSV export."""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
|
||||
elif isinstance(v, list):
|
||||
# Convert lists to JSON strings for CSV
|
||||
items.append((new_key, json.dumps(v, default=str)))
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
591
autogpt_platform/backend/backend/blocks/exa/websets_items.py
Normal file
591
autogpt_platform/backend/backend/blocks/exa/websets_items.py
Normal file
@@ -0,0 +1,591 @@
|
||||
"""
|
||||
Exa Websets Item Management Blocks
|
||||
|
||||
This module provides blocks for managing items within Exa websets, including
|
||||
retrieving, listing, deleting, and bulk operations on webset items.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetItem as SdkWebsetItem
|
||||
from exa_py.websets.types import (
|
||||
WebsetItemArticleProperties,
|
||||
WebsetItemCompanyProperties,
|
||||
WebsetItemCustomProperties,
|
||||
WebsetItemPersonProperties,
|
||||
WebsetItemResearchPaperProperties,
|
||||
)
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for enrichment results
|
||||
class EnrichmentResultModel(BaseModel):
|
||||
"""Stable output model mirroring SDK EnrichmentResult."""
|
||||
|
||||
enrichment_id: str
|
||||
format: str
|
||||
result: Optional[List[str]]
|
||||
reasoning: Optional[str]
|
||||
references: List[Dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_enrich) -> "EnrichmentResultModel":
|
||||
"""Convert SDK EnrichmentResult to our model."""
|
||||
format_str = (
|
||||
sdk_enrich.format.value
|
||||
if hasattr(sdk_enrich.format, "value")
|
||||
else str(sdk_enrich.format)
|
||||
)
|
||||
|
||||
# Convert references to dicts
|
||||
references_list = []
|
||||
if sdk_enrich.references:
|
||||
for ref in sdk_enrich.references:
|
||||
references_list.append(ref.model_dump(by_alias=True, exclude_none=True))
|
||||
|
||||
return cls(
|
||||
enrichment_id=sdk_enrich.enrichment_id,
|
||||
format=format_str,
|
||||
result=sdk_enrich.result,
|
||||
reasoning=sdk_enrich.reasoning,
|
||||
references=references_list,
|
||||
)
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class WebsetItemModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetItem."""
|
||||
|
||||
id: str
|
||||
url: Optional[AnyUrl]
|
||||
title: str
|
||||
content: str
|
||||
entity_data: Dict[str, Any]
|
||||
enrichments: Dict[str, EnrichmentResultModel]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, item: SdkWebsetItem) -> "WebsetItemModel":
|
||||
"""Convert SDK WebsetItem to our stable model."""
|
||||
# Extract properties from the union type
|
||||
properties_dict = {}
|
||||
url_value = None
|
||||
title = ""
|
||||
content = ""
|
||||
|
||||
if hasattr(item, "properties") and item.properties:
|
||||
properties_dict = item.properties.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# URL is always available on all property types
|
||||
url_value = item.properties.url
|
||||
|
||||
# Extract title using isinstance checks on the union type
|
||||
if isinstance(item.properties, WebsetItemPersonProperties):
|
||||
title = item.properties.person.name
|
||||
content = "" # Person type has no content
|
||||
elif isinstance(item.properties, WebsetItemCompanyProperties):
|
||||
title = item.properties.company.name
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemArticleProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemResearchPaperProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemCustomProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
else:
|
||||
# Fallback
|
||||
title = item.properties.description
|
||||
content = getattr(item.properties, "content", "")
|
||||
|
||||
# Convert enrichments from list to dict keyed by enrichment_id using Pydantic models
|
||||
enrichments_dict: Dict[str, EnrichmentResultModel] = {}
|
||||
if hasattr(item, "enrichments") and item.enrichments:
|
||||
for sdk_enrich in item.enrichments:
|
||||
enrich_model = EnrichmentResultModel.from_sdk(sdk_enrich)
|
||||
enrichments_dict[enrich_model.enrichment_id] = enrich_model
|
||||
|
||||
return cls(
|
||||
id=item.id,
|
||||
url=url_value,
|
||||
title=title,
|
||||
content=content or "",
|
||||
entity_data=properties_dict,
|
||||
enrichments=enrichments_dict,
|
||||
created_at=item.created_at.isoformat() if item.created_at else "",
|
||||
updated_at=item.updated_at.isoformat() if item.updated_at else "",
|
||||
)
|
||||
|
||||
|
||||
class ExaGetWebsetItemBlock(Block):
|
||||
"""Get a specific item from a webset by its ID."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
item_id: str = SchemaField(
|
||||
description="The ID of the specific item to retrieve",
|
||||
placeholder="item-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
item_id: str = SchemaField(description="The unique identifier for the item")
|
||||
url: str = SchemaField(description="The URL of the original source")
|
||||
title: str = SchemaField(description="The title of the item")
|
||||
content: str = SchemaField(description="The main content of the item")
|
||||
entity_data: dict = SchemaField(description="Entity-specific structured data")
|
||||
enrichments: dict = SchemaField(description="Enrichment data added to the item")
|
||||
created_at: str = SchemaField(
|
||||
description="When the item was added to the webset"
|
||||
)
|
||||
updated_at: str = SchemaField(description="When the item was last updated")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c4a7d9e2-8f3b-4a6c-9d8e-a5b6c7d8e9f0",
|
||||
description="Get a specific item from a webset by its ID",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetItemBlock.Input,
|
||||
output_schema=ExaGetWebsetItemBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_item = aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
yield "item_id", item.id
|
||||
yield "url", item.url
|
||||
yield "title", item.title
|
||||
yield "content", item.content
|
||||
yield "entity_data", item.entity_data
|
||||
yield "enrichments", item.enrichments
|
||||
yield "created_at", item.created_at
|
||||
yield "updated_at", item.updated_at
|
||||
|
||||
|
||||
class ExaListWebsetItemsBlock(Block):
|
||||
"""List items in a webset with pagination and optional filtering."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of items to return (1-100)",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_items: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for items to be available if webset is still processing",
|
||||
advanced=True,
|
||||
)
|
||||
wait_timeout: int = SchemaField(
|
||||
default=60,
|
||||
description="Maximum time to wait for items in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=300,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
items: list[WebsetItemModel] = SchemaField(
|
||||
description="List of webset items",
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID of the webset",
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each item in the list)",
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more items to paginate through",
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7b5e8c9f-01a2-43c4-95e6-f7a8b9c0d1e2",
|
||||
description="List items in a webset with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListWebsetItemsBlock.Input,
|
||||
output_schema=ExaListWebsetItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if input_data.wait_for_items:
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
interval = 2
|
||||
response = None
|
||||
|
||||
while time.time() - start_time < input_data.wait_timeout:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
if response.data:
|
||||
break
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.2, 10)
|
||||
|
||||
if not response:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
else:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
yield "items", items
|
||||
|
||||
for item in items:
|
||||
yield "item", item
|
||||
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
yield "webset_id", input_data.webset_id
|
||||
|
||||
|
||||
class ExaDeleteWebsetItemBlock(Block):
|
||||
"""Delete a specific item from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
item_id: str = SchemaField(
|
||||
description="The ID of the item to delete",
|
||||
placeholder="item-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
item_id: str = SchemaField(description="The ID of the deleted item")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12c57fbe-c270-4877-a2b6-d2d05529ba79",
|
||||
description="Delete a specific item from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteWebsetItemBlock.Input,
|
||||
output_schema=ExaDeleteWebsetItemBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_item = aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
yield "item_id", deleted_item.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaBulkWebsetItemsBlock(Block):
|
||||
"""Get all items from a webset in a single operation (with size limits)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of items to retrieve (1-1000). Note: Large values may take longer.",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
include_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include enrichment data for each item",
|
||||
)
|
||||
include_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include full content for each item",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
items: list[WebsetItemModel] = SchemaField(
|
||||
description="All items from the webset"
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each item)"
|
||||
)
|
||||
total_retrieved: int = SchemaField(
|
||||
description="Total number of items retrieved"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether results were truncated due to max_items limit"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="dbd619f5-476e-4395-af9a-a7a7c0fb8c4e",
|
||||
description="Get all items from a webset in bulk (with configurable limits)",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaBulkWebsetItemsBlock.Input,
|
||||
output_schema=ExaBulkWebsetItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
all_items: List[WebsetItemModel] = []
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
if not input_data.include_enrichments:
|
||||
item.enrichments = {}
|
||||
if not input_data.include_content:
|
||||
item.content = ""
|
||||
|
||||
all_items.append(item)
|
||||
|
||||
yield "items", all_items
|
||||
|
||||
for item in all_items:
|
||||
yield "item", item
|
||||
|
||||
yield "total_retrieved", len(all_items)
|
||||
yield "truncated", len(all_items) >= input_data.max_items
|
||||
|
||||
|
||||
class ExaWebsetItemsSummaryBlock(Block):
|
||||
"""Get a summary of items in a webset without retrieving all data."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
sample_size: int = SchemaField(
|
||||
default=5,
|
||||
description="Number of sample items to include",
|
||||
ge=0,
|
||||
le=10,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
total_items: int = SchemaField(
|
||||
description="Total number of items in the webset"
|
||||
)
|
||||
entity_type: str = SchemaField(description="Type of entities in the webset")
|
||||
sample_items: list[WebsetItemModel] = SchemaField(
|
||||
description="Sample of items from the webset"
|
||||
)
|
||||
enrichment_columns: list[str] = SchemaField(
|
||||
description="List of enrichment columns available"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="db7813ad-10bd-4652-8623-5667d6fecdd5",
|
||||
description="Get a summary of webset items without retrieving all data",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWebsetItemsSummaryBlock.Input,
|
||||
output_schema=ExaWebsetItemsSummaryBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
first_search = webset.searches[0]
|
||||
if first_search.entity:
|
||||
# The entity is a union type, extract type field
|
||||
entity_dict = first_search.entity.model_dump(by_alias=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Get enrichment columns
|
||||
enrichment_columns = []
|
||||
if webset.enrichments:
|
||||
enrichment_columns = [
|
||||
e.title if e.title else e.description for e in webset.enrichments
|
||||
]
|
||||
|
||||
# Get sample items if requested
|
||||
sample_items: List[WebsetItemModel] = []
|
||||
if input_data.sample_size > 0:
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
# Convert to our stable models
|
||||
sample_items = [
|
||||
WebsetItemModel.from_sdk(item) for item in items_response.data
|
||||
]
|
||||
|
||||
total_items = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
total_items += search.progress.found
|
||||
|
||||
yield "total_items", total_items
|
||||
yield "entity_type", entity_type
|
||||
yield "sample_items", sample_items
|
||||
yield "enrichment_columns", enrichment_columns
|
||||
|
||||
|
||||
class ExaGetNewItemsBlock(Block):
|
||||
"""Get items added to a webset since a specific cursor (incremental processing helper)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
since_cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor from previous run - only items after this will be returned. Leave empty on first run.",
|
||||
placeholder="cursor-from-previous-run",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of new items to retrieve",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
new_items: list[WebsetItemModel] = SchemaField(
|
||||
description="Items added since the cursor"
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each new item)"
|
||||
)
|
||||
count: int = SchemaField(description="Number of new items found")
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Save this cursor for the next run to get only newer items"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more new items beyond max_items"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3ff9bdf5-9613-4d21-8a60-90eb8b69c414",
|
||||
description="Get items added since a cursor - enables incremental processing without reprocessing",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=ExaGetNewItemsBlock.Input,
|
||||
output_schema=ExaGetNewItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get items starting from cursor
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
)
|
||||
|
||||
# Convert SDK items to our stable models
|
||||
new_items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
# Yield the full list
|
||||
yield "new_items", new_items
|
||||
|
||||
# Yield individual items for processing
|
||||
for item in new_items:
|
||||
yield "item", item
|
||||
|
||||
# Yield metadata for next run
|
||||
yield "count", len(new_items)
|
||||
yield "next_cursor", response.next_cursor
|
||||
yield "has_more", response.has_more
|
||||
600
autogpt_platform/backend/backend/blocks/exa/websets_monitor.py
Normal file
600
autogpt_platform/backend/backend/blocks/exa/websets_monitor.py
Normal file
@@ -0,0 +1,600 @@
|
||||
"""
|
||||
Exa Websets Monitor Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing monitors that automatically
|
||||
keep websets updated with fresh data on a schedule.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import Monitor as SdkMonitor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class MonitorModel(BaseModel):
|
||||
"""Stable output model mirroring SDK Monitor."""
|
||||
|
||||
id: str
|
||||
status: str
|
||||
webset_id: str
|
||||
behavior_type: str
|
||||
behavior_config: dict
|
||||
cron_expression: str
|
||||
timezone: str
|
||||
next_run_at: str
|
||||
last_run: dict
|
||||
metadata: dict
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, monitor: SdkMonitor) -> "MonitorModel":
|
||||
"""Convert SDK Monitor to our stable model."""
|
||||
# Extract behavior information
|
||||
behavior_dict = monitor.behavior.model_dump(by_alias=True, exclude_none=True)
|
||||
behavior_type = behavior_dict.get("type", "unknown")
|
||||
behavior_config = behavior_dict.get("config", {})
|
||||
|
||||
# Extract cadence information
|
||||
cadence_dict = monitor.cadence.model_dump(by_alias=True, exclude_none=True)
|
||||
cron_expr = cadence_dict.get("cron", "")
|
||||
timezone = cadence_dict.get("timezone", "Etc/UTC")
|
||||
|
||||
# Extract last run information
|
||||
last_run_dict = {}
|
||||
if monitor.last_run:
|
||||
last_run_dict = monitor.last_run.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Handle status enum
|
||||
status_str = (
|
||||
monitor.status.value
|
||||
if hasattr(monitor.status, "value")
|
||||
else str(monitor.status)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=monitor.id,
|
||||
status=status_str,
|
||||
webset_id=monitor.webset_id,
|
||||
behavior_type=behavior_type,
|
||||
behavior_config=behavior_config,
|
||||
cron_expression=cron_expr,
|
||||
timezone=timezone,
|
||||
next_run_at=monitor.next_run_at.isoformat() if monitor.next_run_at else "",
|
||||
last_run=last_run_dict,
|
||||
metadata=monitor.metadata or {},
|
||||
created_at=monitor.created_at.isoformat() if monitor.created_at else "",
|
||||
updated_at=monitor.updated_at.isoformat() if monitor.updated_at else "",
|
||||
)
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
"""Status of a monitor."""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
"""Type of behavior for a monitor."""
|
||||
|
||||
SEARCH = "search" # Run new searches
|
||||
REFRESH = "refresh" # Refresh existing items
|
||||
|
||||
|
||||
class SearchBehavior(str, Enum):
|
||||
"""How search results interact with existing items."""
|
||||
|
||||
APPEND = "append"
|
||||
OVERRIDE = "override"
|
||||
|
||||
|
||||
class ExaCreateMonitorBlock(Block):
|
||||
"""Create a monitor to automatically keep a webset updated on a schedule."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to monitor",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
# Schedule configuration
|
||||
cron_expression: str = SchemaField(
|
||||
description="Cron expression for scheduling (5 fields, max once per day)",
|
||||
placeholder="0 9 * * 1", # Every Monday at 9 AM
|
||||
)
|
||||
timezone: str = SchemaField(
|
||||
default="Etc/UTC",
|
||||
description="IANA timezone for the schedule",
|
||||
placeholder="America/New_York",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Behavior configuration
|
||||
behavior_type: MonitorBehaviorType = SchemaField(
|
||||
default=MonitorBehaviorType.SEARCH,
|
||||
description="Type of monitor behavior (search for new items or refresh existing)",
|
||||
)
|
||||
|
||||
# Search configuration (for SEARCH behavior)
|
||||
search_query: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Search query for finding new items (required for search behavior)",
|
||||
placeholder="AI startups that raised funding in the last week",
|
||||
)
|
||||
search_count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find in each search",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
search_criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Criteria that items must meet",
|
||||
advanced=True,
|
||||
)
|
||||
search_behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.APPEND,
|
||||
description="How new results interact with existing items",
|
||||
advanced=True,
|
||||
)
|
||||
entity_type: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Type of entity to search for (company, person, etc.)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Refresh configuration (for REFRESH behavior)
|
||||
refresh_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Refresh content from source URLs (for refresh behavior)",
|
||||
advanced=True,
|
||||
)
|
||||
refresh_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Re-run enrichments on items (for refresh behavior)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Metadata
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the monitor",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the created monitor"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this monitor belongs to")
|
||||
status: str = SchemaField(description="Status of the monitor")
|
||||
behavior_type: str = SchemaField(description="Type of monitor behavior")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
cron_expression: str = SchemaField(description="The schedule cron expression")
|
||||
timezone: str = SchemaField(description="The timezone for scheduling")
|
||||
created_at: str = SchemaField(description="When the monitor was created")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8a9b0c1-d2e3-4567-890a-bcdef1234567",
|
||||
description="Create automated monitors to keep websets updated with fresh data on a schedule",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateMonitorBlock.Input,
|
||||
output_schema=ExaCreateMonitorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"webset_id": "test-webset",
|
||||
"cron_expression": "0 9 * * 1",
|
||||
"behavior_type": MonitorBehaviorType.SEARCH,
|
||||
"search_query": "AI startups",
|
||||
"search_count": 10,
|
||||
},
|
||||
test_output=[
|
||||
("monitor_id", "monitor-123"),
|
||||
("webset_id", "test-webset"),
|
||||
("status", "enabled"),
|
||||
("behavior_type", "search"),
|
||||
("next_run_at", "2024-01-01T00:00:00"),
|
||||
("cron_expression", "0 9 * * 1"),
|
||||
("timezone", "Etc/UTC"),
|
||||
("created_at", "2024-01-01T00:00:00"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK monitor object
|
||||
mock_monitor = MagicMock()
|
||||
mock_monitor.id = "monitor-123"
|
||||
mock_monitor.status = MagicMock(value="enabled")
|
||||
mock_monitor.webset_id = "test-webset"
|
||||
mock_monitor.next_run_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.created_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.updated_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.metadata = {}
|
||||
mock_monitor.last_run = None
|
||||
|
||||
# Mock behavior
|
||||
mock_behavior = MagicMock()
|
||||
mock_behavior.model_dump = MagicMock(
|
||||
return_value={"type": "search", "config": {}}
|
||||
)
|
||||
mock_monitor.behavior = mock_behavior
|
||||
|
||||
# Mock cadence
|
||||
mock_cadence = MagicMock()
|
||||
mock_cadence.model_dump = MagicMock(
|
||||
return_value={"cron": "0 9 * * 1", "timezone": "Etc/UTC"}
|
||||
)
|
||||
mock_monitor.cadence = mock_cadence
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"websetId": input_data.webset_id,
|
||||
"cadence": {
|
||||
"cron": input_data.cron_expression,
|
||||
"timezone": input_data.timezone,
|
||||
},
|
||||
}
|
||||
|
||||
# Build behavior configuration based on type
|
||||
if input_data.behavior_type == MonitorBehaviorType.SEARCH:
|
||||
behavior_config = {
|
||||
"query": input_data.search_query or "",
|
||||
"count": input_data.search_count,
|
||||
"behavior": input_data.search_behavior.value,
|
||||
}
|
||||
|
||||
if input_data.search_criteria:
|
||||
behavior_config["criteria"] = [
|
||||
{"description": c} for c in input_data.search_criteria
|
||||
]
|
||||
|
||||
if input_data.entity_type:
|
||||
behavior_config["entity"] = {"type": input_data.entity_type}
|
||||
|
||||
payload["behavior"] = {
|
||||
"type": "search",
|
||||
"config": behavior_config,
|
||||
}
|
||||
else:
|
||||
# REFRESH behavior
|
||||
payload["behavior"] = {
|
||||
"type": "refresh",
|
||||
"config": {
|
||||
"content": input_data.refresh_content,
|
||||
"enrichments": input_data.refresh_enrichments,
|
||||
},
|
||||
}
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield all fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "webset_id", monitor.webset_id
|
||||
yield "status", monitor.status
|
||||
yield "behavior_type", monitor.behavior_type
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "cron_expression", monitor.cron_expression
|
||||
yield "timezone", monitor.timezone
|
||||
yield "created_at", monitor.created_at
|
||||
|
||||
|
||||
class ExaGetMonitorBlock(Block):
|
||||
"""Get the details and status of a monitor."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to retrieve",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the monitor"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this monitor belongs to")
|
||||
status: str = SchemaField(description="Current status of the monitor")
|
||||
behavior_type: str = SchemaField(description="Type of monitor behavior")
|
||||
behavior_config: dict = SchemaField(
|
||||
description="Configuration for the monitor behavior"
|
||||
)
|
||||
cron_expression: str = SchemaField(description="The schedule cron expression")
|
||||
timezone: str = SchemaField(description="The timezone for scheduling")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
last_run: Optional[dict] = SchemaField(
|
||||
description="Information about the last run"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the monitor was created")
|
||||
updated_at: str = SchemaField(description="When the monitor was last updated")
|
||||
metadata: dict = SchemaField(description="Metadata attached to the monitor")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5c852a2d-d505-4a56-b711-7def8dd14e72",
|
||||
description="Get the details and status of a webset monitor",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetMonitorBlock.Input,
|
||||
output_schema=ExaGetMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield all fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "webset_id", monitor.webset_id
|
||||
yield "status", monitor.status
|
||||
yield "behavior_type", monitor.behavior_type
|
||||
yield "behavior_config", monitor.behavior_config
|
||||
yield "cron_expression", monitor.cron_expression
|
||||
yield "timezone", monitor.timezone
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "last_run", monitor.last_run
|
||||
yield "created_at", monitor.created_at
|
||||
yield "updated_at", monitor.updated_at
|
||||
yield "metadata", monitor.metadata
|
||||
|
||||
|
||||
class ExaUpdateMonitorBlock(Block):
|
||||
"""Update a monitor's configuration."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to update",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
status: Optional[MonitorStatus] = SchemaField(
|
||||
default=None,
|
||||
description="New status for the monitor",
|
||||
)
|
||||
cron_expression: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New cron expression for scheduling",
|
||||
)
|
||||
timezone: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New timezone for the schedule",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="New metadata for the monitor",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the monitor"
|
||||
)
|
||||
status: str = SchemaField(description="Updated status of the monitor")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
updated_at: str = SchemaField(description="When the monitor was updated")
|
||||
success: str = SchemaField(description="Whether the update was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="245102c3-6af3-4515-a308-c2210b7939d2",
|
||||
description="Update a monitor's status, schedule, or metadata",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateMonitorBlock.Input,
|
||||
output_schema=ExaUpdateMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Build update payload
|
||||
payload = {}
|
||||
|
||||
if input_data.status is not None:
|
||||
payload["status"] = input_data.status.value
|
||||
|
||||
if input_data.cron_expression is not None or input_data.timezone is not None:
|
||||
cadence = {}
|
||||
if input_data.cron_expression:
|
||||
cadence["cron"] = input_data.cron_expression
|
||||
if input_data.timezone:
|
||||
cadence["timezone"] = input_data.timezone
|
||||
payload["cadence"] = cadence
|
||||
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
|
||||
# Convert to our stable model
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "status", monitor.status
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "updated_at", monitor.updated_at
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaDeleteMonitorBlock(Block):
|
||||
"""Delete a monitor from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to delete",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(description="The ID of the deleted monitor")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f16f9b10-0c4d-4db8-997d-7b96b6026094",
|
||||
description="Delete a monitor from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteMonitorBlock.Input,
|
||||
output_schema=ExaDeleteMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaListMonitorsBlock(Block):
|
||||
"""List all monitors with pagination."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Filter monitors by webset ID",
|
||||
placeholder="webset-id",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of monitors to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitors: list[dict] = SchemaField(description="List of monitors")
|
||||
monitor: dict = SchemaField(
|
||||
description="Individual monitor (yielded for each monitor)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more monitors to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f06e2b38-5397-4e8f-aa85-491149dd98df",
|
||||
description="List all monitors with optional webset filtering",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListMonitorsBlock.Input,
|
||||
output_schema=ExaListMonitorsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.monitors.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
)
|
||||
|
||||
# Convert SDK monitors to our stable models
|
||||
monitors = [MonitorModel.from_sdk(m) for m in response.data]
|
||||
|
||||
# Yield the full list
|
||||
yield "monitors", [m.model_dump() for m in monitors]
|
||||
|
||||
# Yield individual monitors for graph chaining
|
||||
for monitor in monitors:
|
||||
yield "monitor", monitor.model_dump()
|
||||
|
||||
# Yield pagination metadata
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
600
autogpt_platform/backend/backend/blocks/exa/websets_polling.py
Normal file
600
autogpt_platform/backend/backend/blocks/exa/websets_polling.py
Normal file
@@ -0,0 +1,600 @@
|
||||
"""
|
||||
Exa Websets Polling Blocks
|
||||
|
||||
This module provides dedicated polling blocks for waiting on webset operations
|
||||
to complete, with progress tracking and timeout management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
# Import WebsetItemModel for use in enrichment samples
|
||||
# This is safe as websets_items doesn't import from websets_polling
|
||||
from .websets_items import WebsetItemModel
|
||||
|
||||
|
||||
# Model for sample enrichment data
|
||||
class SampleEnrichmentModel(BaseModel):
|
||||
"""Sample enrichment result for display."""
|
||||
|
||||
item_id: str
|
||||
item_title: str
|
||||
enrichment_data: Dict[str, Any]
|
||||
|
||||
|
||||
class WebsetTargetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
COMPLETED = "completed"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
ANY_COMPLETE = "any_complete" # Either idle or completed
|
||||
|
||||
|
||||
class ExaWaitForWebsetBlock(Block):
|
||||
"""Wait for a webset to reach a specific status with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to monitor",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
target_status: WebsetTargetStatus = SchemaField(
|
||||
default=WebsetTargetStatus.IDLE,
|
||||
description="Status to wait for (idle=all operations complete, completed=search done, running=actively processing)",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800, # 30 minutes max
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
max_interval: int = SchemaField(
|
||||
default=30,
|
||||
description="Maximum interval between checks (for exponential backoff)",
|
||||
advanced=True,
|
||||
ge=5,
|
||||
le=120,
|
||||
)
|
||||
include_progress: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include detailed progress information in output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
webset_id: str = SchemaField(description="The webset ID that was monitored")
|
||||
final_status: str = SchemaField(description="The final status of the webset")
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
item_count: int = SchemaField(description="Number of items found")
|
||||
search_progress: dict = SchemaField(
|
||||
description="Detailed search progress information"
|
||||
)
|
||||
enrichment_progress: dict = SchemaField(
|
||||
description="Detailed enrichment progress information"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="619d71e8-b72a-434d-8bd4-23376dd0342c",
|
||||
description="Wait for a webset to reach a specific status with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForWebsetBlock.Input,
|
||||
output_schema=ExaWaitForWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
if input_data.target_status in [
|
||||
WebsetTargetStatus.IDLE,
|
||||
WebsetTargetStatus.ANY_COMPLETE,
|
||||
]:
|
||||
final_webset = aexa.websets.wait_until_idle(
|
||||
id=input_data.webset_id,
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
status_str = (
|
||||
final_webset.status.value
|
||||
if hasattr(final_webset.status, "value")
|
||||
else str(final_webset.status)
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
if final_webset.searches:
|
||||
for search in final_webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
# Extract progress if requested
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = final_webset.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(webset_dict)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", status_str
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", False
|
||||
else:
|
||||
# For other status targets, manually poll
|
||||
interval = input_data.check_interval
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
else str(webset.status)
|
||||
)
|
||||
|
||||
# Check if target status reached
|
||||
if current_status == input_data.target_status.value:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Estimate item count from search progress
|
||||
item_count = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = webset.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(
|
||||
webset_dict
|
||||
)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", current_status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", False
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, input_data.max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
else str(webset.status)
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = webset.model_dump(by_alias=True, exclude_none=True)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(webset_dict)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", final_status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
def _extract_search_progress(self, webset_data: dict) -> dict:
|
||||
"""Extract search progress information from webset data."""
|
||||
progress = {}
|
||||
searches = webset_data.get("searches", [])
|
||||
|
||||
for idx, search in enumerate(searches):
|
||||
search_id = search.get("id", f"search_{idx}")
|
||||
search_progress = search.get("progress", {})
|
||||
|
||||
progress[search_id] = {
|
||||
"status": search.get("status", "unknown"),
|
||||
"found": search_progress.get("found", 0),
|
||||
"analyzed": search_progress.get("analyzed", 0),
|
||||
"completion": search_progress.get("completion", 0),
|
||||
"time_left": search_progress.get("timeLeft", 0),
|
||||
}
|
||||
|
||||
return progress
|
||||
|
||||
def _extract_enrichment_progress(self, webset_data: dict) -> dict:
|
||||
"""Extract enrichment progress information from webset data."""
|
||||
progress = {}
|
||||
enrichments = webset_data.get("enrichments", [])
|
||||
|
||||
for idx, enrichment in enumerate(enrichments):
|
||||
enrich_id = enrichment.get("id", f"enrichment_{idx}")
|
||||
|
||||
progress[enrich_id] = {
|
||||
"status": enrichment.get("status", "unknown"),
|
||||
"title": enrichment.get("title", ""),
|
||||
"description": enrichment.get("description", ""),
|
||||
}
|
||||
|
||||
return progress
|
||||
|
||||
|
||||
class ExaWaitForSearchBlock(Block):
|
||||
"""Wait for a specific webset search to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to monitor",
|
||||
placeholder="search-id",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The search ID that was monitored")
|
||||
final_status: str = SchemaField(description="The final status of the search")
|
||||
items_found: int = SchemaField(
|
||||
description="Number of items found by the search"
|
||||
)
|
||||
items_analyzed: int = SchemaField(description="Number of items analyzed")
|
||||
completion_percentage: int = SchemaField(
|
||||
description="Completion percentage (0-100)"
|
||||
)
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
recall_info: dict = SchemaField(
|
||||
description="Information about expected results and confidence"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="14da21ae-40a1-41bc-a111-c8e5c9ef012b",
|
||||
description="Wait for a specific webset search to complete with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForSearchBlock.Input,
|
||||
output_schema=ExaWaitForSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
interval = input_data.check_interval
|
||||
max_interval = 30
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current search status using SDK
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
)
|
||||
|
||||
# Check if search is complete
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Extract progress information
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Extract recall information
|
||||
recall_info = {}
|
||||
if search.recall:
|
||||
recall_dict = search.recall.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
expected = recall_dict.get("expected", {})
|
||||
recall_info = {
|
||||
"expected_total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min_expected": expected.get("bounds", {}).get("min", 0),
|
||||
"max_expected": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": recall_dict.get("reasoning", ""),
|
||||
}
|
||||
|
||||
yield "search_id", input_data.search_id
|
||||
yield "final_status", status
|
||||
yield "items_found", progress_dict.get("found", 0)
|
||||
yield "items_analyzed", progress_dict.get("analyzed", 0)
|
||||
yield "completion_percentage", progress_dict.get("completion", 0)
|
||||
yield "elapsed_time", elapsed
|
||||
yield "recall_info", recall_info
|
||||
yield "timed_out", False
|
||||
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
final_status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
)
|
||||
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
yield "search_id", input_data.search_id
|
||||
yield "final_status", final_status
|
||||
yield "items_found", progress_dict.get("found", 0)
|
||||
yield "items_analyzed", progress_dict.get("analyzed", 0)
|
||||
yield "completion_percentage", progress_dict.get("completion", 0)
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Search polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
|
||||
class ExaWaitForEnrichmentBlock(Block):
|
||||
"""Wait for a webset enrichment to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to monitor",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
sample_results: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include sample enrichment results in output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The enrichment ID that was monitored"
|
||||
)
|
||||
final_status: str = SchemaField(
|
||||
description="The final status of the enrichment"
|
||||
)
|
||||
items_enriched: int = SchemaField(
|
||||
description="Number of items successfully enriched"
|
||||
)
|
||||
enrichment_title: str = SchemaField(
|
||||
description="Title/description of the enrichment"
|
||||
)
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
sample_data: list[SampleEnrichmentModel] = SchemaField(
|
||||
description="Sample of enriched data (if requested)"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a11865c3-ac80-4721-8a40-ac4e3b71a558",
|
||||
description="Wait for a webset enrichment to complete with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForEnrichmentBlock.Input,
|
||||
output_schema=ExaWaitForEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
interval = input_data.check_interval
|
||||
max_interval = 30
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current enrichment status using SDK
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
)
|
||||
|
||||
# Check if enrichment is complete
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get sample enriched items if requested
|
||||
sample_data = []
|
||||
items_enriched = 0
|
||||
|
||||
if input_data.sample_results and status == "completed":
|
||||
sample_data, items_enriched = (
|
||||
await self._get_sample_enrichments(
|
||||
input_data.webset_id, input_data.enrichment_id, aexa
|
||||
)
|
||||
)
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", status
|
||||
yield "items_enriched", items_enriched
|
||||
yield "enrichment_title", enrichment.title or enrichment.description or ""
|
||||
yield "elapsed_time", elapsed
|
||||
if input_data.sample_results:
|
||||
yield "sample_data", sample_data
|
||||
yield "timed_out", False
|
||||
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
final_status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
)
|
||||
title = enrichment.title or enrichment.description or ""
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", final_status
|
||||
yield "items_enriched", 0
|
||||
yield "enrichment_title", title
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Enrichment polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
async def _get_sample_enrichments(
|
||||
self, webset_id: str, enrichment_id: str, aexa: AsyncExa
|
||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
for sdk_item in response.data:
|
||||
# Convert to our WebsetItemModel first
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
# Check if this item has the enrichment we're looking for
|
||||
if enrichment_id in item.enrichments:
|
||||
enriched_count += 1
|
||||
enrich_model = item.enrichments[enrichment_id]
|
||||
|
||||
# Create sample using our typed model
|
||||
sample = SampleEnrichmentModel(
|
||||
item_id=item.id,
|
||||
item_title=item.title,
|
||||
enrichment_data=enrich_model.model_dump(exclude_none=True),
|
||||
)
|
||||
sample_data.append(sample)
|
||||
|
||||
return sample_data, enriched_count
|
||||
650
autogpt_platform/backend/backend/blocks/exa/websets_search.py
Normal file
650
autogpt_platform/backend/backend/blocks/exa/websets_search.py
Normal file
@@ -0,0 +1,650 @@
|
||||
"""
|
||||
Exa Websets Search Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing searches within websets,
|
||||
including adding new searches, checking status, and canceling operations.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetSearch as SdkWebsetSearch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
class WebsetSearchModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetSearch."""
|
||||
|
||||
id: str
|
||||
webset_id: str
|
||||
status: str
|
||||
query: str
|
||||
entity_type: str
|
||||
criteria: List[Dict[str, Any]]
|
||||
count: int
|
||||
behavior: str
|
||||
progress: Dict[str, Any]
|
||||
recall: Optional[Dict[str, Any]]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
canceled_at: Optional[str]
|
||||
canceled_reason: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, search: SdkWebsetSearch) -> "WebsetSearchModel":
|
||||
"""Convert SDK WebsetSearch to our stable model."""
|
||||
# Extract entity type
|
||||
entity_type = "auto"
|
||||
if search.entity:
|
||||
entity_dict = search.entity.model_dump(by_alias=True)
|
||||
entity_type = entity_dict.get("type", "auto")
|
||||
|
||||
# Convert criteria
|
||||
criteria = [c.model_dump(by_alias=True) for c in search.criteria]
|
||||
|
||||
# Convert progress
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(by_alias=True)
|
||||
|
||||
# Convert recall
|
||||
recall_dict = None
|
||||
if search.recall:
|
||||
recall_dict = search.recall.model_dump(by_alias=True)
|
||||
|
||||
return cls(
|
||||
id=search.id,
|
||||
webset_id=search.webset_id,
|
||||
status=(
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
),
|
||||
query=search.query,
|
||||
entity_type=entity_type,
|
||||
criteria=criteria,
|
||||
count=search.count,
|
||||
behavior=search.behavior.value if search.behavior else "override",
|
||||
progress=progress_dict,
|
||||
recall=recall_dict,
|
||||
created_at=search.created_at.isoformat() if search.created_at else "",
|
||||
updated_at=search.updated_at.isoformat() if search.updated_at else "",
|
||||
canceled_at=search.canceled_at.isoformat() if search.canceled_at else None,
|
||||
canceled_reason=(
|
||||
search.canceled_reason.value if search.canceled_reason else None
|
||||
),
|
||||
metadata=search.metadata if search.metadata else {},
|
||||
)
|
||||
|
||||
|
||||
class SearchBehavior(str, Enum):
|
||||
"""Behavior for how new search results interact with existing items."""
|
||||
|
||||
OVERRIDE = "override" # Replace existing items
|
||||
APPEND = "append" # Add to existing items
|
||||
MERGE = "merge" # Merge with existing items
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ExaCreateWebsetSearchBlock(Block):
|
||||
"""Add a new search to an existing webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query describing what to find",
|
||||
placeholder="Engineering managers at Fortune 500 companies",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
|
||||
# Entity configuration
|
||||
entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Type of entity to search for",
|
||||
)
|
||||
entity_description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Criteria for verification
|
||||
criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of criteria that items must meet. If not provided, auto-detected from query.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Advanced search options
|
||||
behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.APPEND,
|
||||
description="How new results interact with existing items",
|
||||
advanced=True,
|
||||
)
|
||||
recall: bool = SchemaField(
|
||||
default=True,
|
||||
description="Enable recall estimation for expected results",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Exclude sources
|
||||
exclude_source_ids: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="IDs of imports/websets to exclude from results",
|
||||
advanced=True,
|
||||
)
|
||||
exclude_source_types: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Types of sources to exclude ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Scope sources
|
||||
scope_source_ids: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="IDs of imports/websets to limit search scope to",
|
||||
advanced=True,
|
||||
)
|
||||
scope_source_types: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Types of scope sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
scope_relationships: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Relationship definitions for hop searches",
|
||||
advanced=True,
|
||||
)
|
||||
scope_relationship_limits: list[int] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Limits on related entities to find",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the search",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Polling options
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for the search to complete before returning",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait for completion in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(
|
||||
description="The unique identifier for the created search"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this search belongs to")
|
||||
status: str = SchemaField(description="Current status of the search")
|
||||
query: str = SchemaField(description="The search query")
|
||||
expected_results: dict = SchemaField(
|
||||
description="Recall estimation of expected results"
|
||||
)
|
||||
items_found: Optional[int] = SchemaField(
|
||||
description="Number of items found (if wait_for_completion was True)"
|
||||
)
|
||||
completion_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="342ff776-2e2c-4cdb-b392-4eeb34b21d5f",
|
||||
description="Add a new search to an existing webset to find more items",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateWebsetSearchBlock.Input,
|
||||
output_schema=ExaCreateWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import time
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"count": input_data.count,
|
||||
"behavior": input_data.behavior.value,
|
||||
"recall": input_data.recall,
|
||||
}
|
||||
|
||||
# Add entity configuration
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
entity = {"type": input_data.entity_type.value}
|
||||
if (
|
||||
input_data.entity_type == SearchEntityType.CUSTOM
|
||||
and input_data.entity_description
|
||||
):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
# Add criteria if provided
|
||||
if input_data.criteria:
|
||||
payload["criteria"] = [{"description": c} for c in input_data.criteria]
|
||||
|
||||
# Add exclude sources
|
||||
if input_data.exclude_source_ids:
|
||||
exclude_list = []
|
||||
for idx, src_id in enumerate(input_data.exclude_source_ids):
|
||||
src_type = "import"
|
||||
if input_data.exclude_source_types and idx < len(
|
||||
input_data.exclude_source_types
|
||||
):
|
||||
src_type = input_data.exclude_source_types[idx]
|
||||
exclude_list.append({"source": src_type, "id": src_id})
|
||||
payload["exclude"] = exclude_list
|
||||
|
||||
# Add scope sources
|
||||
if input_data.scope_source_ids:
|
||||
scope_list: list[dict[str, Any]] = []
|
||||
for idx, src_id in enumerate(input_data.scope_source_ids):
|
||||
scope_item: dict[str, Any] = {"source": "import", "id": src_id}
|
||||
|
||||
if input_data.scope_source_types and idx < len(
|
||||
input_data.scope_source_types
|
||||
):
|
||||
scope_item["source"] = input_data.scope_source_types[idx]
|
||||
|
||||
# Add relationship if provided
|
||||
if input_data.scope_relationships and idx < len(
|
||||
input_data.scope_relationships
|
||||
):
|
||||
relationship: dict[str, Any] = {
|
||||
"definition": input_data.scope_relationships[idx]
|
||||
}
|
||||
if input_data.scope_relationship_limits and idx < len(
|
||||
input_data.scope_relationship_limits
|
||||
):
|
||||
relationship["limit"] = input_data.scope_relationship_limits[
|
||||
idx
|
||||
]
|
||||
scope_item["relationship"] = relationship
|
||||
|
||||
scope_list.append(scope_item)
|
||||
payload["scope"] = scope_list
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
search_id = sdk_search.id
|
||||
status = (
|
||||
sdk_search.status.value
|
||||
if hasattr(sdk_search.status, "value")
|
||||
else str(sdk_search.status)
|
||||
)
|
||||
|
||||
# Extract expected results from recall
|
||||
expected_results = {}
|
||||
if sdk_search.recall:
|
||||
recall_dict = sdk_search.recall.model_dump(by_alias=True)
|
||||
expected = recall_dict.get("expected", {})
|
||||
expected_results = {
|
||||
"total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min": expected.get("bounds", {}).get("min", 0),
|
||||
"max": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": recall_dict.get("reasoning", ""),
|
||||
}
|
||||
|
||||
# If wait_for_completion is True, poll for completion
|
||||
if input_data.wait_for_completion:
|
||||
import asyncio
|
||||
|
||||
poll_interval = 5
|
||||
max_interval = 30
|
||||
poll_start = time.time()
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
current_status = (
|
||||
current_search.status.value
|
||||
if hasattr(current_search.status, "value")
|
||||
else str(current_search.status)
|
||||
)
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
items_found = 0
|
||||
if current_search.progress:
|
||||
items_found = current_search.progress.found
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", current_status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
yield "items_found", items_found
|
||||
yield "completion_time", completion_time
|
||||
return
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
poll_interval = min(poll_interval * 1.5, max_interval)
|
||||
|
||||
# Timeout - yield what we have
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
yield "items_found", 0
|
||||
yield "completion_time", time.time() - start_time
|
||||
else:
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
|
||||
|
||||
class ExaGetWebsetSearchBlock(Block):
|
||||
"""Get the status and details of a webset search."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to retrieve",
|
||||
placeholder="search-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The unique identifier for the search")
|
||||
status: str = SchemaField(description="Current status of the search")
|
||||
query: str = SchemaField(description="The search query")
|
||||
entity_type: str = SchemaField(description="Type of entity being searched")
|
||||
criteria: list[dict] = SchemaField(description="Criteria used for verification")
|
||||
progress: dict = SchemaField(description="Search progress information")
|
||||
recall: dict = SchemaField(description="Recall estimation information")
|
||||
created_at: str = SchemaField(description="When the search was created")
|
||||
updated_at: str = SchemaField(description="When the search was last updated")
|
||||
canceled_at: Optional[str] = SchemaField(
|
||||
description="When the search was canceled (if applicable)"
|
||||
)
|
||||
canceled_reason: Optional[str] = SchemaField(
|
||||
description="Reason for cancellation (if applicable)"
|
||||
)
|
||||
metadata: dict = SchemaField(description="Metadata attached to the search")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4fa3e627-a0ff-485f-8732-52148051646c",
|
||||
description="Get the status and details of a webset search",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetSearchBlock.Input,
|
||||
output_schema=ExaGetWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
# Extract progress information
|
||||
progress_info = {
|
||||
"found": search.progress.get("found", 0),
|
||||
"analyzed": search.progress.get("analyzed", 0),
|
||||
"completion": search.progress.get("completion", 0),
|
||||
"time_left": search.progress.get("timeLeft", 0),
|
||||
}
|
||||
|
||||
# Extract recall information
|
||||
recall_data = {}
|
||||
if search.recall:
|
||||
expected = search.recall.get("expected", {})
|
||||
recall_data = {
|
||||
"expected_total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min_expected": expected.get("bounds", {}).get("min", 0),
|
||||
"max_expected": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": search.recall.get("reasoning", ""),
|
||||
}
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "entity_type", search.entity_type
|
||||
yield "criteria", search.criteria
|
||||
yield "progress", progress_info
|
||||
yield "recall", recall_data
|
||||
yield "created_at", search.created_at
|
||||
yield "updated_at", search.updated_at
|
||||
yield "canceled_at", search.canceled_at
|
||||
yield "canceled_reason", search.canceled_reason
|
||||
yield "metadata", search.metadata
|
||||
|
||||
|
||||
class ExaCancelWebsetSearchBlock(Block):
|
||||
"""Cancel a running webset search."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to cancel",
|
||||
placeholder="search-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The ID of the canceled search")
|
||||
status: str = SchemaField(description="Status after cancellation")
|
||||
items_found_before_cancel: int = SchemaField(
|
||||
description="Number of items found before cancellation"
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74ef9f1e-ae89-4c7f-9d7d-d217214815b4",
|
||||
description="Cancel a running webset search",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelWebsetSearchBlock.Input,
|
||||
output_schema=ExaCancelWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_search = aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
# Extract items found before cancellation
|
||||
items_found = 0
|
||||
if canceled_search.progress:
|
||||
items_found = canceled_search.progress.found
|
||||
|
||||
status = (
|
||||
canceled_search.status.value
|
||||
if hasattr(canceled_search.status, "value")
|
||||
else str(canceled_search.status)
|
||||
)
|
||||
|
||||
yield "search_id", canceled_search.id
|
||||
yield "status", status
|
||||
yield "items_found_before_cancel", items_found
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaFindOrCreateSearchBlock(Block):
|
||||
"""Find existing search by query or create new one (prevents duplicate searches)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query to find or create",
|
||||
placeholder="AI companies in San Francisco",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find (only used if creating new search)",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Entity type (only used if creating)",
|
||||
advanced=True,
|
||||
)
|
||||
behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.OVERRIDE,
|
||||
description="Search behavior (only used if creating)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The search ID (existing or new)")
|
||||
webset_id: str = SchemaField(description="The webset ID")
|
||||
status: str = SchemaField(description="Current search status")
|
||||
query: str = SchemaField(description="The search query")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if search was newly created, False if already existed"
|
||||
)
|
||||
items_found: int = SchemaField(
|
||||
description="Number of items found (0 if still running)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbdb05ac-cb73-4b03-a493-6d34e9a011da",
|
||||
description="Find existing search by query or create new - prevents duplicate searches in workflows",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaFindOrCreateSearchBlock.Input,
|
||||
output_schema=ExaFindOrCreateSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.query.strip().lower() == input_data.query.strip().lower():
|
||||
existing_search = search
|
||||
break
|
||||
|
||||
if existing_search:
|
||||
# Found existing search
|
||||
search = WebsetSearchModel.from_sdk(existing_search)
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "was_created", False
|
||||
yield "items_found", search.progress.get("found", 0)
|
||||
else:
|
||||
# Create new search
|
||||
payload: Dict[str, Any] = {
|
||||
"query": input_data.query,
|
||||
"count": input_data.count,
|
||||
"behavior": input_data.behavior.value,
|
||||
}
|
||||
|
||||
# Add entity if not auto
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
payload["entity"] = {"type": input_data.entity_type.value}
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "was_created", True
|
||||
yield "items_found", 0 # Newly created, no items yet
|
||||
@@ -1,6 +1,5 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
@@ -118,13 +117,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
||||
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
|
||||
# Groq models
|
||||
GEMMA2_9B = "gemma2-9b-it"
|
||||
LLAMA3_3_70B = "llama-3.3-70b-versatile"
|
||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||
LLAMA3_70B = "llama3-70b-8192"
|
||||
LLAMA3_8B = "llama3-8b-8192"
|
||||
# Groq preview models
|
||||
DEEPSEEK_LLAMA_70B = "deepseek-r1-distill-llama-70b"
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_3 = "llama3.3"
|
||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||
@@ -134,7 +128,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
@@ -238,12 +231,8 @@ MODEL_METADATA = {
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata("aiml_api", 131000, 2000),
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata("aiml_api", 128000, None),
|
||||
# https://console.groq.com/docs/models
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, None),
|
||||
LlmModel.LLAMA3_3_70B: ModelMetadata("groq", 128000, 32768),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 128000, 8192),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, None),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, None),
|
||||
LlmModel.DEEPSEEK_LLAMA_70B: ModelMetadata("groq", 128000, None),
|
||||
# https://ollama.com/library
|
||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata("ollama", 8192, None),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192, None),
|
||||
@@ -251,7 +240,6 @@ MODEL_METADATA = {
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, None),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata("open_router", 1048576, 65535),
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata("open_router", 1048576, 8192),
|
||||
@@ -1644,6 +1632,17 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
ge=1,
|
||||
le=5,
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force the LLM to produce a JSON-only response. "
|
||||
"This can increase the block's reliability, "
|
||||
"but may also reduce the quality of the response "
|
||||
"because it prohibits the LLM from reasoning "
|
||||
"before providing its JSON response."
|
||||
),
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
@@ -1656,7 +1655,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
generated_list: List[str] = SchemaField(description="The generated list.")
|
||||
generated_list: list[str] = SchemaField(description="The generated list.")
|
||||
list_item: str = SchemaField(
|
||||
description="Each individual item in the list.",
|
||||
)
|
||||
@@ -1665,7 +1664,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9c0b0450-d199-458b-a731-072189dd6593",
|
||||
description="Generate a Python list based on the given prompt using a Large Language Model (LLM).",
|
||||
description="Generate a list of values based on the given prompt using a Large Language Model (LLM).",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=AIListGeneratorBlock.Input,
|
||||
output_schema=AIListGeneratorBlock.Output,
|
||||
@@ -1682,6 +1681,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
@@ -1698,7 +1698,13 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda input_data, credentials: {
|
||||
"response": "['Zylora Prime', 'Kharon-9', 'Vortexia', 'Oceara', 'Draknos']"
|
||||
"list": [
|
||||
"Zylora Prime",
|
||||
"Kharon-9",
|
||||
"Vortexia",
|
||||
"Oceara",
|
||||
"Draknos",
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -1707,7 +1713,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
self,
|
||||
input_data: AIStructuredResponseGeneratorBlock.Input,
|
||||
credentials: APIKeyCredentials,
|
||||
) -> dict[str, str]:
|
||||
) -> dict[str, Any]:
|
||||
llm_block = AIStructuredResponseGeneratorBlock()
|
||||
response = await llm_block.run_once(
|
||||
input_data, "response", credentials=credentials
|
||||
@@ -1715,72 +1721,23 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
self.merge_llm_stats(llm_block)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def string_to_list(string):
|
||||
"""
|
||||
Converts a string representation of a list into an actual Python list object.
|
||||
"""
|
||||
logger.debug(f"Converting string to list. Input string: {string}")
|
||||
try:
|
||||
# Use ast.literal_eval to safely evaluate the string
|
||||
python_list = ast.literal_eval(string)
|
||||
if isinstance(python_list, list):
|
||||
logger.debug(f"Successfully converted string to list: {python_list}")
|
||||
return python_list
|
||||
else:
|
||||
logger.error(f"The provided string '{string}' is not a valid list")
|
||||
raise ValueError(f"The provided string '{string}' is not a valid list.")
|
||||
except (SyntaxError, ValueError) as e:
|
||||
logger.error(f"Failed to convert string to list: {e}")
|
||||
raise ValueError("Invalid list format. Could not convert to list.")
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
logger.debug(f"Starting AIListGeneratorBlock.run with input data: {input_data}")
|
||||
|
||||
# Check for API key
|
||||
api_key_check = credentials.api_key.get_secret_value()
|
||||
if not api_key_check:
|
||||
raise ValueError("No LLM API key provided.")
|
||||
# Create a proper expected format for the structured response generator
|
||||
expected_format = {
|
||||
"list": "A JSON array containing the generated string values"
|
||||
}
|
||||
if input_data.force_json_output:
|
||||
# Add reasoning field for better performance
|
||||
expected_format = {
|
||||
"reasoning": "... (optional)",
|
||||
**expected_format,
|
||||
}
|
||||
|
||||
# Prepare the system prompt
|
||||
sys_prompt = """You are a Python list generator. Your task is to generate a Python list based on the user's prompt.
|
||||
|Respond ONLY with a valid python list.
|
||||
|The list can contain strings, numbers, or nested lists as appropriate.
|
||||
|Do not include any explanations or additional text.
|
||||
|
||||
|Valid Example string formats:
|
||||
|
||||
|Example 1:
|
||||
|```
|
||||
|['1', '2', '3', '4']
|
||||
|```
|
||||
|
||||
|Example 2:
|
||||
|```
|
||||
|[['1', '2'], ['3', '4'], ['5', '6']]
|
||||
|```
|
||||
|
||||
|Example 3:
|
||||
|```
|
||||
|['1', ['2', '3'], ['4', ['5', '6']]]
|
||||
|```
|
||||
|
||||
|Example 4:
|
||||
|```
|
||||
|['a', 'b', 'c']
|
||||
|```
|
||||
|
||||
|Example 5:
|
||||
|```
|
||||
|['1', '2.5', 'string', 'True', ['False', 'None']]
|
||||
|```
|
||||
|
||||
|Do not include any explanations or additional text, just respond with the list in the format specified above.
|
||||
|Do not include code fences or any other formatting, just the raw list.
|
||||
"""
|
||||
# If a focus is provided, add it to the prompt
|
||||
# Build the prompt
|
||||
if input_data.focus:
|
||||
prompt = f"Generate a list with the following focus:\n<focus>\n\n{input_data.focus}</focus>"
|
||||
else:
|
||||
@@ -1788,7 +1745,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
if input_data.source_data:
|
||||
prompt = "Extract the main focus of the source data to a list.\ni.e if the source data is a news website, the focus would be the news stories rather than the social links in the footer."
|
||||
else:
|
||||
# No focus or source data provided, generat a random list
|
||||
# No focus or source data provided, generate a random list
|
||||
prompt = "Generate a random list."
|
||||
|
||||
# If the source data is provided, add it to the prompt
|
||||
@@ -1798,63 +1755,56 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
else:
|
||||
prompt += "\n\nInvent the data to generate the list from."
|
||||
|
||||
for attempt in range(input_data.max_retries):
|
||||
try:
|
||||
logger.debug("Calling LLM")
|
||||
llm_response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
sys_prompt=sys_prompt,
|
||||
prompt=prompt,
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
expected_format={}, # Do not use structured response
|
||||
ollama_host=input_data.ollama_host,
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
# Use the structured response generator to handle all the complexity
|
||||
response_obj = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
sys_prompt=self.SYSTEM_PROMPT,
|
||||
prompt=prompt,
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
expected_format=expected_format,
|
||||
force_json_output=input_data.force_json_output,
|
||||
retry=input_data.max_retries,
|
||||
max_tokens=input_data.max_tokens,
|
||||
ollama_host=input_data.ollama_host,
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
logger.debug(f"Response object: {response_obj}")
|
||||
|
||||
logger.debug(f"LLM response: {llm_response}")
|
||||
# Extract the list from the response object
|
||||
if isinstance(response_obj, dict) and "list" in response_obj:
|
||||
parsed_list = response_obj["list"]
|
||||
else:
|
||||
# Fallback - treat the whole response as the list
|
||||
parsed_list = response_obj
|
||||
|
||||
# Extract Response string
|
||||
response_string = llm_response["response"]
|
||||
logger.debug(f"Response string: {response_string}")
|
||||
# Validate that we got a list
|
||||
if not isinstance(parsed_list, list):
|
||||
raise ValueError(
|
||||
f"Expected a list, but got {type(parsed_list).__name__}: {parsed_list}"
|
||||
)
|
||||
|
||||
# Convert the string to a Python list
|
||||
logger.debug("Converting string to Python list")
|
||||
parsed_list = self.string_to_list(response_string)
|
||||
logger.debug(f"Parsed list: {parsed_list}")
|
||||
logger.debug(f"Parsed list: {parsed_list}")
|
||||
|
||||
# If we reach here, we have a valid Python list
|
||||
logger.debug("Successfully generated a valid Python list")
|
||||
yield "generated_list", parsed_list
|
||||
yield "prompt", self.prompt
|
||||
# Yield the results
|
||||
yield "generated_list", parsed_list
|
||||
yield "prompt", self.prompt
|
||||
|
||||
# Yield each item in the list
|
||||
for item in parsed_list:
|
||||
yield "list_item", item
|
||||
return
|
||||
# Yield each item in the list
|
||||
for item in parsed_list:
|
||||
yield "list_item", item
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in attempt {attempt + 1}: {str(e)}")
|
||||
if attempt == input_data.max_retries - 1:
|
||||
logger.error(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts. Last error: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# Add a retry prompt
|
||||
logger.debug("Preparing retry prompt")
|
||||
prompt = f"""
|
||||
The previous attempt failed due to `{e}`
|
||||
Generate a valid Python list based on the original prompt.
|
||||
Remember to respond ONLY with a valid Python list as per the format specified earlier.
|
||||
Original prompt:
|
||||
```{prompt}```
|
||||
|
||||
Respond only with the list in the format specified with no commentary or apologies.
|
||||
"""
|
||||
logger.debug(f"Retry prompt: {prompt}")
|
||||
|
||||
logger.debug("AIListGeneratorBlock.run completed")
|
||||
SYSTEM_PROMPT = trim_prompt(
|
||||
"""
|
||||
|You are a JSON array generator. Your task is to generate a JSON array of string values based on the user's prompt.
|
||||
|
|
||||
|The 'list' field should contain a JSON array with the generated string values.
|
||||
|The array can contain ONLY strings.
|
||||
|
|
||||
|Valid JSON array formats include:
|
||||
|• ["string1", "string2", "string3"]
|
||||
|
|
||||
|Ensure you provide a proper JSON array with only string values in the 'list' field.
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.data.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
is_tool_pin,
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
@@ -367,8 +368,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
"required": sorted(required_fields),
|
||||
}
|
||||
|
||||
# Store field mapping for later use in output processing
|
||||
# Store field mapping and node info for later use in output processing
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@@ -431,10 +433,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
# Store node info for later use in output processing
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_function_signature(
|
||||
async def _create_tool_node_signatures(
|
||||
node_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
@@ -450,7 +455,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
if is_tool_pin(link.source_name) and link.source_id == node_id
|
||||
]
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
@@ -538,8 +543,14 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_def is None and len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
if tool_def is None:
|
||||
if len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
else:
|
||||
validation_errors_list.append(
|
||||
f"Tool call for '{tool_name}' does not match any known "
|
||||
"tool definition."
|
||||
)
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
@@ -591,7 +602,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = await self._create_function_signature(node_id)
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
@@ -661,9 +672,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
"Your tool call had errors. Please fix the following issues and try again:\n"
|
||||
+ f"- {str(e)}\n"
|
||||
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
+ "\nPlease make sure to use the exact tool and parameter names as specified in the function schema."
|
||||
)
|
||||
current_prompt = list(current_prompt) + [
|
||||
{"role": "user", "content": error_feedback}
|
||||
@@ -690,21 +701,23 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
if not tool_def:
|
||||
# NOTE: This matches the logic in _attempt_llm_call_with_validation and
|
||||
# relies on its validation for the assumption that this is valid to use.
|
||||
if len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
else:
|
||||
# This should not happen due to prior validation
|
||||
continue
|
||||
|
||||
if "function" in tool_def and "parameters" in tool_def["function"]:
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
|
||||
# Get field mapping from tool definition
|
||||
field_mapping = (
|
||||
tool_def.get("function", {}).get("_field_mapping", {})
|
||||
if tool_def
|
||||
else {}
|
||||
)
|
||||
# Get the sink node ID and field mapping from tool definition
|
||||
field_mapping = tool_def["function"].get("_field_mapping", {})
|
||||
sink_node_id = tool_def["function"]["_sink_node_id"]
|
||||
|
||||
for clean_arg_name in expected_args:
|
||||
# arg_name is now always the cleaned field name (for Anthropic API compliance)
|
||||
@@ -712,9 +725,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
sanitized_tool_name = self.cleanup(tool_name)
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
|
||||
@@ -6,12 +6,12 @@ from backend.data.block import Block, get_blocks
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b().name)
|
||||
async def test_available_blocks(block: Type[Block]):
|
||||
await execute_block_test(block())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b().name)
|
||||
async def test_block_ids_valid(block: Type[Block]):
|
||||
# add the tests here to check they are uuid4
|
||||
import uuid
|
||||
|
||||
@@ -365,37 +365,22 @@ class TestLLMStatsTracking:
|
||||
assert outputs["response"] == "AI response to conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_list_generator_with_retries(self):
|
||||
"""Test that AIListGeneratorBlock correctly tracks stats with retries."""
|
||||
async def test_ai_list_generator_basic_functionality(self):
|
||||
"""Test that AIListGeneratorBlock correctly works with structured responses."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIListGeneratorBlock()
|
||||
|
||||
# Counter to track calls
|
||||
call_count = 0
|
||||
|
||||
# Mock the llm_call to return a structured response
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Update stats
|
||||
if hasattr(block, "execution_stats") and block.execution_stats:
|
||||
block.execution_stats.input_token_count += 40
|
||||
block.execution_stats.output_token_count += 20
|
||||
block.execution_stats.llm_call_count += 1
|
||||
else:
|
||||
block.execution_stats = NodeExecutionStats(
|
||||
input_token_count=40,
|
||||
output_token_count=20,
|
||||
llm_call_count=1,
|
||||
)
|
||||
|
||||
if call_count == 1:
|
||||
# First call returns invalid format
|
||||
return {"response": "not a valid list"}
|
||||
else:
|
||||
# Second call returns valid list
|
||||
return {"response": "['item1', 'item2', 'item3']"}
|
||||
# Update stats to simulate LLM call
|
||||
block.execution_stats = NodeExecutionStats(
|
||||
input_token_count=50,
|
||||
output_token_count=30,
|
||||
llm_call_count=1,
|
||||
)
|
||||
# Return a structured response with the expected format
|
||||
return {"list": ["item1", "item2", "item3"]}
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
@@ -413,14 +398,20 @@ class TestLLMStatsTracking:
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats - should have 2 calls
|
||||
assert call_count == 2
|
||||
assert block.execution_stats.input_token_count == 80 # 40 * 2
|
||||
assert block.execution_stats.output_token_count == 40 # 20 * 2
|
||||
assert block.execution_stats.llm_call_count == 2
|
||||
# Check stats
|
||||
assert block.execution_stats.input_token_count == 50
|
||||
assert block.execution_stats.output_token_count == 30
|
||||
assert block.execution_stats.llm_call_count == 1
|
||||
|
||||
# Check output
|
||||
assert outputs["generated_list"] == ["item1", "item2", "item3"]
|
||||
# Check that individual items were yielded
|
||||
# Note: outputs dict will only contain the last value for each key
|
||||
# So we need to check that the list_item output exists
|
||||
assert "list_item" in outputs
|
||||
# The list_item output should be the last item in the list
|
||||
assert outputs["list_item"] == "item3"
|
||||
assert "prompt" in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_llm_stats(self):
|
||||
|
||||
@@ -165,7 +165,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = await SmartDecisionMakerBlock._create_function_signature(
|
||||
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -215,7 +215,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"content": "I need to think about this.",
|
||||
}
|
||||
|
||||
# Mock the _create_function_signature method to avoid database calls
|
||||
# Mock the _create_tool_node_signatures method to avoid database calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
@@ -224,7 +224,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
@@ -293,6 +293,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
@@ -318,7 +319,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -375,7 +376,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -425,7 +426,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -450,13 +451,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify tool outputs were generated correctly
|
||||
assert "tools_^_search_keywords_~_query" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert "tools_^_test-sink-node-id_~_query" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_query"] == "test"
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
|
||||
# Optional parameter should be None when not provided
|
||||
assert "tools_^_search_keywords_~_optional_param" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] is None
|
||||
assert "tools_^_test-sink-node-id_~_optional_param" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_optional_param"] is None
|
||||
|
||||
# Test case 4: Valid tool call with ALL parameters (should succeed)
|
||||
mock_tool_call_all_params = MagicMock()
|
||||
@@ -479,7 +480,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -504,9 +505,9 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify all tool outputs were generated correctly
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
assert outputs["tools_^_test-sink-node-id_~_query"] == "test"
|
||||
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_test-sink-node-id_~_optional_param"] == "custom_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -530,6 +531,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
@@ -588,7 +590,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -617,8 +619,8 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify the tool output was generated successfully
|
||||
assert "tools_^_test_tool_~_param" in outputs
|
||||
assert outputs["tools_^_test_tool_~_param"] == "test_value"
|
||||
assert "tools_^_test-sink-node-id_~_param" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_param"] == "test_value"
|
||||
|
||||
# Verify conversation history was properly maintained
|
||||
assert "conversations" in outputs
|
||||
@@ -656,7 +658,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
@@ -702,7 +704,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
@@ -192,7 +192,7 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_function_signature():
|
||||
async def test_create_tool_node_signatures():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
@@ -241,7 +241,7 @@ async def test_create_function_signature():
|
||||
]
|
||||
|
||||
# Call the method that builds signatures
|
||||
tool_functions = await block._create_function_signature("test_node_id")
|
||||
tool_functions = await block._create_tool_node_signatures("test_node_id")
|
||||
|
||||
# Verify we got 2 tool functions (one for dict, one for list)
|
||||
assert len(tool_functions) == 2
|
||||
@@ -310,7 +310,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
@@ -325,6 +325,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
"values___email": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
@@ -351,16 +352,16 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
|
||||
assert "tools_^_createdictionaryblock_~_values___name" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
|
||||
# Verify the outputs use sink node ID in output keys
|
||||
assert "tools_^_test-sink-node-id_~_values___name" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_values___name"] == "Alice"
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___age" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
|
||||
assert "tools_^_test-sink-node-id_~_values___age" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_values___age"] == 30
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___email" in outputs
|
||||
assert "tools_^_test-sink-node-id_~_values___email" in outputs
|
||||
assert (
|
||||
outputs["tools_^_createdictionaryblock_~_values___email"]
|
||||
outputs["tools_^_test-sink-node-id_~_values___email"]
|
||||
== "alice@example.com"
|
||||
)
|
||||
|
||||
@@ -488,7 +489,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
@@ -505,6 +506,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@@ -29,6 +29,13 @@ from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.cache import cached
|
||||
from backend.util.exceptions import (
|
||||
BlockError,
|
||||
BlockExecutionError,
|
||||
BlockInputError,
|
||||
BlockOutputError,
|
||||
BlockUnknownError,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .model import (
|
||||
@@ -542,9 +549,25 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
)
|
||||
|
||||
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
async for output_name, output_data in self._execute(input_data, **kwargs):
|
||||
yield output_name, output_data
|
||||
except Exception as ex:
|
||||
if not isinstance(ex, BlockError):
|
||||
raise BlockUnknownError(
|
||||
message=str(ex),
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from ex
|
||||
else:
|
||||
raise ex
|
||||
|
||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
f"Unable to execute block with invalid input data: {error}"
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
async for output_name, output_data in self.run(
|
||||
@@ -552,11 +575,17 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
**kwargs,
|
||||
):
|
||||
if output_name == "error":
|
||||
raise RuntimeError(output_data)
|
||||
raise BlockExecutionError(
|
||||
message=output_data, block_name=self.name, block_id=self.id
|
||||
)
|
||||
if self.block_type == BlockType.STANDARD and (
|
||||
error := self.output_schema.validate_field(output_name, output_data)
|
||||
):
|
||||
raise ValueError(f"Block produced an invalid output data: {error}")
|
||||
raise BlockOutputError(
|
||||
message=f"Block produced an invalid output data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
def is_triggered_by_event_type(
|
||||
|
||||
@@ -82,20 +82,15 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_8B: 1,
|
||||
LlmModel.LLAMA3_70B: 1,
|
||||
LlmModel.GEMMA2_9B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
"""
|
||||
Diagnostics module for monitoring and troubleshooting execution status.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from prisma.models import AgentGraphExecution
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunningExecutionDetails(BaseModel):
|
||||
"""Details about a running execution for diagnostics."""
|
||||
|
||||
execution_id: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
graph_version: int
|
||||
user_id: str
|
||||
user_email: Optional[str]
|
||||
status: str
|
||||
started_at: Optional[datetime]
|
||||
queue_status: Optional[str] = None
|
||||
|
||||
|
||||
class ExecutionDiagnostics(BaseModel):
|
||||
"""Overall execution diagnostics information."""
|
||||
|
||||
total_running: int
|
||||
total_queued: int
|
||||
total_incomplete: int
|
||||
|
||||
|
||||
async def get_running_executions_details(
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
) -> list[RunningExecutionDetails]:
|
||||
"""
|
||||
Get detailed information about currently running executions.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return
|
||||
offset: Number of executions to skip
|
||||
|
||||
Returns:
|
||||
List of running execution details
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error retrieving execution details
|
||||
"""
|
||||
try:
|
||||
# Query for running and queued executions
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={
|
||||
"isDeleted": False,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
],
|
||||
},
|
||||
include={
|
||||
"AgentGraph": True,
|
||||
"User": True,
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
skip=offset,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
result = []
|
||||
for exec in executions:
|
||||
# Convert string executionStatus to enum if needed, then to string for response
|
||||
# The database field executionStatus is a string, not an enum
|
||||
status_value = exec.executionStatus
|
||||
if isinstance(status_value, str):
|
||||
# It's already a string, use it directly
|
||||
status_str = status_value
|
||||
else:
|
||||
# It's an enum, get the value
|
||||
status_str = status_value.value
|
||||
|
||||
result.append(
|
||||
RunningExecutionDetails(
|
||||
execution_id=exec.id,
|
||||
graph_id=exec.agentGraphId,
|
||||
graph_name=exec.AgentGraph.name if exec.AgentGraph else "Unknown",
|
||||
graph_version=exec.agentGraphVersion,
|
||||
user_id=exec.userId,
|
||||
user_email=exec.User.email if exec.User else None,
|
||||
status=status_str,
|
||||
started_at=exec.startedAt,
|
||||
queue_status=(
|
||||
exec.queueStatus if hasattr(exec, "queueStatus") else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting running execution details: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_execution_diagnostics() -> ExecutionDiagnostics:
|
||||
"""
|
||||
Get overall execution diagnostics information.
|
||||
|
||||
Returns:
|
||||
ExecutionDiagnostics with counts of executions by status
|
||||
"""
|
||||
try:
|
||||
running_count = await AgentGraphExecution.prisma().count(
|
||||
where={
|
||||
"isDeleted": False,
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
}
|
||||
)
|
||||
|
||||
queued_count = await AgentGraphExecution.prisma().count(
|
||||
where={
|
||||
"isDeleted": False,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
}
|
||||
)
|
||||
|
||||
incomplete_count = await AgentGraphExecution.prisma().count(
|
||||
where={
|
||||
"isDeleted": False,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
}
|
||||
)
|
||||
|
||||
return ExecutionDiagnostics(
|
||||
total_running=running_count,
|
||||
total_queued=queued_count,
|
||||
total_incomplete=incomplete_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting execution diagnostics: {e}")
|
||||
raise
|
||||
@@ -92,6 +92,18 @@ def get_dynamic_field_description(field_name: str) -> str:
|
||||
return f"Value for {field_name}"
|
||||
|
||||
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
"""Check if a pin name represents a tool connection."""
|
||||
return name.startswith("tools_^_") or name == "tools"
|
||||
|
||||
|
||||
def sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
if is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Dynamic field parsing and merging utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -137,30 +149,64 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
return tokens
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
|
||||
def parse_execution_output(
|
||||
output_item: tuple[str, Any],
|
||||
link_output_selector: str,
|
||||
sink_node_id: str | None = None,
|
||||
sink_pin_name: str | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
Retrieve a nested value out of `output` using the flattened `link_output_selector`.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path) returns **None**.
|
||||
|
||||
### Special Case: Tool pins
|
||||
For regular output pins, the `output_item`'s name will simply be the field name, and
|
||||
`link_output_selector` (= the `source_name` of the link) may provide a "selector"
|
||||
used to extract part of the output value and route it through the link
|
||||
to the next node.
|
||||
|
||||
However, for tool pins, it is the other way around: the `output_item`'s name
|
||||
provides the routing information (`tools_^_{sink_node_id}_~_{field_name}`),
|
||||
and the `link_output_selector` is simply `"tools"`
|
||||
(or `"tools_^_{tool_name}_~_{field_name}"` for backward compatibility).
|
||||
|
||||
Args:
|
||||
output: Tuple of (base_name, data) representing a block output entry
|
||||
name: The flattened field name to extract from the output data
|
||||
output_item: Tuple of (base_name, data) representing a block output entry.
|
||||
link_output_selector: The flattened field name to extract from the output data.
|
||||
sink_node_id: Sink node ID, used for tool use routing.
|
||||
sink_pin_name: Sink pin name, used for tool use routing.
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or None if not found/invalid
|
||||
The value at the specified path, or `None` if not found/invalid.
|
||||
"""
|
||||
base_name, data = output
|
||||
output_pin_name, data = output_item
|
||||
|
||||
# Special handling for tool pins
|
||||
if is_tool_pin(link_output_selector) and ( # "tools" or "tools_^_…"
|
||||
output_pin_name.startswith("tools_^_") and "_~_" in output_pin_name
|
||||
):
|
||||
if not (sink_node_id and sink_pin_name):
|
||||
raise ValueError(
|
||||
"sink_node_id and sink_pin_name must be provided for tool pin routing"
|
||||
)
|
||||
|
||||
# Extract routing information from emit key: tools_^_{node_id}_~_{field}
|
||||
selector = output_pin_name[8:] # Remove "tools_^_" prefix
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
if target_node_id == sink_node_id and target_input_pin == sink_pin_name:
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
if link_output_selector == output_pin_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
if not link_output_selector.startswith(output_pin_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
path = link_output_selector[len(output_pin_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
|
||||
@@ -175,6 +175,10 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
default=None,
|
||||
description="AI-generated summary of what the agent did",
|
||||
)
|
||||
correctness_score: float | None = Field(
|
||||
default=None,
|
||||
description="AI-generated score (0.0-1.0) indicating how well the execution achieved its intended purpose",
|
||||
)
|
||||
|
||||
def to_db(self) -> GraphExecutionStats:
|
||||
return GraphExecutionStats(
|
||||
@@ -187,6 +191,13 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
node_error_count=self.node_error_count,
|
||||
error=self.error,
|
||||
activity_status=self.activity_status,
|
||||
correctness_score=self.correctness_score,
|
||||
)
|
||||
|
||||
def without_activity_features(self) -> "GraphExecutionMeta.Stats":
|
||||
"""Return a copy of stats with activity features (activity_status, correctness_score) set to None."""
|
||||
return self.model_copy(
|
||||
update={"activity_status": None, "correctness_score": None}
|
||||
)
|
||||
|
||||
stats: Stats | None
|
||||
@@ -244,6 +255,7 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
else stats.error
|
||||
),
|
||||
activity_status=stats.activity_status,
|
||||
correctness_score=stats.correctness_score,
|
||||
)
|
||||
if stats
|
||||
else None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
@@ -17,8 +18,6 @@ from prisma.types import (
|
||||
AgentGraphWhereInput,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
LibraryAgentWhereInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic.fields import computed_field
|
||||
@@ -27,7 +26,7 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import extract_base_field_name
|
||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
@@ -37,7 +36,7 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.exceptions import GraphNotInLibraryError
|
||||
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
@@ -579,9 +578,9 @@ class GraphModel(Graph):
|
||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||
)
|
||||
provided_inputs = set(
|
||||
[_sanitize_pin_name(name) for name in node.input_default]
|
||||
[sanitize_pin_name(name) for name in node.input_default]
|
||||
+ [
|
||||
_sanitize_pin_name(link.sink_name)
|
||||
sanitize_pin_name(link.sink_name)
|
||||
for link in input_links.get(node.id, [])
|
||||
]
|
||||
+ ([name for name in node_input_mask] if node_input_mask else [])
|
||||
@@ -697,7 +696,7 @@ class GraphModel(Graph):
|
||||
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
)
|
||||
|
||||
sanitized_name = _sanitize_pin_name(name)
|
||||
sanitized_name = sanitize_pin_name(name)
|
||||
vals = node.input_default
|
||||
if i == 0:
|
||||
fields = (
|
||||
@@ -711,7 +710,7 @@ class GraphModel(Graph):
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields and not _is_tool_pin(name):
|
||||
if sanitized_name not in fields and not is_tool_pin(name):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
@@ -751,17 +750,6 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
|
||||
def _is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
|
||||
def _sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
if _is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
@@ -897,9 +885,11 @@ async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
include_subgraphs: bool = False,
|
||||
skip_access_check: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -922,19 +912,9 @@ async def get_graph(
|
||||
if graph is None:
|
||||
return None
|
||||
|
||||
if graph.userId != user_id:
|
||||
store_listing_filter: StoreListingVersionWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"isDeleted": False,
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
}
|
||||
if version is not None:
|
||||
store_listing_filter["agentGraphVersion"] = version
|
||||
|
||||
if not skip_access_check and graph.userId != user_id:
|
||||
# For access, the graph must be owned by the user or listed in the store
|
||||
if not await StoreListingVersion.prisma().find_first(
|
||||
where=store_listing_filter, order={"agentGraphVersion": "desc"}
|
||||
):
|
||||
if not await is_graph_published_in_marketplace(graph_id, graph.version):
|
||||
return None
|
||||
|
||||
if include_subgraphs or for_export:
|
||||
@@ -978,13 +958,8 @@ async def get_graph_as_admin(
|
||||
# For access, the graph must be owned by the user or listed in the store
|
||||
if graph is None or (
|
||||
graph.userId != user_id
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": version or graph.version,
|
||||
}
|
||||
)
|
||||
and not await is_graph_published_in_marketplace(
|
||||
graph_id, version or graph.version
|
||||
)
|
||||
):
|
||||
return None
|
||||
@@ -1112,7 +1087,7 @@ async def delete_graph(graph_id: str, user_id: str) -> int:
|
||||
|
||||
|
||||
async def validate_graph_execution_permissions(
|
||||
graph_id: str, user_id: str, graph_version: Optional[int] = None
|
||||
user_id: str, graph_id: str, graph_version: int, is_sub_graph: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user has permission to execute a specific graph.
|
||||
@@ -1120,47 +1095,88 @@ async def validate_graph_execution_permissions(
|
||||
This function performs comprehensive authorization checks and raises specific
|
||||
exceptions for different types of failures to enable appropriate error handling.
|
||||
|
||||
## Logic
|
||||
A user can execute a graph if any of these is true:
|
||||
1. They own the graph and some version of it is still listed in their library
|
||||
2. The graph is published in the marketplace and listed in their library
|
||||
3. The graph is published in the marketplace and is being executed as a sub-agent
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to check
|
||||
user_id: The ID of the user
|
||||
graph_version: Optional specific version to check. If None (recommended),
|
||||
performs version-agnostic check allowing execution of any
|
||||
version as long as the graph is in the user's library.
|
||||
This is important for sub-graphs that may reference older
|
||||
versions no longer in the library.
|
||||
graph_version: The version of the graph to check
|
||||
is_sub_graph: Whether this is being executed as a sub-graph.
|
||||
If `True`, the graph isn't required to be in the user's Library.
|
||||
|
||||
Raises:
|
||||
GraphNotInLibraryError: If the graph is not in the user's library (deleted/archived)
|
||||
GraphNotAccessibleError: If the graph is not accessible to the user.
|
||||
GraphNotInLibraryError: If the graph is not in the user's library (deleted/archived).
|
||||
NotAuthorizedError: If the user lacks execution permissions for other reasons
|
||||
"""
|
||||
graph, library_agent = await asyncio.gather(
|
||||
AgentGraph.prisma().find_unique(
|
||||
where={"graphVersionId": {"id": graph_id, "version": graph_version}}
|
||||
),
|
||||
LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Step 1: Check library membership (raises specific GraphNotInLibraryError)
|
||||
where_clause: LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
# Step 1: Check if user owns this graph
|
||||
user_owns_graph = graph and graph.userId == user_id
|
||||
|
||||
if graph_version is not None:
|
||||
where_clause["agentGraphVersion"] = graph_version
|
||||
# Step 2: Check if agent is in the library *and not deleted*
|
||||
user_has_in_library = library_agent is not None
|
||||
|
||||
count = await LibraryAgent.prisma().count(where=where_clause)
|
||||
if count == 0:
|
||||
raise GraphNotInLibraryError(
|
||||
f"Graph #{graph_id} is not accessible in your library"
|
||||
# Step 3: Apply permission logic
|
||||
if not (
|
||||
user_owns_graph
|
||||
or await is_graph_published_in_marketplace(graph_id, graph_version)
|
||||
):
|
||||
raise GraphNotAccessibleError(
|
||||
f"You do not have access to graph #{graph_id} v{graph_version}: "
|
||||
"it is not owned by you and not available in the Marketplace"
|
||||
)
|
||||
elif not (user_has_in_library or is_sub_graph):
|
||||
raise GraphNotInLibraryError(f"Graph #{graph_id} is not in your library")
|
||||
|
||||
# Step 2: Check execution-specific permissions (raises generic NotAuthorizedError)
|
||||
# Additional authorization checks beyond library membership:
|
||||
# Step 6: Check execution-specific permissions (raises generic NotAuthorizedError)
|
||||
# Additional authorization checks beyond the above:
|
||||
# 1. Check if user has execution credits (future)
|
||||
# 2. Check if graph is suspended/disabled (future)
|
||||
# 3. Check rate limiting rules (future)
|
||||
# 4. Check organization-level permissions (future)
|
||||
|
||||
# For now, library membership is sufficient for execution permission
|
||||
# Future enhancements can add more granular permission checks here
|
||||
# When adding new checks, raise NotAuthorizedError for non-library issues
|
||||
# For now, the above check logic is sufficient for execution permission.
|
||||
# Future enhancements can add more granular permission checks here.
|
||||
# When adding new checks, raise NotAuthorizedError for non-library issues.
|
||||
|
||||
|
||||
async def is_graph_published_in_marketplace(graph_id: str, graph_version: int) -> bool:
|
||||
"""
|
||||
Check if a graph is published in the marketplace.
|
||||
|
||||
Params:
|
||||
graph_id: The ID of the graph to check
|
||||
graph_version: The version of the graph to check
|
||||
|
||||
Returns:
|
||||
True if the graph is published and approved in the marketplace, False otherwise
|
||||
"""
|
||||
marketplace_listing = await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
return marketplace_listing is not None
|
||||
|
||||
|
||||
async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
@@ -1177,7 +1193,7 @@ async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphMo
|
||||
"""
|
||||
Forks a graph by copying it and all its nodes and links to a new graph.
|
||||
"""
|
||||
graph = await get_graph(graph_id, graph_version, user_id, True)
|
||||
graph = await get_graph(graph_id, graph_version, user_id=user_id, for_export=True)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
|
||||
|
||||
|
||||
@@ -833,6 +833,10 @@ class GraphExecutionStats(BaseModel):
|
||||
activity_status: Optional[str] = Field(
|
||||
default=None, description="AI-generated summary of what the agent did"
|
||||
)
|
||||
correctness_score: Optional[float] = Field(
|
||||
default=None,
|
||||
description="AI-generated score (0.0-1.0) indicating how well the execution achieved its intended purpose",
|
||||
)
|
||||
|
||||
|
||||
class UserExecutionSummaryStats(BaseModel):
|
||||
|
||||
33
autogpt_platform/backend/backend/data/notification_bus.py
Normal file
33
autogpt_platform/backend/backend/data/notification_bus.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.server.model import NotificationPayload
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
class NotificationEvent(BaseModel):
|
||||
"""Generic notification event destined for websocket delivery."""
|
||||
|
||||
user_id: str
|
||||
payload: NotificationPayload
|
||||
|
||||
|
||||
class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
|
||||
Model = NotificationEvent # type: ignore
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return Settings().config.notification_event_bus_name
|
||||
|
||||
async def publish(self, event: NotificationEvent) -> None:
|
||||
await self.publish_event(event, event.user_id)
|
||||
|
||||
async def listen(
|
||||
self, user_id: str = "*"
|
||||
) -> AsyncGenerator[NotificationEvent, None]:
|
||||
async for event in self.listen_events(user_id):
|
||||
yield event
|
||||
@@ -11,6 +11,11 @@ from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.server.model import OnboardingNotificationPayload
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
@@ -82,22 +87,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["completedSteps"] = list(
|
||||
set(data.completedSteps + onboarding.completedSteps)
|
||||
)
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.MARKETPLACE_VISIT,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.RE_RUN_AGENT,
|
||||
OnboardingStep.SCHEDULE_AGENT,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.RUN_3_DAYS,
|
||||
OnboardingStep.TRIGGER_WEBHOOK,
|
||||
OnboardingStep.RUN_14_DAYS,
|
||||
OnboardingStep.RUN_AGENTS_100,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step, onboarding)
|
||||
for step in data.completedSteps:
|
||||
if step not in onboarding.completedSteps:
|
||||
await _reward_user(user_id, onboarding, step)
|
||||
await _send_onboarding_notification(user_id, step)
|
||||
if data.walletShown:
|
||||
update["walletShown"] = data.walletShown
|
||||
if data.notified is not None:
|
||||
@@ -130,7 +123,7 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
)
|
||||
|
||||
|
||||
async def reward_user(user_id: str, step: OnboardingStep, onboarding: UserOnboarding):
|
||||
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
||||
reward = 0
|
||||
match step:
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
@@ -180,20 +173,32 @@ async def reward_user(user_id: str, step: OnboardingStep, onboarding: UserOnboar
|
||||
)
|
||||
|
||||
|
||||
async def complete_webhook_trigger_step(user_id: str):
|
||||
async def complete_onboarding_step(user_id: str, step: OnboardingStep):
|
||||
"""
|
||||
Completes the TRIGGER_WEBHOOK onboarding step for the user if not already completed.
|
||||
Completes the specified onboarding step for the user if not already completed.
|
||||
"""
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
if OnboardingStep.TRIGGER_WEBHOOK not in onboarding.completedSteps:
|
||||
if step not in onboarding.completedSteps:
|
||||
await update_user_onboarding(
|
||||
user_id,
|
||||
UserOnboardingUpdate(
|
||||
completedSteps=onboarding.completedSteps
|
||||
+ [OnboardingStep.TRIGGER_WEBHOOK]
|
||||
),
|
||||
UserOnboardingUpdate(completedSteps=onboarding.completedSteps + [step]),
|
||||
)
|
||||
await _send_onboarding_notification(user_id, step)
|
||||
|
||||
|
||||
async def _send_onboarding_notification(user_id: str, step: OnboardingStep):
|
||||
"""
|
||||
Sends an onboarding notification to the user for the specified step.
|
||||
"""
|
||||
payload = OnboardingNotificationPayload(
|
||||
type="onboarding",
|
||||
event="step_completed",
|
||||
step=step.value,
|
||||
)
|
||||
await AsyncRedisNotificationEventBus().publish(
|
||||
NotificationEvent(user_id=user_id, payload=payload)
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
|
||||
@@ -13,12 +13,11 @@ except ImportError:
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.llm import LlmModel, llm_call
|
||||
from backend.blocks.llm import AIStructuredResponseGeneratorBlock, LlmModel
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import APIKeyCredentials, GraphExecutionStats
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
@@ -70,6 +69,13 @@ class NodeRelation(TypedDict):
|
||||
sink_block_name: NotRequired[str] # Optional, only set if block exists
|
||||
|
||||
|
||||
class ActivityStatusResponse(TypedDict):
|
||||
"""Type definition for structured activity status response."""
|
||||
|
||||
activity_status: str
|
||||
correctness_score: float
|
||||
|
||||
|
||||
def _truncate_uuid(uuid_str: str) -> str:
|
||||
"""Truncate UUID to first segment to reduce payload size."""
|
||||
if not uuid_str:
|
||||
@@ -85,9 +91,11 @@ async def generate_activity_status_for_execution(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
user_id: str,
|
||||
execution_status: ExecutionStatus | None = None,
|
||||
) -> str | None:
|
||||
model_name: str = "gpt-4o-mini",
|
||||
skip_feature_flag: bool = False,
|
||||
) -> ActivityStatusResponse | None:
|
||||
"""
|
||||
Generate an AI-based activity status summary for a graph execution.
|
||||
Generate an AI-based activity status summary and correctness assessment for a graph execution.
|
||||
|
||||
This function handles all the data collection and AI generation logic,
|
||||
keeping the manager integration simple.
|
||||
@@ -102,10 +110,13 @@ async def generate_activity_status_for_execution(
|
||||
execution_status: The overall execution status (COMPLETED, FAILED, TERMINATED)
|
||||
|
||||
Returns:
|
||||
AI-generated activity status string, or None if feature is disabled
|
||||
AI-generated activity status response with activity_status and correctness_status,
|
||||
or None if feature is disabled
|
||||
"""
|
||||
# Check LaunchDarkly feature flag for AI activity status generation with full context support
|
||||
if not await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
if not skip_feature_flag and not await is_feature_enabled(
|
||||
Flag.AI_ACTIVITY_STATUS, user_id
|
||||
):
|
||||
logger.debug("AI activity status generation is disabled via LaunchDarkly")
|
||||
return None
|
||||
|
||||
@@ -141,16 +152,27 @@ async def generate_activity_status_for_execution(
|
||||
execution_status,
|
||||
)
|
||||
|
||||
# Prepare prompt for AI
|
||||
# Prepare prompt for AI with structured output requirements
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant summarizing what you just did for a user in simple, friendly language. "
|
||||
"Write from the user's perspective about what they accomplished, NOT about technical execution details. "
|
||||
"Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps. "
|
||||
"Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc. "
|
||||
"Keep it to 3 sentences maximum. Be conversational and human-friendly.\n\n"
|
||||
"You are an AI assistant analyzing what an agent execution accomplished and whether it worked correctly. "
|
||||
"You need to provide both a user-friendly summary AND a correctness assessment.\n\n"
|
||||
"FOR THE ACTIVITY STATUS:\n"
|
||||
"- Write from the user's perspective about what they accomplished, NOT about technical execution details\n"
|
||||
"- Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps\n"
|
||||
"- Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc.\n"
|
||||
"- Keep it to 3 sentences maximum. Be conversational and human-friendly\n\n"
|
||||
"FOR THE CORRECTNESS SCORE:\n"
|
||||
"- Provide a score from 0.0 to 1.0 indicating how well the execution achieved its intended purpose\n"
|
||||
"- Use this scoring guide:\n"
|
||||
" 0.0-0.2: Failure - The result clearly did not meet the task requirements\n"
|
||||
" 0.2-0.4: Poor - Major issues; only small parts of the goal were achieved\n"
|
||||
" 0.4-0.6: Partial Success - Some objectives met, but with noticeable gaps or inaccuracies\n"
|
||||
" 0.6-0.8: Mostly Successful - Largely achieved the intended outcome, with minor flaws\n"
|
||||
" 0.8-1.0: Success - Fully met or exceeded the task requirements\n"
|
||||
"- Base the score on actual outputs produced, not just technical completion\n\n"
|
||||
"UNDERSTAND THE INTENDED PURPOSE:\n"
|
||||
"- FIRST: Read the graph description carefully to understand what the user wanted to accomplish\n"
|
||||
"- The graph name and description tell you the main goal/intention of this automation\n"
|
||||
@@ -186,7 +208,7 @@ async def generate_activity_status_for_execution(
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"A user ran '{graph_name}' to accomplish something. Based on this execution data, "
|
||||
f"write what they achieved in simple, user-friendly terms:\n\n"
|
||||
f"provide both an activity summary and correctness assessment:\n\n"
|
||||
f"{json.dumps(execution_data, indent=2)}\n\n"
|
||||
"ANALYSIS CHECKLIST:\n"
|
||||
"1. READ graph_info.description FIRST - this tells you what the user intended to accomplish\n"
|
||||
@@ -203,13 +225,20 @@ async def generate_activity_status_for_execution(
|
||||
"- If description mentions 'content generation' → was content actually generated?\n"
|
||||
"- If description mentions 'social media posting' → were posts actually made?\n"
|
||||
"- Match the outputs to the stated intention, not just technical completion\n\n"
|
||||
"Write 1-3 sentences about what the user accomplished, such as:\n"
|
||||
"PROVIDE:\n"
|
||||
"activity_status: 1-3 sentences about what the user accomplished, such as:\n"
|
||||
"- 'I analyzed your resume and provided detailed feedback for the IT industry.'\n"
|
||||
"- 'I couldn't complete the task because critical steps failed to produce any results.'\n"
|
||||
"- 'I failed to generate the content you requested due to missing API access.'\n"
|
||||
"- 'I extracted key information from your documents and organized it into a summary.'\n"
|
||||
"- 'The task failed because the blog post creation step didn't produce any output.'\n\n"
|
||||
"BE CRITICAL: If the graph's intended purpose (from description) wasn't achieved, report this as a failure even if status is 'completed'."
|
||||
"correctness_score: A float score from 0.0 to 1.0 based on how well the intended purpose was achieved:\n"
|
||||
"- 0.0-0.2: Failure (didn't meet requirements)\n"
|
||||
"- 0.2-0.4: Poor (major issues, minimal achievement)\n"
|
||||
"- 0.4-0.6: Partial Success (some objectives met with gaps)\n"
|
||||
"- 0.6-0.8: Mostly Successful (largely achieved with minor flaws)\n"
|
||||
"- 0.8-1.0: Success (fully met or exceeded requirements)\n\n"
|
||||
"BE CRITICAL: If the graph's intended purpose (from description) wasn't achieved, use a low score (0.0-0.4) even if status is 'completed'."
|
||||
),
|
||||
},
|
||||
]
|
||||
@@ -227,17 +256,61 @@ async def generate_activity_status_for_execution(
|
||||
title="System OpenAI",
|
||||
)
|
||||
|
||||
# Make LLM call using current event loop
|
||||
activity_status = await _call_llm_direct(credentials, prompt)
|
||||
# Define expected response format
|
||||
expected_format = {
|
||||
"activity_status": "A user-friendly 1-3 sentence summary of what was accomplished",
|
||||
"correctness_score": "Float score from 0.0 to 1.0 indicating how well the execution achieved its intended purpose",
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Generated activity status for {graph_exec_id}: {activity_status}"
|
||||
# Use existing AIStructuredResponseGeneratorBlock for structured LLM call
|
||||
structured_block = AIStructuredResponseGeneratorBlock()
|
||||
|
||||
# Convert credentials to the format expected by AIStructuredResponseGeneratorBlock
|
||||
credentials_input = {
|
||||
"provider": credentials.provider,
|
||||
"id": credentials.id,
|
||||
"type": credentials.type,
|
||||
"title": credentials.title,
|
||||
}
|
||||
|
||||
structured_input = AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=prompt[1]["content"], # User prompt content
|
||||
sys_prompt=prompt[0]["content"], # System prompt content
|
||||
expected_format=expected_format,
|
||||
model=LlmModel(model_name),
|
||||
credentials=credentials_input, # type: ignore
|
||||
max_tokens=150,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
return activity_status
|
||||
# Execute the structured LLM call
|
||||
async for output_name, output_data in structured_block.run(
|
||||
structured_input, credentials=credentials
|
||||
):
|
||||
if output_name == "response":
|
||||
response = output_data
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Failed to get response from structured LLM call")
|
||||
|
||||
# Create typed response with validation
|
||||
correctness_score = float(response["correctness_score"])
|
||||
# Clamp score to valid range
|
||||
correctness_score = max(0.0, min(1.0, correctness_score))
|
||||
|
||||
activity_response: ActivityStatusResponse = {
|
||||
"activity_status": response["activity_status"],
|
||||
"correctness_score": correctness_score,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Generated activity status for {graph_exec_id}: {activity_response}"
|
||||
)
|
||||
|
||||
return activity_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
logger.exception(
|
||||
f"Failed to generate activity status for execution {graph_exec_id}: {str(e)}"
|
||||
)
|
||||
return None
|
||||
@@ -448,23 +521,3 @@ def _build_execution_summary(
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@func_retry
|
||||
async def _call_llm_direct(
|
||||
credentials: APIKeyCredentials, prompt: list[dict[str, str]]
|
||||
) -> str:
|
||||
"""Make direct LLM call."""
|
||||
|
||||
response = await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=LlmModel.GPT4O_MINI,
|
||||
prompt=prompt,
|
||||
max_tokens=150,
|
||||
compress_prompt_to_fit=True,
|
||||
)
|
||||
|
||||
if response and response.response:
|
||||
return response.response.strip()
|
||||
else:
|
||||
return "Unable to generate activity summary"
|
||||
|
||||
@@ -7,12 +7,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LLMResponse
|
||||
from backend.blocks.llm import LlmModel, LLMResponse
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.activity_status_generator import (
|
||||
_build_execution_summary,
|
||||
_call_llm_direct,
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
|
||||
@@ -373,25 +372,24 @@ class TestLLMCall:
|
||||
"""Tests for LLM calling functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_direct_success(self):
|
||||
"""Test successful LLM call."""
|
||||
async def test_structured_llm_call_success(self):
|
||||
"""Test successful structured LLM call."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
prompt=[],
|
||||
response="Agent successfully processed user input and generated response.",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=20,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.llm_call"
|
||||
) as mock_llm_call:
|
||||
mock_llm_call.return_value = mock_response
|
||||
with patch("backend.blocks.llm.llm_call") as mock_llm_call, patch(
|
||||
"backend.blocks.llm.secrets.token_hex", return_value="test123"
|
||||
):
|
||||
mock_llm_call.return_value = LLMResponse(
|
||||
raw_response={},
|
||||
prompt=[],
|
||||
response='<json_output id="test123">{"activity_status": "Test completed successfully", "correctness_score": 0.9}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=20,
|
||||
)
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
id="test",
|
||||
@@ -401,26 +399,61 @@ class TestLLMCall:
|
||||
)
|
||||
|
||||
prompt = [{"role": "user", "content": "Test prompt"}]
|
||||
expected_format = {
|
||||
"activity_status": "User-friendly summary",
|
||||
"correctness_score": "Float score from 0.0 to 1.0",
|
||||
}
|
||||
|
||||
result = await _call_llm_direct(credentials, prompt)
|
||||
# Create structured block and input
|
||||
structured_block = AIStructuredResponseGeneratorBlock()
|
||||
credentials_input = {
|
||||
"provider": credentials.provider,
|
||||
"id": credentials.id,
|
||||
"type": credentials.type,
|
||||
"title": credentials.title,
|
||||
}
|
||||
|
||||
assert (
|
||||
result
|
||||
== "Agent successfully processed user input and generated response."
|
||||
structured_input = AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=prompt[0]["content"],
|
||||
expected_format=expected_format,
|
||||
model=LlmModel.GPT4O_MINI,
|
||||
credentials=credentials_input, # type: ignore
|
||||
)
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
# Execute the structured LLM call
|
||||
result = None
|
||||
async for output_name, output_data in structured_block.run(
|
||||
structured_input, credentials=credentials
|
||||
):
|
||||
if output_name == "response":
|
||||
result = output_data
|
||||
break
|
||||
|
||||
assert result is not None
|
||||
assert result["activity_status"] == "Test completed successfully"
|
||||
assert result["correctness_score"] == 0.9
|
||||
mock_llm_call.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_direct_no_response(self):
|
||||
"""Test LLM call with no response."""
|
||||
async def test_structured_llm_call_validation_error(self):
|
||||
"""Test structured LLM call with validation error."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.llm_call"
|
||||
) as mock_llm_call:
|
||||
mock_llm_call.return_value = None
|
||||
with patch("backend.blocks.llm.llm_call") as mock_llm_call, patch(
|
||||
"backend.blocks.llm.secrets.token_hex", return_value="test123"
|
||||
):
|
||||
# Return invalid JSON that will fail validation (missing required field)
|
||||
mock_llm_call.return_value = LLMResponse(
|
||||
raw_response={},
|
||||
prompt=[],
|
||||
response='<json_output id="test123">{"activity_status": "Test completed successfully"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=20,
|
||||
)
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
id="test",
|
||||
@@ -430,10 +463,36 @@ class TestLLMCall:
|
||||
)
|
||||
|
||||
prompt = [{"role": "user", "content": "Test prompt"}]
|
||||
expected_format = {
|
||||
"activity_status": "User-friendly summary",
|
||||
"correctness_score": "Float score from 0.0 to 1.0",
|
||||
}
|
||||
|
||||
result = await _call_llm_direct(credentials, prompt)
|
||||
# Create structured block and input
|
||||
structured_block = AIStructuredResponseGeneratorBlock()
|
||||
credentials_input = {
|
||||
"provider": credentials.provider,
|
||||
"id": credentials.id,
|
||||
"type": credentials.type,
|
||||
"title": credentials.title,
|
||||
}
|
||||
|
||||
assert result == "Unable to generate activity summary"
|
||||
structured_input = AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=prompt[0]["content"],
|
||||
expected_format=expected_format,
|
||||
model=LlmModel.GPT4O_MINI,
|
||||
credentials=credentials_input, # type: ignore
|
||||
retry=1, # Use fewer retries for faster test
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
Exception
|
||||
): # AIStructuredResponseGeneratorBlock may raise different exceptions
|
||||
async for output_name, output_data in structured_block.run(
|
||||
structured_input, credentials=credentials
|
||||
):
|
||||
if output_name == "response":
|
||||
break
|
||||
|
||||
|
||||
class TestGenerateActivityStatusForExecution:
|
||||
@@ -461,17 +520,25 @@ class TestGenerateActivityStatusForExecution:
|
||||
) as mock_get_block, patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator._call_llm_direct"
|
||||
) as mock_llm, patch(
|
||||
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
|
||||
) as mock_structured_block, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = (
|
||||
"I analyzed your data and provided the requested insights."
|
||||
)
|
||||
|
||||
# Mock the structured block to return our expected response
|
||||
mock_instance = mock_structured_block.return_value
|
||||
|
||||
async def mock_run(*args, **kwargs):
|
||||
yield "response", {
|
||||
"activity_status": "I analyzed your data and provided the requested insights.",
|
||||
"correctness_score": 0.85,
|
||||
}
|
||||
|
||||
mock_instance.run = mock_run
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -482,11 +549,16 @@ class TestGenerateActivityStatusForExecution:
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result == "I analyzed your data and provided the requested insights."
|
||||
assert result is not None
|
||||
assert (
|
||||
result["activity_status"]
|
||||
== "I analyzed your data and provided the requested insights."
|
||||
)
|
||||
assert result["correctness_score"] == 0.85
|
||||
mock_db_client.get_node_executions.assert_called_once()
|
||||
mock_db_client.get_graph_metadata.assert_called_once()
|
||||
mock_db_client.get_graph.assert_called_once()
|
||||
mock_llm.assert_called_once()
|
||||
mock_structured_block.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_status_feature_disabled(self, mock_execution_stats):
|
||||
@@ -574,15 +646,25 @@ class TestGenerateActivityStatusForExecution:
|
||||
) as mock_get_block, patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator._call_llm_direct"
|
||||
) as mock_llm, patch(
|
||||
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
|
||||
) as mock_structured_block, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = "Agent completed execution."
|
||||
|
||||
# Mock the structured block to return our expected response
|
||||
mock_instance = mock_structured_block.return_value
|
||||
|
||||
async def mock_run(*args, **kwargs):
|
||||
yield "response", {
|
||||
"activity_status": "Agent completed execution.",
|
||||
"correctness_score": 0.8,
|
||||
}
|
||||
|
||||
mock_instance.run = mock_run
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -593,10 +675,11 @@ class TestGenerateActivityStatusForExecution:
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result == "Agent completed execution."
|
||||
# Should use fallback graph name in prompt
|
||||
call_args = mock_llm.call_args[0][1] # prompt argument
|
||||
assert "Graph test_graph" in call_args[1]["content"]
|
||||
assert result is not None
|
||||
assert result["activity_status"] == "Agent completed execution."
|
||||
assert result["correctness_score"] == 0.8
|
||||
# The structured block should have been instantiated
|
||||
assert mock_structured_block.called
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
@@ -626,8 +709,8 @@ class TestIntegration:
|
||||
) as mock_get_block, patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator.llm_call"
|
||||
) as mock_llm_call, patch(
|
||||
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
|
||||
) as mock_structured_block, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
@@ -635,15 +718,16 @@ class TestIntegration:
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
prompt=[],
|
||||
response=expected_activity,
|
||||
tool_calls=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=30,
|
||||
)
|
||||
mock_llm_call.return_value = mock_response
|
||||
# Mock the structured block to return our expected response
|
||||
mock_instance = mock_structured_block.return_value
|
||||
|
||||
async def mock_run(*args, **kwargs):
|
||||
yield "response", {
|
||||
"activity_status": expected_activity,
|
||||
"correctness_score": 0.3, # Low score since there was a failure
|
||||
}
|
||||
|
||||
mock_instance.run = mock_run
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -654,24 +738,14 @@ class TestIntegration:
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result == expected_activity
|
||||
assert result is not None
|
||||
assert result["activity_status"] == expected_activity
|
||||
assert result["correctness_score"] == 0.3
|
||||
|
||||
# Verify the correct data was passed to LLM
|
||||
llm_call_args = mock_llm_call.call_args
|
||||
prompt = llm_call_args[1]["prompt"]
|
||||
|
||||
# Check system prompt
|
||||
assert prompt[0]["role"] == "system"
|
||||
assert "user's perspective" in prompt[0]["content"]
|
||||
|
||||
# Check user prompt contains expected data
|
||||
user_content = prompt[1]["content"]
|
||||
assert "Test Integration Agent" in user_content
|
||||
assert "user-friendly terms" in user_content.lower()
|
||||
|
||||
# Verify that execution data is present in the prompt
|
||||
assert "{" in user_content # Should contain JSON data
|
||||
assert "overall_status" in user_content
|
||||
# Verify the structured block was called
|
||||
assert mock_structured_block.called
|
||||
# The structured block should have been instantiated
|
||||
mock_structured_block.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_integration_with_disabled_feature(
|
||||
|
||||
@@ -54,7 +54,9 @@ from backend.executor.activity_status_generator import (
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
@@ -320,7 +322,9 @@ async def _enqueue_next_nodes(
|
||||
next_node_id = node_link.sink_id
|
||||
|
||||
output_name, _ = output
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
next_data = parse_execution_output(
|
||||
output, next_output_name, next_node_id, next_input_name
|
||||
)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
@@ -694,7 +698,7 @@ class ExecutionProcessor:
|
||||
exec_meta.status = status
|
||||
|
||||
# Activity status handling
|
||||
activity_status = asyncio.run_coroutine_threadsafe(
|
||||
activity_response = asyncio.run_coroutine_threadsafe(
|
||||
generate_activity_status_for_execution(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -706,18 +710,21 @@ class ExecutionProcessor:
|
||||
),
|
||||
self.node_execution_loop,
|
||||
).result(timeout=60.0)
|
||||
if activity_status is not None:
|
||||
exec_stats.activity_status = activity_status
|
||||
log_metadata.info(f"Generated activity status: {activity_status}")
|
||||
if activity_response is not None:
|
||||
exec_stats.activity_status = activity_response["activity_status"]
|
||||
exec_stats.correctness_score = activity_response["correctness_score"]
|
||||
log_metadata.info(
|
||||
f"Generated activity status: {activity_response['activity_status']} "
|
||||
f"(correctness: {activity_response['correctness_score']:.2f})"
|
||||
)
|
||||
else:
|
||||
log_metadata.debug(
|
||||
"Activity status generation disabled, not setting field"
|
||||
"Activity status generation disabled, not setting fields"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
@@ -1459,14 +1466,43 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
@func_retry
|
||||
def _ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message based on execution status."""
|
||||
"""
|
||||
Acknowledge or reject the message based on execution status.
|
||||
|
||||
Args:
|
||||
reject: Whether to reject the message
|
||||
requeue: Whether to requeue the message
|
||||
"""
|
||||
|
||||
# Connection can be lost, so always get a fresh channel
|
||||
channel = self.run_client.get_channel()
|
||||
if reject:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
|
||||
)
|
||||
if requeue and settings.config.requeue_by_republishing:
|
||||
# Send rejected message to back of queue using republishing
|
||||
def _republish_to_back():
|
||||
try:
|
||||
# First republish to back of queue
|
||||
self.run_client.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=body.decode(), # publish_message expects string, not bytes
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
# Then reject without requeue (message already republished)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
logger.info("Message requeued to back of queue")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Failed to requeue message to back: {e}"
|
||||
)
|
||||
# Fall back to traditional requeue on failure
|
||||
channel.basic_nack(delivery_tag, requeue=True)
|
||||
|
||||
channel.connection.add_callback_threadsafe(_republish_to_back)
|
||||
else:
|
||||
# Traditional requeue (goes to front) or no requeue
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
|
||||
)
|
||||
else:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_ack(delivery_tag)
|
||||
|
||||
@@ -178,8 +178,9 @@ async def _execute_graph(**kwargs):
|
||||
|
||||
async def _cleanup_orphaned_schedules_for_graph(graph_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Clean up orphaned schedules for a specific graph when execution fails with GraphNotInLibraryError.
|
||||
This happens when an agent is deleted but schedules still exist.
|
||||
Clean up orphaned schedules for a specific graph when execution fails with GraphNotAccessibleError.
|
||||
This happens when an agent is pulled from the Marketplace or deleted
|
||||
but schedules still exist.
|
||||
"""
|
||||
# Use scheduler client to access the scheduler service
|
||||
scheduler_client = get_scheduler_client()
|
||||
|
||||
@@ -477,6 +477,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
@@ -489,6 +490,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_version: The version of the graph to use.
|
||||
graph_credentials_inputs: Credentials inputs to use.
|
||||
nodes_input_masks: Node inputs to use.
|
||||
is_sub_graph: Whether this is a sub-graph execution.
|
||||
|
||||
Returns:
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
@@ -510,19 +512,20 @@ async def validate_and_construct_node_execution_input(
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
# Execution/access permission is checked by validate_graph_execution_permissions
|
||||
skip_access_check=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
# Validate that the user has permission to execute this graph
|
||||
# This checks both library membership and execution permissions,
|
||||
# raising specific exceptions for appropriate error handling
|
||||
# Note: Version-agnostic check to allow execution of graphs that reference
|
||||
# older versions of sub-graphs that may no longer be in the library
|
||||
# raising specific exceptions for appropriate error handling.
|
||||
await gdb.validate_graph_execution_permissions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
# graph_version omitted for version-agnostic permission check
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
is_sub_graph=is_sub_graph,
|
||||
)
|
||||
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
@@ -756,6 +759,7 @@ async def add_graph_execution(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
parent_graph_exec_id: Optional[str] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -770,6 +774,7 @@ async def add_graph_execution(
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
nodes_input_masks: Node inputs to use in the execution.
|
||||
parent_graph_exec_id: The ID of the parent graph execution (for nested executions).
|
||||
is_sub_graph: Whether this is a sub-graph execution.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
@@ -788,6 +793,7 @@ async def add_graph_execution(
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
is_sub_graph=is_sub_graph,
|
||||
)
|
||||
)
|
||||
graph_exec = None
|
||||
|
||||
@@ -111,6 +111,35 @@ def test_parse_execution_output():
|
||||
parse_execution_output(output, "result_@_attr_$_0_#_key") is None
|
||||
) # Should fail at @_attr
|
||||
|
||||
# Test case 7: Tool pin routing with matching node ID and pin name
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "tools", "node123", "query") == "search term"
|
||||
|
||||
# Test case 8: Tool pin routing with node ID mismatch
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "tools", "node456", "query") is None
|
||||
|
||||
# Test case 9: Tool pin routing with pin name mismatch
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "tools", "node123", "different_pin") is None
|
||||
|
||||
# Test case 10: Tool pin routing with complex field names
|
||||
output = ("tools_^_node789_~_nested_field", {"key": "value"})
|
||||
result = parse_execution_output(output, "tools", "node789", "nested_field")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
# Test case 11: Tool pin routing missing required parameters should raise error
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
try:
|
||||
parse_execution_output(output, "tools", "node123") # Missing sink_pin_name
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert "must be provided for tool pin routing" in str(e)
|
||||
|
||||
# Test case 12: Non-tool pin with similar pattern should use normal logic
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "different_name", "node123", "query") is None
|
||||
|
||||
|
||||
def test_merge_execution_input():
|
||||
# Test case for basic list extraction
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
@@ -7,7 +8,7 @@ from backend.data.execution import (
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import WSMessage, WSMethod
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -19,15 +20,24 @@ class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
||||
self.user_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
async def connect_socket(self, websocket: WebSocket):
|
||||
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
if user_id not in self.user_connections:
|
||||
self.user_connections[user_id] = set()
|
||||
self.user_connections[user_id].add(websocket)
|
||||
|
||||
def disconnect_socket(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
self.active_connections.discard(websocket)
|
||||
for subscribers in self.subscriptions.values():
|
||||
subscribers.discard(websocket)
|
||||
user_conns = self.user_connections.get(user_id)
|
||||
if user_conns is not None:
|
||||
user_conns.discard(websocket)
|
||||
if not user_conns:
|
||||
self.user_connections.pop(user_id, None)
|
||||
|
||||
async def subscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
@@ -92,6 +102,26 @@ class ConnectionManager:
|
||||
|
||||
return n_sent
|
||||
|
||||
async def send_notification(
|
||||
self, *, user_id: str, payload: NotificationPayload
|
||||
) -> int:
|
||||
"""Send a notification to all websocket connections belonging to a user."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data=payload.model_dump(),
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.user_connections.get(user_id, set()))
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.data.execution import (
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import WSMessage, WSMethod
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,8 +29,9 @@ def mock_websocket() -> AsyncMock:
|
||||
async def test_connect(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
await connection_manager.connect_socket(mock_websocket)
|
||||
await connection_manager.connect_socket(mock_websocket, user_id="user-1")
|
||||
assert mock_websocket in connection_manager.active_connections
|
||||
assert mock_websocket in connection_manager.user_connections["user-1"]
|
||||
mock_websocket.accept.assert_called_once()
|
||||
|
||||
|
||||
@@ -39,11 +40,13 @@ def test_disconnect(
|
||||
) -> None:
|
||||
connection_manager.active_connections.add(mock_websocket)
|
||||
connection_manager.subscriptions["test_channel_42"] = {mock_websocket}
|
||||
connection_manager.user_connections["user-1"] = {mock_websocket}
|
||||
|
||||
connection_manager.disconnect_socket(mock_websocket)
|
||||
connection_manager.disconnect_socket(mock_websocket, user_id="user-1")
|
||||
|
||||
assert mock_websocket not in connection_manager.active_connections
|
||||
assert mock_websocket not in connection_manager.subscriptions["test_channel_42"]
|
||||
assert "user-1" not in connection_manager.user_connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -207,3 +210,22 @@ async def test_send_execution_result_no_subscribers(
|
||||
await connection_manager.send_execution_update(result)
|
||||
|
||||
mock_websocket.send_text.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_notification(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.user_connections["user-1"] = {mock_websocket}
|
||||
|
||||
await connection_manager.send_notification(
|
||||
user_id="user-1", payload=NotificationPayload(type="info", event="hey")
|
||||
)
|
||||
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
sent_message = mock_websocket.send_text.call_args[0][0]
|
||||
expected_message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={"type": "info", "event": "hey"},
|
||||
).model_dump_json()
|
||||
assert sent_message == expected_message
|
||||
|
||||
@@ -33,7 +33,7 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import complete_webhook_trigger_step
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -376,7 +376,7 @@ async def webhook_ingress_generic(
|
||||
if not (webhook.triggered_nodes or webhook.triggered_presets):
|
||||
return
|
||||
|
||||
await complete_webhook_trigger_step(user_id)
|
||||
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
||||
|
||||
# Execute all triggers concurrently for better performance
|
||||
tasks = []
|
||||
@@ -470,7 +470,9 @@ async def _execute_webhook_preset_trigger(
|
||||
logger.debug(f"Preset #{preset.id} is inactive")
|
||||
return
|
||||
|
||||
graph = await get_graph(preset.graph_id, preset.graph_version, webhook.user_id)
|
||||
graph = await get_graph(
|
||||
preset.graph_id, preset.graph_version, user_id=webhook.user_id
|
||||
)
|
||||
if not graph:
|
||||
logger.error(
|
||||
f"User #{webhook.user_id} has preset #{preset.id} for graph "
|
||||
@@ -562,8 +564,9 @@ async def _cleanup_orphaned_webhook_for_graph(
|
||||
graph_id: str, user_id: str, webhook_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Clean up orphaned webhook connections for a specific graph when execution fails with GraphNotInLibraryError.
|
||||
This happens when an agent is deleted but webhook triggers still exist.
|
||||
Clean up orphaned webhook connections for a specific graph when execution fails with GraphNotAccessibleError.
|
||||
This happens when an agent is pulled from the Marketplace or deleted
|
||||
but webhook triggers still exist.
|
||||
"""
|
||||
try:
|
||||
webhook = await get_webhook(webhook_id, include_relations=True)
|
||||
|
||||
@@ -14,6 +14,7 @@ class WSMethod(enum.Enum):
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
GRAPH_EXECUTION_EVENT = "graph_execution_event"
|
||||
NODE_EXECUTION_EVENT = "node_execution_event"
|
||||
NOTIFICATION = "notification"
|
||||
ERROR = "error"
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
@@ -76,3 +77,12 @@ class TimezoneResponse(pydantic.BaseModel):
|
||||
|
||||
class UpdateTimezoneRequest(pydantic.BaseModel):
|
||||
timezone: TimeZoneName
|
||||
|
||||
|
||||
class NotificationPayload(pydantic.BaseModel):
|
||||
type: str
|
||||
event: str
|
||||
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: str
|
||||
|
||||
@@ -24,10 +24,11 @@ import backend.integrations.webhooks.utils
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.diagnostics_admin_routes
|
||||
import backend.server.v2.admin.execution_analytics_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.builder
|
||||
import backend.server.v2.builder.routes
|
||||
import backend.server.v2.chat.routes as chat_routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
@@ -43,6 +44,7 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.exceptions import (
|
||||
@@ -269,9 +271,9 @@ app.include_router(
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.admin.diagnostics_admin_routes.router,
|
||||
backend.server.v2.admin.execution_analytics_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api",
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
@@ -290,6 +292,11 @@ app.include_router(
|
||||
tags=["v1", "email"],
|
||||
prefix="/api/email",
|
||||
)
|
||||
app.include_router(
|
||||
chat_routes.router,
|
||||
tags=["v2", "chat"],
|
||||
prefix="/api/chat",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
@@ -303,9 +310,14 @@ async def health():
|
||||
|
||||
class AgentServer(backend.util.service.AppProcess):
|
||||
def run(self):
|
||||
cors_params = build_cors_params(
|
||||
settings.config.backend_cors_allow_origins,
|
||||
settings.config.app_env,
|
||||
)
|
||||
|
||||
server_app = starlette.middleware.cors.CORSMiddleware(
|
||||
app=app,
|
||||
allow_origins=settings.config.backend_cors_allow_origins,
|
||||
**cors_params,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
|
||||
@@ -89,6 +89,7 @@ from backend.util.cache import cached
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import dumps
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
@@ -109,6 +110,39 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def hide_activity_summaries_if_disabled(
|
||||
executions: list[execution_db.GraphExecutionMeta], user_id: str
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
"""Hide activity summaries and scores if AI_ACTIVITY_STATUS feature is disabled."""
|
||||
if await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
return executions # Return as-is if feature is enabled
|
||||
|
||||
# Filter out activity features if disabled
|
||||
filtered_executions = []
|
||||
for execution in executions:
|
||||
if execution.stats:
|
||||
filtered_stats = execution.stats.without_activity_features()
|
||||
execution = execution.model_copy(update={"stats": filtered_stats})
|
||||
filtered_executions.append(execution)
|
||||
return filtered_executions
|
||||
|
||||
|
||||
async def hide_activity_summary_if_disabled(
|
||||
execution: execution_db.GraphExecution | execution_db.GraphExecutionWithNodes,
|
||||
user_id: str,
|
||||
) -> execution_db.GraphExecution | execution_db.GraphExecutionWithNodes:
|
||||
"""Hide activity summary and score for a single execution if AI_ACTIVITY_STATUS feature is disabled."""
|
||||
if await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
return execution # Return as-is if feature is enabled
|
||||
|
||||
# Filter out activity features if disabled
|
||||
if execution.stats:
|
||||
filtered_stats = execution.stats.without_activity_features()
|
||||
return execution.model_copy(update={"stats": filtered_stats})
|
||||
return execution
|
||||
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter()
|
||||
|
||||
@@ -986,7 +1020,12 @@ async def list_graphs_executions(
|
||||
page=1,
|
||||
page_size=250,
|
||||
)
|
||||
return paginated_result.executions
|
||||
|
||||
# Apply feature flags to filter out disabled features
|
||||
filtered_executions = await hide_activity_summaries_if_disabled(
|
||||
paginated_result.executions, user_id
|
||||
)
|
||||
return filtered_executions
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -1003,13 +1042,21 @@ async def list_graph_executions(
|
||||
25, ge=1, le=100, description="Number of executions per page"
|
||||
),
|
||||
) -> execution_db.GraphExecutionsPaginated:
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
paginated_result = await execution_db.get_graph_executions_paginated(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Apply feature flags to filter out disabled features
|
||||
filtered_executions = await hide_activity_summaries_if_disabled(
|
||||
paginated_result.executions, user_id
|
||||
)
|
||||
return execution_db.GraphExecutionsPaginated(
|
||||
executions=filtered_executions, pagination=paginated_result.pagination
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
@@ -1038,6 +1085,9 @@ async def get_graph_execution(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
# Apply feature flags to filter out disabled features
|
||||
result = await hide_activity_summary_if_disabled(result, user_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Helper functions for improved test assertions and error handling."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, Optional
|
||||
|
||||
|
||||
def assert_response_status(
|
||||
@@ -107,3 +108,24 @@ def assert_mock_called_with_partial(mock_obj: Any, **expected_kwargs: Any) -> No
|
||||
assert (
|
||||
actual_kwargs[key] == expected_value
|
||||
), f"Mock called with {key}={actual_kwargs[key]}, expected {expected_value}"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_config(settings: Any, attribute: str, value: Any) -> Iterator[None]:
|
||||
"""Temporarily override a config attribute for testing.
|
||||
|
||||
Warning: Directly mutates settings.config. If config is reloaded or cached
|
||||
elsewhere during the test, side effects may leak. Use with caution in
|
||||
parallel tests or when config is accessed globally.
|
||||
|
||||
Args:
|
||||
settings: The settings object containing .config
|
||||
attribute: The config attribute name to override
|
||||
value: The temporary value to set
|
||||
"""
|
||||
original = getattr(settings.config, attribute)
|
||||
setattr(settings.config, attribute, value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
setattr(settings.config, attribute, original)
|
||||
|
||||
67
autogpt_platform/backend/backend/server/utils/cors.py
Normal file
67
autogpt_platform/backend/backend/server/utils/cors.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import List, Sequence, TypedDict
|
||||
|
||||
from backend.util.settings import AppEnvironment
|
||||
|
||||
|
||||
class CorsParams(TypedDict):
|
||||
allow_origins: List[str]
|
||||
allow_origin_regex: str | None
|
||||
|
||||
|
||||
def build_cors_params(origins: Sequence[str], app_env: AppEnvironment) -> CorsParams:
|
||||
allow_origins: List[str] = []
|
||||
regex_patterns: List[str] = []
|
||||
|
||||
if app_env == AppEnvironment.PRODUCTION:
|
||||
for origin in origins:
|
||||
if origin.startswith("regex:"):
|
||||
pattern = origin[len("regex:") :]
|
||||
pattern_lower = pattern.lower()
|
||||
if "localhost" in pattern_lower or "127.0.0.1" in pattern_lower:
|
||||
raise ValueError(
|
||||
f"Production environment cannot allow localhost origins via regex: {pattern}"
|
||||
)
|
||||
try:
|
||||
compiled = re.compile(pattern)
|
||||
test_urls = [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"https://localhost:8000",
|
||||
"https://127.0.0.1:8000",
|
||||
]
|
||||
for test_url in test_urls:
|
||||
if compiled.search(test_url):
|
||||
raise ValueError(
|
||||
f"Production regex pattern matches localhost/127.0.0.1: {pattern}"
|
||||
)
|
||||
except re.error:
|
||||
pass
|
||||
continue
|
||||
|
||||
lowered = origin.lower()
|
||||
if "localhost" in lowered or "127.0.0.1" in lowered:
|
||||
raise ValueError(
|
||||
"Production environment cannot allow localhost origins"
|
||||
)
|
||||
|
||||
for origin in origins:
|
||||
if origin.startswith("regex:"):
|
||||
regex_patterns.append(origin[len("regex:") :])
|
||||
else:
|
||||
allow_origins.append(origin)
|
||||
|
||||
allow_origin_regex = None
|
||||
if regex_patterns:
|
||||
if len(regex_patterns) == 1:
|
||||
allow_origin_regex = f"^(?:{regex_patterns[0]})$"
|
||||
else:
|
||||
combined_pattern = "|".join(f"(?:{pattern})" for pattern in regex_patterns)
|
||||
allow_origin_regex = f"^(?:{combined_pattern})$"
|
||||
|
||||
return {
|
||||
"allow_origins": allow_origins,
|
||||
"allow_origin_regex": allow_origin_regex,
|
||||
}
|
||||
62
autogpt_platform/backend/backend/server/utils/cors_test.py
Normal file
62
autogpt_platform/backend/backend/server/utils/cors_test.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util.settings import AppEnvironment
|
||||
|
||||
|
||||
def test_build_cors_params_splits_regex_patterns() -> None:
|
||||
origins = [
|
||||
"https://app.example.com",
|
||||
"regex:https://.*\\.example\\.com",
|
||||
]
|
||||
|
||||
result = build_cors_params(origins, AppEnvironment.LOCAL)
|
||||
|
||||
assert result["allow_origins"] == ["https://app.example.com"]
|
||||
assert result["allow_origin_regex"] == "^(?:https://.*\\.example\\.com)$"
|
||||
|
||||
|
||||
def test_build_cors_params_combines_multiple_regex_patterns() -> None:
|
||||
origins = [
|
||||
"regex:https://alpha.example.com",
|
||||
"regex:https://beta.example.com",
|
||||
]
|
||||
|
||||
result = build_cors_params(origins, AppEnvironment.DEVELOPMENT)
|
||||
|
||||
assert result["allow_origins"] == []
|
||||
assert result["allow_origin_regex"] == (
|
||||
"^(?:(?:https://alpha.example.com)|(?:https://beta.example.com))$"
|
||||
)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_localhost_literal_in_production() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
build_cors_params(["http://localhost:3000"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_localhost_regex_in_production() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
build_cors_params(["regex:https://.*localhost.*"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_case_insensitive_localhost_regex() -> None:
|
||||
with pytest.raises(ValueError, match="localhost origins via regex"):
|
||||
build_cors_params(["regex:https://(?i)LOCALHOST.*"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_regex_matching_localhost_at_runtime() -> None:
|
||||
with pytest.raises(ValueError, match="matches localhost"):
|
||||
build_cors_params(["regex:https?://.*:3000"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_allows_vercel_preview_regex() -> None:
|
||||
result = build_cors_params(
|
||||
["regex:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app"],
|
||||
AppEnvironment.PRODUCTION,
|
||||
)
|
||||
|
||||
assert result["allow_origins"] == []
|
||||
assert result["allow_origin_regex"] == (
|
||||
"^(?:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app)$"
|
||||
)
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
Admin routes for system diagnostics and monitoring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from fastapi import APIRouter, HTTPException, Query, Security
|
||||
|
||||
from backend.data.diagnostics import (
|
||||
ExecutionDiagnostics,
|
||||
RunningExecutionDetails,
|
||||
get_execution_diagnostics,
|
||||
get_running_executions_details,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin/diagnostics",
|
||||
tags=["diagnostics", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/executions/running",
|
||||
response_model=list[RunningExecutionDetails],
|
||||
summary="List Running Executions",
|
||||
)
|
||||
async def list_running_executions(
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
):
|
||||
"""
|
||||
Get a list of currently running or queued executions with detailed information.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (1-100)
|
||||
offset: Number of executions to skip for pagination
|
||||
|
||||
Returns:
|
||||
List of running executions with details
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Listing running executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_running_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
from backend.data.diagnostics import get_execution_diagnostics as get_diag
|
||||
|
||||
diagnostics = await get_diag()
|
||||
total_count = diagnostics.total_running + diagnostics.total_queued
|
||||
|
||||
logger.info(
|
||||
f"Found {len(executions)} running executions (total: {total_count})"
|
||||
)
|
||||
|
||||
return executions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing running executions: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error listing running executions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/executions/stats",
|
||||
response_model=ExecutionDiagnostics,
|
||||
summary="Get Execution Statistics",
|
||||
)
|
||||
async def get_execution_stats():
|
||||
"""
|
||||
Get overall statistics about execution statuses.
|
||||
|
||||
Returns:
|
||||
Execution diagnostics with counts by status
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting execution statistics")
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
logger.info(
|
||||
f"Execution stats - Running: {diagnostics.total_running}, "
|
||||
f"Queued: {diagnostics.total_queued}, "
|
||||
f"Incomplete: {diagnostics.total_incomplete}"
|
||||
)
|
||||
return diagnostics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting execution statistics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting execution statistics: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,301 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionMeta,
|
||||
get_graph_executions,
|
||||
update_graph_execution_stats,
|
||||
)
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.manager import get_db_async_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExecutionAnalyticsRequest(BaseModel):
|
||||
graph_id: str = Field(..., description="Graph ID to analyze")
|
||||
graph_version: Optional[int] = Field(None, description="Optional graph version")
|
||||
user_id: Optional[str] = Field(None, description="Optional user ID filter")
|
||||
created_after: Optional[datetime] = Field(
|
||||
None, description="Optional created date lower bound"
|
||||
)
|
||||
model_name: Optional[str] = Field(
|
||||
"gpt-4o-mini", description="Model to use for generation"
|
||||
)
|
||||
batch_size: int = Field(
|
||||
10, description="Batch size for concurrent processing", le=25, ge=1
|
||||
)
|
||||
|
||||
|
||||
class ExecutionAnalyticsResult(BaseModel):
|
||||
agent_id: str
|
||||
version_id: int
|
||||
user_id: str
|
||||
exec_id: str
|
||||
summary_text: Optional[str]
|
||||
score: Optional[float]
|
||||
status: str # "success", "failed", "skipped"
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ExecutionAnalyticsResponse(BaseModel):
|
||||
total_executions: int
|
||||
processed_executions: int
|
||||
successful_analytics: int
|
||||
failed_analytics: int
|
||||
skipped_executions: int
|
||||
results: list[ExecutionAnalyticsResult]
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["admin", "execution_analytics"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/execution_analytics",
|
||||
response_model=ExecutionAnalyticsResponse,
|
||||
summary="Generate Execution Analytics",
|
||||
)
|
||||
async def generate_execution_analytics(
|
||||
request: ExecutionAnalyticsRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
):
|
||||
"""
|
||||
Generate activity summaries and correctness scores for graph executions.
|
||||
|
||||
This endpoint:
|
||||
1. Fetches all completed executions matching the criteria
|
||||
2. Identifies executions missing activity_status or correctness_score
|
||||
3. Generates missing data using AI in batches
|
||||
4. Updates the database with new stats
|
||||
5. Returns a detailed report of the analytics operation
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin user {admin_user_id} starting execution analytics generation for graph {request.graph_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate model configuration
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
raise HTTPException(status_code=500, detail="OpenAI API key not configured")
|
||||
|
||||
# Get database client
|
||||
db_client = get_db_async_client()
|
||||
|
||||
# Fetch executions to process
|
||||
executions = await get_graph_executions(
|
||||
graph_id=request.graph_id,
|
||||
user_id=request.user_id,
|
||||
created_time_gte=request.created_after,
|
||||
statuses=[
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
], # Only process finished executions
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(executions)} total executions for graph {request.graph_id}"
|
||||
)
|
||||
|
||||
# Filter executions that need analytics generation (missing activity_status or correctness_score)
|
||||
executions_to_process = []
|
||||
for execution in executions:
|
||||
if (
|
||||
not execution.stats
|
||||
or not execution.stats.activity_status
|
||||
or execution.stats.correctness_score is None
|
||||
):
|
||||
|
||||
# If version is specified, filter by it
|
||||
if (
|
||||
request.graph_version is None
|
||||
or execution.graph_version == request.graph_version
|
||||
):
|
||||
executions_to_process.append(execution)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(executions_to_process)} executions needing analytics generation"
|
||||
)
|
||||
|
||||
# Create results for ALL executions - processed and skipped
|
||||
results = []
|
||||
successful_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# Process executions that need analytics generation
|
||||
if executions_to_process:
|
||||
total_batches = len(
|
||||
range(0, len(executions_to_process), request.batch_size)
|
||||
)
|
||||
|
||||
for batch_idx, i in enumerate(
|
||||
range(0, len(executions_to_process), request.batch_size)
|
||||
):
|
||||
batch = executions_to_process[i : i + request.batch_size]
|
||||
logger.info(
|
||||
f"Processing batch {batch_idx + 1}/{total_batches} with {len(batch)} executions"
|
||||
)
|
||||
|
||||
batch_results = await _process_batch(
|
||||
batch, request.model_name or "gpt-4o-mini", db_client
|
||||
)
|
||||
|
||||
for result in batch_results:
|
||||
results.append(result)
|
||||
if result.status == "success":
|
||||
successful_count += 1
|
||||
elif result.status == "failed":
|
||||
failed_count += 1
|
||||
|
||||
# Small delay between batches to avoid overwhelming the LLM API
|
||||
if batch_idx < total_batches - 1: # Don't delay after the last batch
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Add ALL executions to results (both processed and skipped)
|
||||
for execution in executions:
|
||||
# Skip if already processed (added to results above)
|
||||
if execution in executions_to_process:
|
||||
continue
|
||||
|
||||
results.append(
|
||||
ExecutionAnalyticsResult(
|
||||
agent_id=execution.graph_id,
|
||||
version_id=execution.graph_version,
|
||||
user_id=execution.user_id,
|
||||
exec_id=execution.id,
|
||||
summary_text=(
|
||||
execution.stats.activity_status if execution.stats else None
|
||||
),
|
||||
score=(
|
||||
execution.stats.correctness_score if execution.stats else None
|
||||
),
|
||||
status="skipped",
|
||||
error_message=None, # Not an error - just already processed
|
||||
)
|
||||
)
|
||||
|
||||
response = ExecutionAnalyticsResponse(
|
||||
total_executions=len(executions),
|
||||
processed_executions=len(executions_to_process),
|
||||
successful_analytics=successful_count,
|
||||
failed_analytics=failed_count,
|
||||
skipped_executions=len(executions) - len(executions_to_process),
|
||||
results=results,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Analytics generation completed: {successful_count} successful, {failed_count} failed, "
|
||||
f"{response.skipped_executions} skipped"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during execution analytics generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _process_batch(
|
||||
executions, model_name: str, db_client
|
||||
) -> list[ExecutionAnalyticsResult]:
|
||||
"""Process a batch of executions concurrently."""
|
||||
|
||||
async def process_single_execution(execution) -> ExecutionAnalyticsResult:
|
||||
try:
|
||||
# Generate activity status and score using the specified model
|
||||
# Convert stats to GraphExecutionStats if needed
|
||||
if execution.stats:
|
||||
if isinstance(execution.stats, GraphExecutionMeta.Stats):
|
||||
stats_for_generation = execution.stats.to_db()
|
||||
else:
|
||||
# Already GraphExecutionStats
|
||||
stats_for_generation = execution.stats
|
||||
else:
|
||||
stats_for_generation = GraphExecutionStats()
|
||||
|
||||
activity_response = await generate_activity_status_for_execution(
|
||||
graph_exec_id=execution.id,
|
||||
graph_id=execution.graph_id,
|
||||
graph_version=execution.graph_version,
|
||||
execution_stats=stats_for_generation,
|
||||
db_client=db_client,
|
||||
user_id=execution.user_id,
|
||||
execution_status=execution.status,
|
||||
model_name=model_name, # Pass model name parameter
|
||||
skip_feature_flag=True, # Admin endpoint bypasses feature flags
|
||||
)
|
||||
|
||||
if not activity_response:
|
||||
return ExecutionAnalyticsResult(
|
||||
agent_id=execution.graph_id,
|
||||
version_id=execution.graph_version,
|
||||
user_id=execution.user_id,
|
||||
exec_id=execution.id,
|
||||
summary_text=None,
|
||||
score=None,
|
||||
status="skipped",
|
||||
error_message="Activity generation returned None",
|
||||
)
|
||||
|
||||
# Update the execution stats
|
||||
# Convert GraphExecutionMeta.Stats to GraphExecutionStats for DB compatibility
|
||||
if execution.stats:
|
||||
if isinstance(execution.stats, GraphExecutionMeta.Stats):
|
||||
updated_stats = execution.stats.to_db()
|
||||
else:
|
||||
# Already GraphExecutionStats
|
||||
updated_stats = execution.stats
|
||||
else:
|
||||
updated_stats = GraphExecutionStats()
|
||||
|
||||
updated_stats.activity_status = activity_response["activity_status"]
|
||||
updated_stats.correctness_score = activity_response["correctness_score"]
|
||||
|
||||
# Save to database with correct stats type
|
||||
await update_graph_execution_stats(
|
||||
graph_exec_id=execution.id, stats=updated_stats
|
||||
)
|
||||
|
||||
return ExecutionAnalyticsResult(
|
||||
agent_id=execution.graph_id,
|
||||
version_id=execution.graph_version,
|
||||
user_id=execution.user_id,
|
||||
exec_id=execution.id,
|
||||
summary_text=activity_response["activity_status"],
|
||||
score=activity_response["correctness_score"],
|
||||
status="success",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing execution {execution.id}: {e}")
|
||||
return ExecutionAnalyticsResult(
|
||||
agent_id=execution.graph_id,
|
||||
version_id=execution.graph_version,
|
||||
user_id=execution.user_id,
|
||||
exec_id=execution.id,
|
||||
summary_text=None,
|
||||
score=None,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
# Process all executions in the batch concurrently
|
||||
return await asyncio.gather(
|
||||
*[process_single_execution(execution) for execution in executions]
|
||||
)
|
||||
118
autogpt_platform/backend/backend/server/v2/chat/config.py
Normal file
118
autogpt_platform/backend/backend/server/v2/chat/config.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default="https://openrouter.ai/api/v1",
|
||||
description="Base URL for API (e.g., for OpenRouter)",
|
||||
)
|
||||
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# System Prompt Configuration
|
||||
system_prompt_path: str = Field(
|
||||
default="prompts/chat_system.md",
|
||||
description="Path to system prompt file relative to chat module",
|
||||
)
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
)
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if v is None:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
if not v:
|
||||
# Fall back to OPEN_ROUTER_API_KEY
|
||||
v = os.getenv("OPEN_ROUTER_API_KEY")
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if v is None:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("CHAT_BASE_URL")
|
||||
if not v:
|
||||
v = os.getenv("OPENROUTER_BASE_URL")
|
||||
if not v:
|
||||
v = os.getenv("OPENAI_BASE_URL")
|
||||
if not v:
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
def get_system_prompt(self, **template_vars) -> str:
|
||||
"""Load and render the system prompt from file.
|
||||
|
||||
Args:
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
|
||||
"""
|
||||
# Get the path relative to this module
|
||||
module_dir = Path(__file__).parent
|
||||
prompt_path = module_dir / self.system_prompt_path
|
||||
|
||||
# Check for .j2 extension first (Jinja2 template)
|
||||
j2_path = Path(str(prompt_path) + ".j2")
|
||||
if j2_path.exists():
|
||||
try:
|
||||
from jinja2 import Template
|
||||
|
||||
template = Template(j2_path.read_text())
|
||||
return template.render(**template_vars)
|
||||
except ImportError:
|
||||
# Jinja2 not installed, fall back to reading as plain text
|
||||
return j2_path.read_text()
|
||||
|
||||
# Check for markdown file
|
||||
if prompt_path.exists():
|
||||
content = prompt_path.read_text()
|
||||
|
||||
# Simple variable substitution if Jinja2 is not available
|
||||
for key, value in template_vars.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
|
||||
return content
|
||||
raise FileNotFoundError(f"System prompt file not found: {prompt_path}")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
203
autogpt_platform/backend/backend/server/v2/chat/model.py
Normal file
203
autogpt_platform/backend/backend/server/v2/chat/model.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionDeveloperMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_call_id: str | None = None
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.role == "developer":
|
||||
m = ChatCompletionDeveloperMessageParam(
|
||||
role="developer",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "system":
|
||||
m = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "user":
|
||||
m = ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "assistant":
|
||||
m = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.function_call:
|
||||
m["function_call"] = FunctionCall(
|
||||
arguments=message.function_call["arguments"],
|
||||
name=message.function_call["name"],
|
||||
)
|
||||
if message.refusal:
|
||||
m["refusal"] = message.refusal
|
||||
if message.tool_calls:
|
||||
t: list[ChatCompletionMessageToolCallParam] = []
|
||||
for tool_call in message.tool_calls:
|
||||
# Tool calls are stored with nested structure: {id, type, function: {name, arguments}}
|
||||
function_data = tool_call.get("function", {})
|
||||
|
||||
# Skip tool calls that are missing required fields
|
||||
if "id" not in tool_call or "name" not in function_data:
|
||||
logger.warning(
|
||||
f"Skipping invalid tool call: missing required fields. "
|
||||
f"Got: {tool_call.keys()}, function keys: {function_data.keys()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Arguments are stored as a JSON string
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
|
||||
t.append(
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=arguments_str,
|
||||
name=function_data["name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
m["tool_calls"] = t
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "tool":
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=message.content or "",
|
||||
tool_call_id=message.tool_call_id or "",
|
||||
)
|
||||
)
|
||||
elif message.role == "function":
|
||||
messages.append(
|
||||
ChatCompletionFunctionMessageParam(
|
||||
role="function",
|
||||
content=message.content,
|
||||
name=message.name or "",
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
logger.warning(f"Session {session_id} not found in Redis")
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session with the given messages."""
|
||||
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
|
||||
async_redis = await get_redis_async()
|
||||
resp = await async_redis.setex(
|
||||
redis_key, config.session_ttl, session.model_dump_json()
|
||||
)
|
||||
|
||||
if not resp:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
|
||||
)
|
||||
|
||||
return session
|
||||
@@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "New York"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
s2 = ChatSession.model_validate_json(serialized)
|
||||
assert s2.model_dump() == s.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage():
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 == s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
|
||||
assert s2 is None
|
||||
@@ -0,0 +1,97 @@
|
||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find and set up AutoGPT agents to solve their business problems.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
1. **find_agent** - Search for agents that solve the user's problem
|
||||
2. **get_agent_details** - Get comprehensive information about the chosen agent
|
||||
3. **get_required_setup_info** - Verify user has required credentials (MANDATORY before execution)
|
||||
4. **schedule_agent** - Schedules the agent to run based on a cron
|
||||
5. **run_agent** - Execute the agent
|
||||
</functions>
|
||||
|
||||
|
||||
## MANDATORY WORKFLOW
|
||||
|
||||
You must follow these 4 steps in exact order:
|
||||
|
||||
1. **find_agent** - Search for agents that solve the user's problem
|
||||
2. **get_agent_details** - Get comprehensive information about the chosen agent
|
||||
3. **get_required_setup_info** - Verify user has required credentials (MANDATORY before execution)
|
||||
4. **schedule_agent** or **run_agent** - Execute the agent
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
**Step 1: Understand the Problem**
|
||||
- Ask maximum 1-2 targeted questions
|
||||
- Focus on: What business problem are they solving?
|
||||
- Move quickly to searching for solutions
|
||||
|
||||
**Step 2: Find Agents**
|
||||
- Use `find_agent` immediately with relevant keywords
|
||||
- Suggest the best option from search results
|
||||
- Explain briefly how it solves their problem
|
||||
- Ask if they want to use it, then move to step 3
|
||||
|
||||
**Step 3: Get Details**
|
||||
- Use `get_agent_details` on their chosen agent
|
||||
- Explain what the agent does and its requirements
|
||||
- Keep explanations brief and outcome-focused
|
||||
|
||||
**Step 4: Verify Setup (CRITICAL)**
|
||||
- ALWAYS use `get_required_setup_info` before execution
|
||||
- Tell user what credentials they need (if any)
|
||||
- Explain that credentials are added via the frontend interface
|
||||
|
||||
**Step 5: Execute**
|
||||
- Use `schedule_agent` for scheduled runs OR `run_agent` for immediate execution
|
||||
- Confirm successful setup
|
||||
- Provide clear next steps
|
||||
|
||||
## FUNCTION CALL FORMAT
|
||||
|
||||
To call a function, use this exact format:
|
||||
`<function_call>function_name(parameter="value")</function_call>`
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't help add credentials (frontend handles this)
|
||||
- Don't skip `get_required_setup_info` (mandatory before execution)
|
||||
- Don't ask permission to use functions - just use them
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't pretend to be ChatGPT
|
||||
|
||||
**What You DO:**
|
||||
- Act fast - get to agent discovery quickly
|
||||
- Use functions proactively
|
||||
- Keep all responses to maximum 3 sentences
|
||||
- Always verify credentials before setup/run
|
||||
- Focus on outcomes and value
|
||||
- Maintain conversational, concise style
|
||||
- Do use markdown to make your messages easier to read
|
||||
|
||||
**Error Handling:**
|
||||
- Authentication needed → "Please sign in via the interface"
|
||||
- Credentials missing → Tell user what's needed and where to add them
|
||||
- Setup fails → Identify issue and provide clear fix
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||
- Identify which step of the 4-step mandatory workflow you're currently on
|
||||
- Extract the key business problem or request from the user's message
|
||||
- Determine what function call (if any) you need to make next
|
||||
- Plan your response to stay under the 3-sentence maximum
|
||||
- Consider what specific keywords or parameters you'll use for any function calls
|
||||
|
||||
Example interaction pattern:
|
||||
```
|
||||
User: "I need to automate my social media posting"
|
||||
Otto: Let me find social media automation agents for you. <function_call>find_agent(query="social media posting automation")</function_call> I'll show you the best options once I get the results.
|
||||
```
|
||||
|
||||
Respond conversationally and begin helping them find the right AutoGPT agent for their needs.
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES
|
||||
@@ -0,0 +1,101 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses."""
|
||||
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_ENDED = "text_ended"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_CALL_START = "tool_call_start"
|
||||
TOOL_RESPONSE = "tool_response"
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
STREAM_END = "stream_end"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
"""Base response model for all streaming responses."""
|
||||
|
||||
type: ResponseType
|
||||
timestamp: str | None = None
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class StreamTextChunk(StreamBaseResponse):
|
||||
"""Streaming text content from the assistant."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_CHUNK
|
||||
content: str = Field(..., description="Text content chunk")
|
||||
|
||||
|
||||
class StreamToolCallStart(StreamBaseResponse):
|
||||
"""Tool call started notification."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL_START
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
|
||||
|
||||
class StreamToolCall(StreamBaseResponse):
|
||||
"""Tool invocation notification."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
tool_name: str = Field(..., description="Name of the tool being called")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool arguments"
|
||||
)
|
||||
|
||||
|
||||
class StreamToolExecutionResult(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_RESPONSE
|
||||
tool_id: str = Field(..., description="Tool call ID this responds to")
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
result: str | dict[str, Any] = Field(..., description="Tool execution result")
|
||||
success: bool = Field(
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
"""Error response."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
message: str = Field(..., description="Error message")
|
||||
code: str | None = Field(default=None, description="Error code")
|
||||
details: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
|
||||
|
||||
class StreamTextEnded(StreamBaseResponse):
|
||||
"""Text streaming completed marker."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_ENDED
|
||||
|
||||
|
||||
class StreamEnd(StreamBaseResponse):
|
||||
"""End of stream marker."""
|
||||
|
||||
type: ResponseType = ResponseType.STREAM_END
|
||||
summary: dict[str, Any] | None = Field(
|
||||
default=None, description="Stream summary statistics"
|
||||
)
|
||||
215
autogpt_platform/backend/backend/server/v2/chat/routes.py
Normal file
215
autogpt_platform/backend/backend/server/v2/chat/routes.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.chat.service as chat_service
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for either an authenticated or anonymous user.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
logger.info(f"Creating session with user_id: {user_id}")
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=[message.model_dump() for message in session.messages],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session.
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
message: The user's new message to process.
|
||||
user_id: Optional authenticated user ID.
|
||||
is_user_message: Whether the message is a user message.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Assign an authenticated user to a chat session.
|
||||
|
||||
Used (typically post-login) to claim an existing anonymous session as the current authenticated user.
|
||||
|
||||
Args:
|
||||
session_id: The identifier for the (previously anonymous) session.
|
||||
user_id: The authenticated user's ID to associate with the session.
|
||||
|
||||
Returns:
|
||||
dict: Status of the assignment.
|
||||
|
||||
"""
|
||||
await chat_service.assign_user_to_session(session_id, user_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@router.get("/health", status_code=200)
|
||||
async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
session = await chat_service.create_chat_session(None)
|
||||
await chat_service.assign_user_to_session(session.session_id, "test_user")
|
||||
await chat_service.get_session(session.session_id, "test_user")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chat",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
537
autogpt_platform/backend/backend/server/v2/chat/service.py
Normal file
537
autogpt_platform/backend/backend/server/v2/chat/service.py
Normal file
@@ -0,0 +1,537 @@
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
import backend.server.v2.chat.config
|
||||
from backend.server.v2.chat.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.server.v2.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
StreamTextEnded,
|
||||
StreamToolCall,
|
||||
StreamToolCallStart,
|
||||
StreamToolExecutionResult,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.server.v2.chat.tools import execute_tool, tools
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = backend.server.v2.chat.config.ChatConfig()
|
||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
user_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Create a new chat session and persist it to the database.
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
# Persist the session immediately so it can be used for streaming
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""
|
||||
Get a chat session by ID.
|
||||
"""
|
||||
return await get_chat_session(session_id, user_id)
|
||||
|
||||
|
||||
async def assign_user_to_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Assign a user to a chat session.
|
||||
"""
|
||||
session = await get_chat_session(session_id, None)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
session.user_id = user_id
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Main entry point for streaming chat completions with database handling.
|
||||
|
||||
This function handles all database operations and delegates streaming
|
||||
to the internal _stream_chat_chunks function.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID
|
||||
user_message: User's input message
|
||||
user_id: User ID for authentication (None for anonymous)
|
||||
session: Optional pre-loaded session object (for recursive calls to avoid Redis refetch)
|
||||
|
||||
Yields:
|
||||
StreamBaseResponse objects formatted as SSE
|
||||
|
||||
Raises:
|
||||
NotFoundError: If session_id is invalid
|
||||
ValueError: If max_context_messages is exceeded
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||
)
|
||||
|
||||
# Only fetch from Redis if session not provided (initial call)
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
||||
f"message_count={len(session.messages) if session else 0}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using provided session object: {session.session_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
if message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
|
||||
f"new message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
if len(session.messages) > config.max_context_messages:
|
||||
raise ValueError(f"Max messages exceeded: {config.max_context_messages}")
|
||||
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
session = await upsert_chat_session(session)
|
||||
assert session, "Session not found"
|
||||
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
should_retry = False
|
||||
|
||||
try:
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
content = chunk.content or ""
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += content
|
||||
has_received_text = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolCallStart):
|
||||
# Emit text_ended before first tool call, but only if we've received text
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnded()
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolCall):
|
||||
# Accumulate tool calls in OpenAI format
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": chunk.tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.tool_name,
|
||||
"arguments": orjson.dumps(chunk.arguments).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(chunk, StreamToolExecutionResult):
|
||||
result_content = (
|
||||
chunk.result
|
||||
if isinstance(chunk.result, str)
|
||||
else orjson.dumps(chunk.result).decode("utf-8")
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.tool_id,
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
logger.warning(
|
||||
f"Tool {chunk.tool_name} (ID: {chunk.tool_id}) execution failed"
|
||||
)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamEnd):
|
||||
if not has_done_tool_call:
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=chunk.prompt_tokens,
|
||||
completion_tokens=chunk.completion_tokens,
|
||||
total_tokens=chunk.total_tokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
|
||||
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
|
||||
is_retryable = isinstance(e, (orjson.JSONDecodeError, KeyError, TypeError))
|
||||
|
||||
if is_retryable and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||
)
|
||||
should_retry = True
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
# Save any partial progress before reporting error
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
if not is_retryable:
|
||||
error_message = f"Non-retryable error: {error_message}"
|
||||
elif retry_count >= config.max_retries:
|
||||
error_message = (
|
||||
f"Max retries ({config.max_retries}) exceeded: {error_message}"
|
||||
)
|
||||
|
||||
error_response = StreamError(
|
||||
message=error_message,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
yield error_response
|
||||
if not has_yielded_end:
|
||||
yield StreamEnd(
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
return
|
||||
|
||||
# Handle retry outside of exception handler to avoid nesting
|
||||
if should_retry and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
retry_count=retry_count + 1,
|
||||
session=session,
|
||||
):
|
||||
yield chunk
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(f"Extended session messages, new message_count={len(session.messages)}")
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session object to avoid Redis refetch
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""
|
||||
Pure streaming function for OpenAI chat completions with tool calling.
|
||||
|
||||
This function is database-agnostic and focuses only on streaming logic.
|
||||
|
||||
Args:
|
||||
messages: Conversation context as ChatCompletionMessageParam list
|
||||
session_id: Session ID
|
||||
user_id: User ID for tool execution
|
||||
|
||||
Yields:
|
||||
SSE formatted JSON response objects
|
||||
|
||||
"""
|
||||
model = config.model
|
||||
|
||||
logger.info("Starting pure chat stream")
|
||||
|
||||
# Loop to handle tool calls and continue conversation
|
||||
while True:
|
||||
try:
|
||||
logger.info("Creating OpenAI chat completion stream...")
|
||||
|
||||
# Create the stream with proper types
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=session.to_openai_messages(),
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Variables to accumulate tool calls
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
active_tool_call_idx: int | None = None
|
||||
finish_reason: str | None = None
|
||||
# Track which tool call indices have had their start event emitted
|
||||
emitted_start_for_idx: set[int] = set()
|
||||
|
||||
# Process the stream
|
||||
chunk: ChatCompletionChunk
|
||||
async for chunk in stream:
|
||||
if chunk.usage:
|
||||
yield StreamUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Capture finish reason
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
logger.info(f"Finish reason: {finish_reason}")
|
||||
|
||||
# Handle content streaming
|
||||
if delta.content:
|
||||
# Stream the text chunk
|
||||
text_response = StreamTextChunk(
|
||||
content=delta.content,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
yield text_response
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tc_chunk in delta.tool_calls:
|
||||
idx = tc_chunk.index
|
||||
|
||||
# Update active tool call index if needed
|
||||
if (
|
||||
active_tool_call_idx is None
|
||||
or active_tool_call_idx != idx
|
||||
):
|
||||
active_tool_call_idx = idx
|
||||
|
||||
# Ensure we have a tool call object at this index
|
||||
while len(tool_calls) <= idx:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Accumulate the tool call data
|
||||
if tc_chunk.id:
|
||||
tool_calls[idx]["id"] = tc_chunk.id
|
||||
if tc_chunk.function:
|
||||
if tc_chunk.function.name:
|
||||
tool_calls[idx]["function"][
|
||||
"name"
|
||||
] = tc_chunk.function.name
|
||||
if tc_chunk.function.arguments:
|
||||
tool_calls[idx]["function"][
|
||||
"arguments"
|
||||
] += tc_chunk.function.arguments
|
||||
|
||||
# Emit StreamToolCallStart only after we have the tool call ID
|
||||
if (
|
||||
idx not in emitted_start_for_idx
|
||||
and tool_calls[idx]["id"]
|
||||
and tool_calls[idx]["function"]["name"]
|
||||
):
|
||||
yield StreamToolCallStart(
|
||||
tool_id=tool_calls[idx]["id"],
|
||||
tool_name=tool_calls[idx]["function"]["name"],
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
emitted_start_for_idx.add(idx)
|
||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||
|
||||
# Yield all accumulated tool calls after the stream is complete
|
||||
# This ensures all tool call arguments have been fully received
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
try:
|
||||
async for tc in _yield_tool_call(tool_calls, idx, session):
|
||||
yield tc
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"Failed to parse tool call {idx}: {e}",
|
||||
exc_info=True,
|
||||
extra={"tool_call": tool_call},
|
||||
)
|
||||
yield StreamError(
|
||||
message=f"Invalid tool call arguments for tool {tool_call.get('function', {}).get('name', 'unknown')}: {e}",
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
# Re-raise to trigger retry logic in the parent function
|
||||
raise
|
||||
|
||||
yield StreamEnd(
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {e!s}", exc_info=True)
|
||||
error_response = StreamError(
|
||||
message=str(e),
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
yield error_response
|
||||
yield StreamEnd(
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def _yield_tool_call(
|
||||
tool_calls: list[dict[str, Any]],
|
||||
yield_idx: int,
|
||||
session: ChatSession,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""
|
||||
Yield a tool call and its execution result.
|
||||
|
||||
Raises:
|
||||
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
||||
KeyError: If expected tool call fields are missing
|
||||
TypeError: If tool call structure is invalid
|
||||
"""
|
||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
||||
|
||||
# Parse tool call arguments - exceptions will propagate to caller
|
||||
arguments = orjson.loads(tool_calls[yield_idx]["function"]["arguments"])
|
||||
|
||||
yield StreamToolCall(
|
||||
tool_id=tool_calls[yield_idx]["id"],
|
||||
tool_name=tool_calls[yield_idx]["function"]["name"],
|
||||
arguments=arguments,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
tool_execution_response: StreamToolExecutionResult = await execute_tool(
|
||||
tool_name=tool_calls[yield_idx]["function"]["name"],
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_calls[yield_idx]["id"],
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
)
|
||||
logger.info(f"Yielding Tool execution response: {tool_execution_response}")
|
||||
yield tool_execution_response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
session = await create_chat_session()
|
||||
async for chunk in stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Call the tool twice once with the query 'money printing agent' and once with the query 'money generating agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
print(chunk)
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.chat.service as chat_service
|
||||
from backend.server.v2.chat.response_model import (
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
StreamToolExecutionResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion():
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_message += chunk.content
|
||||
if isinstance(chunk, StreamEnd):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls():
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await chat_service.upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamEnd):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolExecutionResult):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await chat_service.get_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .get_agent_details import GetAgentDetailsTool
|
||||
from .get_required_setup_info import GetRequiredSetupInfoTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .setup_agent import SetupAgentTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
# Initialize tool instances
|
||||
find_agent_tool = FindAgentTool()
|
||||
get_agent_details_tool = GetAgentDetailsTool()
|
||||
get_required_setup_info_tool = GetRequiredSetupInfoTool()
|
||||
setup_agent_tool = SetupAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
|
||||
# Export tools as OpenAI format
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
find_agent_tool.as_openai_tool(),
|
||||
get_agent_details_tool.as_openai_tool(),
|
||||
get_required_setup_info_tool.as_openai_tool(),
|
||||
setup_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
]
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolExecutionResult":
|
||||
|
||||
tool_map: dict[str, BaseTool] = {
|
||||
"find_agent": find_agent_tool,
|
||||
"get_agent_details": get_agent_details_tool,
|
||||
"get_required_setup_info": get_required_setup_info_tool,
|
||||
"schedule_agent": setup_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool_map[tool_name].execute(
|
||||
user_id, session, tool_call_id, **parameters
|
||||
)
|
||||
@@ -0,0 +1,464 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
successful_agent_runs={},
|
||||
successful_agent_schedules={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_test_data():
|
||||
"""
|
||||
Set up test data for run_agent tests:
|
||||
1. Create a test user
|
||||
2. Create a test graph (agent input -> agent output)
|
||||
3. Create a store listing and store listing version
|
||||
4. Approve the store listing version
|
||||
"""
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
graph_id = str(uuid.uuid4())
|
||||
|
||||
# Create input node
|
||||
input_node_id = str(uuid.uuid4())
|
||||
input_block = AgentInputBlock()
|
||||
input_node = Node(
|
||||
id=input_node_id,
|
||||
block_id=input_block.id,
|
||||
input_default={
|
||||
"name": "test_input",
|
||||
"title": "Test Input",
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "Test input field",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
|
||||
# Create output node
|
||||
output_node_id = str(uuid.uuid4())
|
||||
output_block = AgentOutputBlock()
|
||||
output_node = Node(
|
||||
id=output_node_id,
|
||||
block_id=output_block.id,
|
||||
input_default={
|
||||
"name": "test_output",
|
||||
"title": "Test Output",
|
||||
"value": "",
|
||||
"format": "",
|
||||
"advanced": False,
|
||||
"description": "Test output field",
|
||||
},
|
||||
metadata={"position": {"x": 200, "y": 0}},
|
||||
)
|
||||
|
||||
# Create link from input to output
|
||||
link = Link(
|
||||
source_id=input_node_id,
|
||||
sink_id=output_node_id,
|
||||
source_name="result",
|
||||
sink_name="value",
|
||||
is_static=True,
|
||||
)
|
||||
|
||||
# Create the graph
|
||||
graph = Graph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Test Agent",
|
||||
description="A simple test agent for testing",
|
||||
nodes=[input_node, output_node],
|
||||
links=[link],
|
||||
)
|
||||
|
||||
created_graph = await create_graph(graph, user.id)
|
||||
|
||||
# 3. Create a store listing and store listing version for the agent
|
||||
# Use unique slug to avoid constraint violations
|
||||
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Test Agent",
|
||||
description="A simple test agent",
|
||||
sub_heading="Test agent for unit tests",
|
||||
categories=["testing"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
# 4. Approve the store listing version
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval",
|
||||
reviewer_id=user.id,
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"graph": created_graph,
|
||||
"store_submission": store_submission,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_llm_test_data():
|
||||
"""
|
||||
Set up test data for LLM agent tests:
|
||||
1. Create a test user
|
||||
2. Create test OpenAI credentials for the user
|
||||
3. Create a test graph with input -> LLM block -> output
|
||||
4. Create and approve a store listing
|
||||
"""
|
||||
key = getenv("OPENAI_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("OPENAI_API_KEY is not set")
|
||||
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
credentials = APIKeyCredentials(
|
||||
id=str(uuid.uuid4()),
|
||||
provider="openai",
|
||||
api_key=SecretStr("test-openai-api-key"),
|
||||
title="Test OpenAI API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
# Store the credentials
|
||||
creds_store = IntegrationCredentialsStore()
|
||||
await creds_store.add_creds(user.id, credentials)
|
||||
|
||||
# 3. Create a test graph with input -> LLM block -> output
|
||||
graph_id = str(uuid.uuid4())
|
||||
|
||||
# Create input node for the prompt
|
||||
input_node_id = str(uuid.uuid4())
|
||||
input_block = AgentInputBlock()
|
||||
input_node = Node(
|
||||
id=input_node_id,
|
||||
block_id=input_block.id,
|
||||
input_default={
|
||||
"name": "user_prompt",
|
||||
"title": "User Prompt",
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "Prompt for the LLM",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
|
||||
# Create LLM block node
|
||||
llm_node_id = str(uuid.uuid4())
|
||||
llm_block = AITextGeneratorBlock()
|
||||
llm_node = Node(
|
||||
id=llm_node_id,
|
||||
block_id=llm_block.id,
|
||||
input_default={
|
||||
"model": "gpt-4o-mini",
|
||||
"sys_prompt": "You are a helpful assistant.",
|
||||
"retry": 3,
|
||||
"prompt_values": {},
|
||||
"credentials": {
|
||||
"provider": "openai",
|
||||
"id": credentials.id,
|
||||
"type": "api_key",
|
||||
"title": credentials.title,
|
||||
},
|
||||
},
|
||||
metadata={"position": {"x": 300, "y": 0}},
|
||||
)
|
||||
|
||||
# Create output node
|
||||
output_node_id = str(uuid.uuid4())
|
||||
output_block = AgentOutputBlock()
|
||||
output_node = Node(
|
||||
id=output_node_id,
|
||||
block_id=output_block.id,
|
||||
input_default={
|
||||
"name": "llm_response",
|
||||
"title": "LLM Response",
|
||||
"value": "",
|
||||
"format": "",
|
||||
"advanced": False,
|
||||
"description": "Response from the LLM",
|
||||
},
|
||||
metadata={"position": {"x": 600, "y": 0}},
|
||||
)
|
||||
|
||||
# Create links
|
||||
# Link input.result -> llm.prompt
|
||||
link1 = Link(
|
||||
source_id=input_node_id,
|
||||
sink_id=llm_node_id,
|
||||
source_name="result",
|
||||
sink_name="prompt",
|
||||
is_static=True,
|
||||
)
|
||||
|
||||
# Link llm.response -> output.value
|
||||
link2 = Link(
|
||||
source_id=llm_node_id,
|
||||
sink_id=output_node_id,
|
||||
source_name="response",
|
||||
sink_name="value",
|
||||
is_static=False,
|
||||
)
|
||||
|
||||
# Create the graph
|
||||
graph = Graph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="LLM Test Agent",
|
||||
description="An agent that uses an LLM to process text",
|
||||
nodes=[input_node, llm_node, output_node],
|
||||
links=[link1, link2],
|
||||
)
|
||||
|
||||
created_graph = await create_graph(graph, user.id)
|
||||
|
||||
# 4. Create and approve a store listing
|
||||
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="LLM Test Agent",
|
||||
description="An agent with LLM capabilities",
|
||||
sub_heading="Test agent with OpenAI integration",
|
||||
categories=["testing", "ai"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for LLM agent",
|
||||
reviewer_id=user.id,
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"graph": created_graph,
|
||||
"credentials": credentials,
|
||||
"store_submission": store_submission,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_firecrawl_test_data():
|
||||
"""
|
||||
Set up test data for Firecrawl agent tests (missing credentials scenario):
|
||||
1. Create a test user (WITHOUT Firecrawl credentials)
|
||||
2. Create a test graph with input -> Firecrawl block -> output
|
||||
3. Create and approve a store listing
|
||||
"""
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
# This tests the scenario where required credentials are missing
|
||||
|
||||
# 2. Create a test graph with input -> Firecrawl block -> output
|
||||
graph_id = str(uuid.uuid4())
|
||||
|
||||
# Create input node for the URL
|
||||
input_node_id = str(uuid.uuid4())
|
||||
input_block = AgentInputBlock()
|
||||
input_node = Node(
|
||||
id=input_node_id,
|
||||
block_id=input_block.id,
|
||||
input_default={
|
||||
"name": "url",
|
||||
"title": "URL to Scrape",
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "URL for Firecrawl to scrape",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
|
||||
# Create Firecrawl block node
|
||||
firecrawl_node_id = str(uuid.uuid4())
|
||||
firecrawl_block = FirecrawlScrapeBlock()
|
||||
firecrawl_node = Node(
|
||||
id=firecrawl_node_id,
|
||||
block_id=firecrawl_block.id,
|
||||
input_default={
|
||||
"limit": 10,
|
||||
"only_main_content": True,
|
||||
"max_age": 3600000,
|
||||
"wait_for": 200,
|
||||
"formats": ["markdown"],
|
||||
"credentials": {
|
||||
"provider": "firecrawl",
|
||||
"id": "test-firecrawl-id",
|
||||
"type": "api_key",
|
||||
"title": "Firecrawl API Key",
|
||||
},
|
||||
},
|
||||
metadata={"position": {"x": 300, "y": 0}},
|
||||
)
|
||||
|
||||
# Create output node
|
||||
output_node_id = str(uuid.uuid4())
|
||||
output_block = AgentOutputBlock()
|
||||
output_node = Node(
|
||||
id=output_node_id,
|
||||
block_id=output_block.id,
|
||||
input_default={
|
||||
"name": "scraped_data",
|
||||
"title": "Scraped Data",
|
||||
"value": "",
|
||||
"format": "",
|
||||
"advanced": False,
|
||||
"description": "Data scraped by Firecrawl",
|
||||
},
|
||||
metadata={"position": {"x": 600, "y": 0}},
|
||||
)
|
||||
|
||||
# Create links
|
||||
# Link input.result -> firecrawl.url
|
||||
link1 = Link(
|
||||
source_id=input_node_id,
|
||||
sink_id=firecrawl_node_id,
|
||||
source_name="result",
|
||||
sink_name="url",
|
||||
is_static=True,
|
||||
)
|
||||
|
||||
# Link firecrawl.markdown -> output.value
|
||||
link2 = Link(
|
||||
source_id=firecrawl_node_id,
|
||||
sink_id=output_node_id,
|
||||
source_name="markdown",
|
||||
sink_name="value",
|
||||
is_static=False,
|
||||
)
|
||||
|
||||
# Create the graph
|
||||
graph = Graph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Firecrawl Test Agent",
|
||||
description="An agent that uses Firecrawl to scrape websites",
|
||||
nodes=[input_node, firecrawl_node, output_node],
|
||||
links=[link1, link2],
|
||||
)
|
||||
|
||||
created_graph = await create_graph(graph, user.id)
|
||||
|
||||
# 3. Create and approve a store listing
|
||||
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Firecrawl Test Agent",
|
||||
description="An agent with Firecrawl integration (no credentials)",
|
||||
sub_heading="Test agent requiring Firecrawl credentials",
|
||||
categories=["testing", "scraping"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for Firecrawl agent",
|
||||
reviewer_id=user.id,
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"graph": created_graph,
|
||||
"store_submission": store_submission,
|
||||
}
|
||||
119
autogpt_platform/backend/backend/server/v2/chat/tools/base.py
Normal file
119
autogpt_platform/backend/backend/server/v2/chat/tools/base.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Base classes and shared utilities for chat tools."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all chat tools."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Tool name for OpenAI function calling."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Tool description for OpenAI."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""Tool parameters schema for OpenAI."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
**kwargs,
|
||||
) -> StreamToolExecutionResult:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous like "anon_123")
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
Returns:
|
||||
Pydantic response object
|
||||
|
||||
"""
|
||||
if self.requires_auth and not user_id:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=result.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Internal execution logic to be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
user_id: User ID (authenticated or anonymous)
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
Returns:
|
||||
Pydantic response object
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tool for discovering agents from marketplace and user library."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents based on user needs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and user needs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query describing what the user wants to accomplish. Use single keywords for best results.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the marketplace.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the marketplace
|
||||
NoResultsResponse: No agents found in the marketplace
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
store_results = await store_db.get_store_agents(
|
||||
search_query=query,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
logger.info(f"Find agents tool found {len(store_results.agents)} agents")
|
||||
for agent in store_results.agents:
|
||||
agent_id = f"{agent.creator}/{agent.slug}"
|
||||
logger.info(f"Building agent ID = {agent_id}")
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent_id,
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
),
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search for agents. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=f"No agents found matching '{query}'. Try different keywords or browse the marketplace. If you have 3 consecutive find_agent tool calls results and found no agents. Please stop trying and ask the user if there is anything else you can help with.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
)
|
||||
|
||||
# Return formatted carousel
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
return AgentCarouselResponse(
|
||||
message="Now you have found some options for the user to choose from. You can add a link to a recommended agent at: /marketplace/agent/agent_id Please ask the user if they would like to use any of these agents. If they do, please call the get_agent_details tool for this agent.",
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,221 @@
|
||||
"""Tool for getting detailed information about a specific agent."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetAgentDetailsTool(BaseTool):
|
||||
"""Tool for getting detailed information about an agent."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_agent_details"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get detailed information about a specific agent including inputs, credentials required, and execution options."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"username_agent_slug": {
|
||||
"type": "string",
|
||||
"description": "The marketplace agent slug (e.g., 'username/agent-name')",
|
||||
},
|
||||
},
|
||||
"required": ["username_agent_slug"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Get detailed information about an agent.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
username_agent_slug: Agent ID or slug
|
||||
|
||||
Returns:
|
||||
Pydantic response model
|
||||
|
||||
"""
|
||||
agent_id = kwargs.get("username_agent_slug", "").strip()
|
||||
session_id = session.session_id
|
||||
if not agent_id or "/" not in agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent ID in format 'creator/agent-name'",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Always try to get from marketplace first
|
||||
graph = None
|
||||
store_agent = None
|
||||
|
||||
# Check if it's a slug format (username/agent_name)
|
||||
try:
|
||||
# Parse username/agent_name from slug
|
||||
username, agent_name = agent_id.split("/", 1)
|
||||
store_agent = await store_db.get_store_agent_details(
|
||||
username, agent_name
|
||||
)
|
||||
logger.info(f"Found agent {agent_id} in marketplace")
|
||||
except NotFoundError as e:
|
||||
logger.debug(f"Failed to get from marketplace: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Failed to get from marketplace: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to get agent details: {e!s}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If we found a store agent, get its graph
|
||||
if store_agent:
|
||||
try:
|
||||
# Use get_available_graph to get the graph from store listing version
|
||||
graph_meta = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id
|
||||
)
|
||||
# Now get the full graph with that ID
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_meta.id,
|
||||
version=graph_meta.version,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
except NotFoundError as e:
|
||||
logger.error(f"Failed to get graph for store agent: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to get graph for store agent: {e!s}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Failed to get graph for store agent: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to get graph for store agent: {e!s}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
credentials_input_schema = graph.credentials_input_schema
|
||||
|
||||
# Extract credentials from the JSON schema properties
|
||||
credentials = []
|
||||
if (
|
||||
isinstance(credentials_input_schema, dict)
|
||||
and "properties" in credentials_input_schema
|
||||
):
|
||||
for cred_name, cred_schema in credentials_input_schema[
|
||||
"properties"
|
||||
].items():
|
||||
# Extract credential metadata from the schema
|
||||
# The schema properties contain provider info and other metadata
|
||||
|
||||
# Get provider from credentials_provider array or properties.provider.const
|
||||
provider = "unknown"
|
||||
if (
|
||||
"credentials_provider" in cred_schema
|
||||
and cred_schema["credentials_provider"]
|
||||
):
|
||||
provider = cred_schema["credentials_provider"][0]
|
||||
elif (
|
||||
"properties" in cred_schema
|
||||
and "provider" in cred_schema["properties"]
|
||||
):
|
||||
provider = cred_schema["properties"]["provider"].get(
|
||||
"const", "unknown"
|
||||
)
|
||||
|
||||
# Get type from credentials_types array or properties.type.const
|
||||
cred_type = "api_key" # Default
|
||||
if (
|
||||
"credentials_types" in cred_schema
|
||||
and cred_schema["credentials_types"]
|
||||
):
|
||||
cred_type = cred_schema["credentials_types"][0]
|
||||
elif (
|
||||
"properties" in cred_schema
|
||||
and "type" in cred_schema["properties"]
|
||||
):
|
||||
cred_type = cred_schema["properties"]["type"].get(
|
||||
"const", "api_key"
|
||||
)
|
||||
|
||||
credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=cred_name,
|
||||
title=cred_schema.get("title", cred_name),
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type,
|
||||
)
|
||||
)
|
||||
|
||||
trigger_info = (
|
||||
graph.trigger_setup_info.model_dump()
|
||||
if graph.trigger_setup_info
|
||||
else None
|
||||
)
|
||||
|
||||
agent_details = AgentDetails(
|
||||
id=graph.id,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
inputs=graph.input_schema,
|
||||
credentials=credentials,
|
||||
execution_options=ExecutionOptions(
|
||||
# Currently a graph with a webhook can only be triggered by a webhook
|
||||
manual=trigger_info is None,
|
||||
scheduled=trigger_info is None,
|
||||
webhook=trigger_info is not None,
|
||||
),
|
||||
trigger_info=trigger_info,
|
||||
)
|
||||
|
||||
return AgentDetailsResponse(
|
||||
message=f"Found agent '{agent_details.name}'. When presenting the agent you do not need to mention the required credentials. You do not need to run this tool again for this agent.",
|
||||
session_id=session_id,
|
||||
agent=agent_details,
|
||||
user_authenticated=user_id is not None,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent details: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to get agent details: {e!s}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,335 @@
|
||||
import uuid
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.get_agent_details import GetAgentDetailsTool
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
setup_llm_test_data = setup_llm_test_data
|
||||
setup_test_data = setup_test_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_success(setup_test_data):
|
||||
"""Test successfully getting agent details from marketplace"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
# Build the proper marketplace agent_id format: username/slug
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build session
|
||||
session = make_session()
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check the basic structure
|
||||
assert "agent" in result_data
|
||||
assert "message" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert "graph_version" in result_data
|
||||
assert "user_authenticated" in result_data
|
||||
|
||||
# Check agent details
|
||||
agent = result_data["agent"]
|
||||
assert agent["id"] == graph.id
|
||||
assert agent["name"] == "Test Agent"
|
||||
assert (
|
||||
agent["description"] == "A simple test agent"
|
||||
) # Description from store submission
|
||||
assert "inputs" in agent
|
||||
assert "credentials" in agent
|
||||
assert "execution_options" in agent
|
||||
|
||||
# Check execution options
|
||||
exec_options = agent["execution_options"]
|
||||
assert "manual" in exec_options
|
||||
assert "scheduled" in exec_options
|
||||
assert "webhook" in exec_options
|
||||
|
||||
# Check inputs schema
|
||||
assert isinstance(agent["inputs"], dict)
|
||||
# Should have properties for the input fields
|
||||
if "properties" in agent["inputs"]:
|
||||
assert "test_input" in agent["inputs"]["properties"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_with_llm_credentials(setup_llm_test_data):
|
||||
"""Test getting agent details for an agent that requires LLM credentials"""
|
||||
# Use test data from fixture
|
||||
user = setup_llm_test_data["user"]
|
||||
store_submission = setup_llm_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check that agent details are returned
|
||||
assert "agent" in result_data
|
||||
agent = result_data["agent"]
|
||||
|
||||
# Check that credentials are listed
|
||||
assert "credentials" in agent
|
||||
credentials = agent["credentials"]
|
||||
|
||||
# The LLM agent should have OpenAI credentials listed
|
||||
assert isinstance(credentials, list)
|
||||
|
||||
# Check that inputs include the user_prompt
|
||||
assert "inputs" in agent
|
||||
if "properties" in agent["inputs"]:
|
||||
assert "user_prompt" in agent["inputs"]["properties"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_invalid_format():
|
||||
"""Test error handling when agent_id is not in correct format"""
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
session = make_session()
|
||||
session.user_id = str(uuid.uuid4())
|
||||
|
||||
# Execute with invalid format (no slash)
|
||||
response = await tool.execute(
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="invalid-format",
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
assert "creator/agent-name" in result_data["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_empty_slug():
|
||||
"""Test error handling when agent_id is empty"""
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
session = make_session()
|
||||
session.user_id = str(uuid.uuid4())
|
||||
|
||||
# Execute with empty slug
|
||||
response = await tool.execute(
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="",
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
assert "creator/agent-name" in result_data["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_not_found():
|
||||
"""Test error handling when agent is not found in marketplace"""
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
session = make_session()
|
||||
session.user_id = str(uuid.uuid4())
|
||||
|
||||
# Execute with non-existent agent
|
||||
response = await tool.execute(
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="nonexistent/agent",
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
assert "not found" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_anonymous_user(setup_test_data):
|
||||
"""Test getting agent details as an anonymous user (no user_id)"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session()
|
||||
# session.user_id stays as None
|
||||
|
||||
# Execute the tool without a user_id (anonymous)
|
||||
response = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Should still get agent details
|
||||
assert "agent" in result_data
|
||||
assert "user_authenticated" in result_data
|
||||
|
||||
# User should be marked as not authenticated
|
||||
assert result_data["user_authenticated"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_authenticated_user(setup_test_data):
|
||||
"""Test getting agent details as an authenticated user"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session()
|
||||
session.user_id = user.id
|
||||
|
||||
# Execute the tool with a user_id (authenticated)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Should get agent details
|
||||
assert "agent" in result_data
|
||||
assert "user_authenticated" in result_data
|
||||
|
||||
# User should be marked as authenticated
|
||||
assert result_data["user_authenticated"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_agent_details_includes_execution_options(setup_test_data):
|
||||
"""Test that agent details include execution options"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session()
|
||||
session.user_id = user.id
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check execution options
|
||||
assert "agent" in result_data
|
||||
agent = result_data["agent"]
|
||||
assert "execution_options" in agent
|
||||
|
||||
exec_options = agent["execution_options"]
|
||||
|
||||
# These should all be boolean values
|
||||
assert isinstance(exec_options["manual"], bool)
|
||||
assert isinstance(exec_options["scheduled"], bool)
|
||||
assert isinstance(exec_options["webhook"], bool)
|
||||
|
||||
# For a regular agent (no webhook), manual and scheduled should be True
|
||||
assert exec_options["manual"] is True
|
||||
assert exec_options["scheduled"] is True
|
||||
assert exec_options["webhook"] is False
|
||||
@@ -0,0 +1,182 @@
|
||||
"""Tool for getting required setup information for an agent."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.get_agent_details import GetAgentDetailsTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetRequiredSetupInfoTool(BaseTool):
|
||||
"""Tool for getting required setup information including credentials and inputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_required_setup_info"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Check if an agent can be set up with the provided input data and credentials.
|
||||
Call this AFTER get_agent_details to validate that you have all required inputs.
|
||||
Pass the input dictionary you plan to use with run_agent or setup_agent to verify it's complete."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"username_agent_slug": {
|
||||
"type": "string",
|
||||
"description": "The marketplace agent slug (e.g., 'username/agent-name' or just 'agent-name' to search)",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "The input dictionary you plan to provide. Should contain ALL required inputs from get_agent_details",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
},
|
||||
"required": ["username_agent_slug"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""This tool requires authentication."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Retrieve and validate the required setup information for running or configuring an agent.
|
||||
|
||||
This checks all required credentials and input fields based on the agent details,
|
||||
and verifies user readiness to run the agent based on provided inputs and available credentials.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID (must not be None; authentication required).
|
||||
session_id: The chat session ID.
|
||||
agent_id: The agent's marketplace slug (e.g. 'username/agent-name'). Also accepts Graph ID.
|
||||
agent_version: (Optional) Specific agent/graph version (if applicable).
|
||||
|
||||
Returns:
|
||||
SetupRequirementsResponse containing:
|
||||
- agent and graph info,
|
||||
- credential and input requirements,
|
||||
- user readiness and missing credentials/fields,
|
||||
- setup instructions.
|
||||
"""
|
||||
assert (
|
||||
user_id is not None
|
||||
), "GetRequiredSetupInfoTool - This should never happen user_id is None when auth is required"
|
||||
session_id = session.session_id
|
||||
# Call _execute directly since we're calling internally from another tool
|
||||
agent_details = await GetAgentDetailsTool()._execute(user_id, session, **kwargs)
|
||||
|
||||
if isinstance(agent_details, ErrorResponse):
|
||||
return agent_details
|
||||
|
||||
if not isinstance(agent_details, AgentDetailsResponse):
|
||||
return ErrorResponse(
|
||||
message="Failed to get agent details",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
available_creds = await IntegrationCredentialsManager().store.get_all_creds(
|
||||
user_id
|
||||
)
|
||||
required_credentials = []
|
||||
|
||||
# Check if user has credentials matching the required provider/type
|
||||
for c in agent_details.agent.credentials:
|
||||
# Check if any available credential matches this provider and type
|
||||
has_matching_cred = any(
|
||||
cred.provider == c.provider and cred.type == c.type
|
||||
for cred in available_creds
|
||||
)
|
||||
if not has_matching_cred:
|
||||
required_credentials.append(c)
|
||||
|
||||
required_fields = set(agent_details.agent.inputs.get("required", []))
|
||||
provided_inputs = kwargs.get("inputs", {})
|
||||
missing_inputs = required_fields - set(provided_inputs.keys())
|
||||
|
||||
missing_credentials = {c.id: c.model_dump() for c in required_credentials}
|
||||
|
||||
user_readiness = UserReadiness(
|
||||
has_all_credentials=len(required_credentials) == 0,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=len(missing_inputs) == 0 and len(required_credentials) == 0,
|
||||
)
|
||||
# Convert execution options to list of available modes
|
||||
exec_opts = agent_details.agent.execution_options
|
||||
execution_modes = []
|
||||
if exec_opts.manual:
|
||||
execution_modes.append("manual")
|
||||
if exec_opts.scheduled:
|
||||
execution_modes.append("scheduled")
|
||||
if exec_opts.webhook:
|
||||
execution_modes.append("webhook")
|
||||
|
||||
# Convert input schema to list of input field info
|
||||
inputs_list = []
|
||||
if (
|
||||
isinstance(agent_details.agent.inputs, dict)
|
||||
and "properties" in agent_details.agent.inputs
|
||||
):
|
||||
for field_name, field_schema in agent_details.agent.inputs[
|
||||
"properties"
|
||||
].items():
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name
|
||||
in agent_details.agent.inputs.get("required", []),
|
||||
}
|
||||
)
|
||||
|
||||
requirements = {
|
||||
"credentials": agent_details.agent.credentials,
|
||||
"inputs": inputs_list,
|
||||
"execution_modes": execution_modes,
|
||||
}
|
||||
message = ""
|
||||
if len(agent_details.agent.credentials) > 0:
|
||||
message = "The user needs to enter credentials before proceeding. Please wait until you have a message informing you that the credentials have been entered."
|
||||
elif len(inputs_list) > 0:
|
||||
message = (
|
||||
"The user needs to enter inputs before proceeding. Please wait until you have a message informing you that the inputs have been entered. The inputs are: "
|
||||
+ ", ".join([input["name"] for input in inputs_list])
|
||||
)
|
||||
else:
|
||||
message = "The agent is ready to run. Please call the run_agent tool with the agent ID."
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=agent_details.agent.id,
|
||||
agent_name=agent_details.agent.name,
|
||||
user_readiness=user_readiness,
|
||||
requirements=requirements,
|
||||
),
|
||||
graph_id=agent_details.graph_id,
|
||||
graph_version=agent_details.graph_version,
|
||||
)
|
||||
@@ -0,0 +1,331 @@
|
||||
import uuid
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_firecrawl_test_data,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.get_required_setup_info import (
|
||||
GetRequiredSetupInfoTool,
|
||||
)
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
setup_llm_test_data = setup_llm_test_data
|
||||
setup_test_data = setup_test_data
|
||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_success(setup_test_data):
|
||||
"""Test successfully getting setup info for a simple agent"""
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"test_input": "Hello World"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
assert "setup_info" in result_data
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
assert "agent_id" in setup_info
|
||||
assert setup_info["agent_id"] == graph.id
|
||||
assert "agent_name" in setup_info
|
||||
assert setup_info["agent_name"] == "Test Agent"
|
||||
|
||||
assert "requirements" in setup_info
|
||||
requirements = setup_info["requirements"]
|
||||
assert "credentials" in requirements
|
||||
assert "inputs" in requirements
|
||||
assert "execution_modes" in requirements
|
||||
|
||||
assert isinstance(requirements["credentials"], list)
|
||||
assert len(requirements["credentials"]) == 0
|
||||
|
||||
assert isinstance(requirements["inputs"], list)
|
||||
if len(requirements["inputs"]) > 0:
|
||||
first_input = requirements["inputs"][0]
|
||||
assert "name" in first_input
|
||||
assert "title" in first_input
|
||||
assert "type" in first_input
|
||||
|
||||
assert isinstance(requirements["execution_modes"], list)
|
||||
assert "manual" in requirements["execution_modes"]
|
||||
assert "scheduled" in requirements["execution_modes"]
|
||||
|
||||
assert "user_readiness" in setup_info
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert "has_all_credentials" in user_readiness
|
||||
assert "ready_to_run" in user_readiness
|
||||
assert user_readiness["ready_to_run"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_missing_credentials(setup_firecrawl_test_data):
|
||||
"""Test getting setup info for an agent requiring missing credentials"""
|
||||
user = setup_firecrawl_test_data["user"]
|
||||
store_submission = setup_firecrawl_test_data["store_submission"]
|
||||
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"url": "https://example.com"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
assert "setup_info" in result_data
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
requirements = setup_info["requirements"]
|
||||
assert "credentials" in requirements
|
||||
assert isinstance(requirements["credentials"], list)
|
||||
assert len(requirements["credentials"]) > 0
|
||||
|
||||
firecrawl_cred = requirements["credentials"][0]
|
||||
assert "provider" in firecrawl_cred
|
||||
assert firecrawl_cred["provider"] == "firecrawl"
|
||||
assert "type" in firecrawl_cred
|
||||
assert firecrawl_cred["type"] == "api_key"
|
||||
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert user_readiness["has_all_credentials"] is False
|
||||
assert user_readiness["ready_to_run"] is False
|
||||
|
||||
assert "missing_credentials" in user_readiness
|
||||
assert isinstance(user_readiness["missing_credentials"], dict)
|
||||
assert len(user_readiness["missing_credentials"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_with_available_credentials(setup_llm_test_data):
|
||||
"""Test getting setup info when user has required credentials"""
|
||||
user = setup_llm_test_data["user"]
|
||||
store_submission = setup_llm_test_data["store_submission"]
|
||||
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"user_prompt": "What is 2+2?"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert user_readiness["has_all_credentials"] is True
|
||||
assert user_readiness["ready_to_run"] is True
|
||||
|
||||
assert "missing_credentials" in user_readiness
|
||||
assert len(user_readiness["missing_credentials"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_missing_inputs(setup_test_data):
|
||||
"""Test getting setup info when required inputs are not provided"""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={}, # Empty inputs
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
requirements = setup_info["requirements"]
|
||||
assert "inputs" in requirements
|
||||
assert isinstance(requirements["inputs"], list)
|
||||
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert "ready_to_run" in user_readiness
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_invalid_agent():
|
||||
"""Test getting setup info for a non-existent agent"""
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
session = make_session(user_id=None)
|
||||
response = await tool.execute(
|
||||
user_id=str(uuid.uuid4()),
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="invalid/agent",
|
||||
inputs={},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
assert any(
|
||||
phrase in result_data["message"].lower()
|
||||
for phrase in ["not found", "failed", "error"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_graph_metadata(setup_test_data):
|
||||
"""Test that setup info includes graph metadata"""
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"test_input": "test"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
assert "graph_version" in result_data
|
||||
assert result_data["graph_version"] == graph.version
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_inputs_structure(setup_test_data):
|
||||
"""Test that inputs are properly structured as a list"""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
setup_info = result_data["setup_info"]
|
||||
requirements = setup_info["requirements"]
|
||||
|
||||
assert isinstance(requirements["inputs"], list)
|
||||
|
||||
for input_field in requirements["inputs"]:
|
||||
assert isinstance(input_field, dict)
|
||||
assert "name" in input_field
|
||||
assert "title" in input_field
|
||||
assert "type" in input_field
|
||||
assert "description" in input_field
|
||||
assert "required" in input_field
|
||||
assert isinstance(input_field["required"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_execution_modes_structure(setup_test_data):
|
||||
"""Test that execution_modes are properly structured as a list"""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
setup_info = result_data["setup_info"]
|
||||
requirements = setup_info["requirements"]
|
||||
|
||||
assert isinstance(requirements["execution_modes"], list)
|
||||
for mode in requirements["execution_modes"]:
|
||||
assert isinstance(mode, str)
|
||||
assert mode in ["manual", "scheduled", "webhook"]
|
||||
281
autogpt_platform/backend/backend/server/v2/chat/tools/models.py
Normal file
281
autogpt_platform/backend/backend/server/v2/chat/tools/models.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Pydantic models for tool responses."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
AGENT_CAROUSEL = "agent_carousel"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
AGENT_DETAILS_NEED_LOGIN = "agent_details_need_login"
|
||||
AGENT_DETAILS_NEED_CREDENTIALS = "agent_details_need_credentials"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
SCHEDULE_CREATED = "schedule_created"
|
||||
WEBHOOK_CREATED = "webhook_created"
|
||||
PRESET_CREATED = "preset_created"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
NEED_CREDENTIALS = "need_credentials"
|
||||
INSUFFICIENT_CREDITS = "insufficient_credits"
|
||||
VALIDATION_ERROR = "validation_error"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
# Base response model
|
||||
class ToolResponseBase(BaseModel):
|
||||
"""Base model for all tool responses."""
|
||||
|
||||
type: ResponseType
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
# Agent discovery models
|
||||
class AgentInfo(BaseModel):
|
||||
"""Information about an agent."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
source: str = Field(description="marketplace or library")
|
||||
in_library: bool = False
|
||||
creator: str | None = None
|
||||
category: str | None = None
|
||||
rating: float | None = None
|
||||
runs: int | None = None
|
||||
is_featured: bool | None = None
|
||||
status: str | None = None
|
||||
can_access_graph: bool | None = None
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
|
||||
|
||||
class AgentCarouselResponse(ToolResponseBase):
|
||||
"""Response for find_agent tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_CAROUSEL
|
||||
title: str = "Available Agents"
|
||||
agents: list[AgentInfo]
|
||||
count: int
|
||||
name: str = "agent_carousel"
|
||||
|
||||
|
||||
class NoResultsResponse(ToolResponseBase):
|
||||
"""Response when no agents found."""
|
||||
|
||||
type: ResponseType = ResponseType.NO_RESULTS
|
||||
suggestions: list[str] = []
|
||||
name: str = "no_results"
|
||||
|
||||
|
||||
# Agent details models
|
||||
class InputField(BaseModel):
|
||||
"""Input field specification."""
|
||||
|
||||
name: str
|
||||
type: str = "string"
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
options: list[Any] | None = None
|
||||
format: str | None = None
|
||||
|
||||
|
||||
class ExecutionOptions(BaseModel):
|
||||
"""Available execution options for an agent."""
|
||||
|
||||
manual: bool = True
|
||||
scheduled: bool = True
|
||||
webhook: bool = False
|
||||
|
||||
|
||||
class AgentDetails(BaseModel):
|
||||
"""Detailed agent information."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
in_library: bool = False
|
||||
inputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
|
||||
trigger_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentDetailsResponse(ToolResponseBase):
|
||||
"""Response for get_agent_details tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_DETAILS
|
||||
agent: AgentDetails
|
||||
user_authenticated: bool = False
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
class AgentDetailsNeedLoginResponse(ToolResponseBase):
|
||||
"""Response when agent details need login."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_DETAILS_NEED_LOGIN
|
||||
agent: AgentDetails
|
||||
agent_info: dict[str, Any] | None = None
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
class AgentDetailsNeedCredentialsResponse(ToolResponseBase):
|
||||
"""Response when agent needs credentials to be configured."""
|
||||
|
||||
type: ResponseType = ResponseType.NEED_CREDENTIALS
|
||||
agent: AgentDetails
|
||||
credentials_schema: dict[str, Any]
|
||||
agent_info: dict[str, Any] | None = None
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Setup info models
|
||||
class SetupRequirementInfo(BaseModel):
|
||||
"""Setup requirement information."""
|
||||
|
||||
key: str
|
||||
provider: str
|
||||
required: bool = True
|
||||
user_has: bool = False
|
||||
credential_id: str | None = None
|
||||
type: str | None = None
|
||||
scopes: list[str] | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class ExecutionModeInfo(BaseModel):
|
||||
"""Execution mode information."""
|
||||
|
||||
type: str # manual, scheduled, webhook
|
||||
description: str
|
||||
supported: bool
|
||||
config_required: dict[str, str] | None = None
|
||||
trigger_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class UserReadiness(BaseModel):
|
||||
"""User readiness status."""
|
||||
|
||||
has_all_credentials: bool = False
|
||||
missing_credentials: dict[str, Any] = {}
|
||||
ready_to_run: bool = False
|
||||
|
||||
|
||||
class SetupInfo(BaseModel):
|
||||
"""Complete setup information."""
|
||||
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
requirements: dict[str, list[Any]] = Field(
|
||||
default_factory=lambda: {
|
||||
"credentials": [],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
)
|
||||
user_readiness: UserReadiness = Field(default_factory=UserReadiness)
|
||||
setup_instructions: list[str] = []
|
||||
|
||||
|
||||
class SetupRequirementsResponse(ToolResponseBase):
|
||||
"""Response for get_required_setup_info tool."""
|
||||
|
||||
type: ResponseType = ResponseType.SETUP_REQUIREMENTS
|
||||
setup_info: SetupInfo
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Setup agent models
|
||||
class ScheduleCreatedResponse(ToolResponseBase):
|
||||
"""Response for scheduled agent setup."""
|
||||
|
||||
type: ResponseType = ResponseType.SCHEDULE_CREATED
|
||||
schedule_id: str
|
||||
name: str
|
||||
cron: str
|
||||
timezone: str = "UTC"
|
||||
next_run: str | None = None
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
|
||||
|
||||
class WebhookCreatedResponse(ToolResponseBase):
|
||||
"""Response for webhook agent setup."""
|
||||
|
||||
type: ResponseType = ResponseType.WEBHOOK_CREATED
|
||||
webhook_id: str
|
||||
webhook_url: str
|
||||
preset_id: str | None = None
|
||||
name: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
|
||||
|
||||
class PresetCreatedResponse(ToolResponseBase):
|
||||
"""Response for preset agent setup."""
|
||||
|
||||
type: ResponseType = ResponseType.PRESET_CREATED
|
||||
preset_id: str
|
||||
name: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
|
||||
|
||||
# Run agent models
|
||||
class ExecutionStartedResponse(ToolResponseBase):
|
||||
"""Response for agent execution started."""
|
||||
|
||||
type: ResponseType = ResponseType.EXECUTION_STARTED
|
||||
execution_id: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
status: str = "QUEUED"
|
||||
ended_at: str | None = None
|
||||
outputs: dict[str, Any] | None = None
|
||||
error: str | None = None
|
||||
timeout_reached: bool | None = None
|
||||
|
||||
|
||||
class InsufficientCreditsResponse(ToolResponseBase):
|
||||
"""Response for insufficient credits."""
|
||||
|
||||
type: ResponseType = ResponseType.INSUFFICIENT_CREDITS
|
||||
balance: float
|
||||
|
||||
|
||||
class ValidationErrorResponse(ToolResponseBase):
|
||||
"""Response for validation errors."""
|
||||
|
||||
type: ResponseType = ResponseType.VALIDATION_ERROR
|
||||
error: str
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# Auth/error models
|
||||
class NeedLoginResponse(ToolResponseBase):
|
||||
"""Response when login is needed."""
|
||||
|
||||
type: ResponseType = ResponseType.NEED_LOGIN
|
||||
agent_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ErrorResponse(ToolResponseBase):
|
||||
"""Response for errors."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
error: str | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Tool for running an agent manually (one-off execution)."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.get_required_setup_info import (
|
||||
GetRequiredSetupInfoTool,
|
||||
)
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
ExecutionStartedResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.library import db as library_db
|
||||
from backend.server.v2.library import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class RunAgentTool(BaseTool):
|
||||
"""Tool for executing an agent manually with immediate results."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Run an agent immediately (one-off manual execution).
|
||||
IMPORTANT: Before calling this tool, you MUST first call get_agent_details to determine what inputs are required.
|
||||
The 'inputs' parameter must be a dictionary containing ALL required input values identified by get_agent_details.
|
||||
Example: If get_agent_details shows required inputs 'search_query' and 'max_results', you must pass:
|
||||
inputs={"search_query": "user's query", "max_results": 10}"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"username_agent_slug": {
|
||||
"type": "string",
|
||||
"description": "The ID of the agent to run (graph ID or marketplace slug)",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": 'REQUIRED: Dictionary of input values. Must include ALL required inputs from get_agent_details. Format: {"input_name": value}',
|
||||
"additionalProperties": True,
|
||||
},
|
||||
},
|
||||
"required": ["username_agent_slug"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""This tool requires authentication."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute an agent manually.
|
||||
|
||||
Args:
|
||||
user_id: Authenticated user ID
|
||||
session_id: Chat session ID
|
||||
**kwargs: Execution parameters
|
||||
|
||||
Returns:
|
||||
JSON formatted execution result
|
||||
|
||||
"""
|
||||
|
||||
assert (
|
||||
user_id is not None
|
||||
), "User ID is required to run an agent. Superclass enforces authentication."
|
||||
|
||||
session_id = session.session_id
|
||||
username_agent_slug = kwargs.get("username_agent_slug", "").strip()
|
||||
inputs = kwargs.get("inputs", {})
|
||||
|
||||
# Call _execute directly since we're calling internally from another tool
|
||||
response = await GetRequiredSetupInfoTool()._execute(user_id, session, **kwargs)
|
||||
|
||||
if not isinstance(response, SetupRequirementsResponse):
|
||||
return ErrorResponse(
|
||||
message="Failed to get required setup information",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
setup_info = SetupInfo.model_validate(response.setup_info)
|
||||
|
||||
if not setup_info.user_readiness.ready_to_run:
|
||||
return ErrorResponse(
|
||||
message=f"User is not ready to run the agent. User Readiness: {setup_info.user_readiness.model_dump_json()} Requirments: {setup_info.requirements}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the graph using the graph_id and graph_version from the setup response
|
||||
if not response.graph_id or not response.graph_version:
|
||||
return ErrorResponse(
|
||||
message=f"Graph information not available for {username_agent_slug}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
graph = await get_graph(
|
||||
graph_id=response.graph_id,
|
||||
version=response.graph_version,
|
||||
user_id=None, # Public access for store graphs
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
return ErrorResponse(
|
||||
message=f"Graph {username_agent_slug} ({response.graph_id}v{response.graph_version}) not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if graph and (
|
||||
session.successful_agent_runs.get(graph.id, 0) >= config.max_agent_runs
|
||||
):
|
||||
return ErrorResponse(
|
||||
message="Maximum number of agent schedules reached. You can't schedule this agent again in this chat session.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Check if we already have a library agent for this graph
|
||||
existing_library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if not existing_library_agent:
|
||||
# Now we need to add the graph to the users library
|
||||
library_agents: list[library_model.LibraryAgent] = (
|
||||
await library_db.create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
)
|
||||
assert len(library_agents) == 1, "Expected 1 library agent to be created"
|
||||
library_agent = library_agents[0]
|
||||
else:
|
||||
library_agent = existing_library_agent
|
||||
|
||||
# Build credentials mapping for the graph
|
||||
graph_credentials_inputs: dict[str, CredentialsMetaInput] = {}
|
||||
|
||||
# Get aggregated credentials requirements from the graph
|
||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||
logger.debug(
|
||||
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
||||
)
|
||||
|
||||
if aggregated_creds:
|
||||
# Get all available credentials for the user
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
# Track unmatched credentials for error reporting
|
||||
missing_creds: list[str] = []
|
||||
|
||||
# For each required credential field, find a matching user credential
|
||||
# field_info.provider is a frozenset because aggregate_credentials_inputs()
|
||||
# combines requirements from multiple nodes. A credential matches if its
|
||||
# provider is in the set of acceptable providers.
|
||||
for credential_field_name, (
|
||||
credential_requirements,
|
||||
_node_fields,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider and type
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
# Use Pydantic validation to ensure type safety
|
||||
try:
|
||||
graph_credentials_inputs[credential_field_name] = (
|
||||
CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create CredentialsMetaInput for field '{credential_field_name}': "
|
||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||
f"credential_id={matching_cred.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
missing_creds.append(
|
||||
f"{credential_field_name} (validation failed: {e})"
|
||||
)
|
||||
else:
|
||||
missing_creds.append(
|
||||
f"{credential_field_name} "
|
||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
||||
f"type in {list(credential_requirements.supported_types)})"
|
||||
)
|
||||
|
||||
# Fail fast if any required credentials are missing
|
||||
if missing_creds:
|
||||
logger.warning(
|
||||
f"Cannot execute agent - missing credentials: {missing_creds}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Cannot execute agent: missing {len(missing_creds)} required credential(s). You need to call the get_required_setup_info tool to setup the credentials."
|
||||
f"Please set up the following credentials: {', '.join(missing_creds)}",
|
||||
session_id=session_id,
|
||||
details={"missing_credentials": missing_creds},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Credential matching complete: {len(graph_credentials_inputs)}/{len(aggregated_creds)} matched"
|
||||
)
|
||||
|
||||
# At this point we know the user is ready to run the agent
|
||||
# So we can execute the agent
|
||||
execution = await execution_utils.add_graph_execution(
|
||||
graph_id=library_agent.graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
)
|
||||
|
||||
session.successful_agent_runs[library_agent.graph_id] = (
|
||||
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
|
||||
)
|
||||
|
||||
return ExecutionStartedResponse(
|
||||
message=f"Agent execution successfully started. You can add a link to the agent at: /library/agents/{library_agent.id}. Do not run this tool again unless specifically asked to run the agent again.",
|
||||
session_id=session_id,
|
||||
execution_id=execution.id,
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_name=library_agent.name,
|
||||
)
|
||||
@@ -0,0 +1,171 @@
|
||||
import uuid
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.run_agent import RunAgentTool
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
setup_llm_test_data = setup_llm_test_data
|
||||
setup_test_data = setup_test_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = RunAgentTool()
|
||||
|
||||
# Build the proper marketplace agent_id format: username/slug
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"test_input": "Hello World"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
assert "graph_name" in result_data
|
||||
assert result_data["graph_name"] == "Test Agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent_missing_inputs(setup_test_data):
|
||||
"""Test that the run_agent tool returns error when inputs are missing"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = RunAgentTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool without required inputs
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={}, # Missing required input
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
# The tool should return an ErrorResponse when setup info indicates not ready
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
"""Test that the run_agent tool returns error for invalid agent ID"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = RunAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool with invalid agent ID
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="invalid/agent-id",
|
||||
inputs={"test_input": "Hello World"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
# Should get an error about failed setup or not found
|
||||
assert any(
|
||||
phrase in result_data["message"].lower() for phrase in ["not found", "failed"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
"""Test that run_agent works with an agent requiring LLM credentials"""
|
||||
# Use test data from fixture
|
||||
user = setup_llm_test_data["user"]
|
||||
graph = setup_llm_test_data["graph"]
|
||||
store_submission = setup_llm_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = RunAgentTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool with a prompt for the LLM
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"user_prompt": "What is 2+2?"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Should successfully start execution since credentials are available
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
assert "graph_name" in result_data
|
||||
assert result_data["graph_name"] == "LLM Test Agent"
|
||||
@@ -0,0 +1,395 @@
|
||||
"""Tool for setting up an agent with credentials and configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.get_required_setup_info import (
|
||||
GetRequiredSetupInfoTool,
|
||||
)
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
ExecutionStartedResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
)
|
||||
from backend.server.v2.library import db as library_db
|
||||
from backend.server.v2.library import model as library_model
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
config = ChatConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentDetails(BaseModel):
|
||||
graph_name: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
recommended_schedule_cron: str | None
|
||||
required_credentials: dict[str, CredentialsMetaInput]
|
||||
|
||||
|
||||
class SetupAgentTool(BaseTool):
|
||||
"""Tool for setting up an agent with scheduled execution or webhook triggers."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "schedule_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Set up an agent with credentials and configure it for scheduled execution or webhook triggers.
|
||||
IMPORTANT: Before calling this tool, you MUST first call get_agent_details to determine what inputs are required.
|
||||
|
||||
For SCHEDULED execution:
|
||||
- Cron format: "minute hour day month weekday" (e.g., "0 9 * * 1-5" = 9am weekdays)
|
||||
- Common patterns: "0 * * * *" (hourly), "0 0 * * *" (daily at midnight), "0 9 * * 1" (Mondays at 9am)
|
||||
- Timezone: Use IANA timezone names like "America/New_York", "Europe/London", "Asia/Tokyo"
|
||||
- The 'inputs' parameter must contain ALL required inputs from get_agent_details as a dictionary
|
||||
|
||||
For WEBHOOK triggers:
|
||||
- The agent will be triggered by external events
|
||||
- Still requires all input values from get_agent_details"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"username_agent_slug": {
|
||||
"type": "string",
|
||||
"description": "The marketplace agent slug (e.g., 'username/agent-name')",
|
||||
},
|
||||
"setup_type": {
|
||||
"type": "string",
|
||||
"enum": ["schedule", "webhook"],
|
||||
"description": "Type of setup: 'schedule' for cron, 'webhook' for triggers.",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name for this setup/schedule (e.g., 'Daily Report', 'Weekly Summary')",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Description of this setup",
|
||||
},
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "Cron expression (5 fields: minute hour day month weekday). Examples: '0 9 * * 1-5' (9am weekdays), '*/30 * * * *' (every 30 min)",
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "IANA timezone (e.g., 'America/New_York', 'Europe/London', 'UTC'). Defaults to UTC if not specified.",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": 'REQUIRED: Dictionary with ALL required inputs from get_agent_details. Format: {"input_name": value}',
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"webhook_config": {
|
||||
"type": "object",
|
||||
"description": "Webhook configuration (required if setup_type is 'webhook')",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
},
|
||||
"required": ["username_agent_slug", "setup_type"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""This tool requires authentication."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Set up an agent with configuration.
|
||||
|
||||
Args:
|
||||
user_id: Authenticated user ID
|
||||
session_id: Chat session ID
|
||||
**kwargs: Setup parameters
|
||||
|
||||
Returns:
|
||||
JSON formatted setup result
|
||||
|
||||
"""
|
||||
assert (
|
||||
user_id is not None
|
||||
), "User ID is required to run an agent. Superclass enforces authentication."
|
||||
|
||||
session_id = session.session_id
|
||||
setup_type = kwargs.get("setup_type", "schedule").strip()
|
||||
if setup_type != "schedule":
|
||||
return ErrorResponse(
|
||||
message="Only schedule setup is supported at this time",
|
||||
session_id=session_id,
|
||||
)
|
||||
else:
|
||||
cron = kwargs.get("cron", "").strip()
|
||||
cron_name = kwargs.get("name", "").strip()
|
||||
if not cron or not cron_name:
|
||||
return ErrorResponse(
|
||||
message="Cron and name are required for schedule setup",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
username_agent_slug = kwargs.get("username_agent_slug", "").strip()
|
||||
inputs = kwargs.get("inputs", {})
|
||||
|
||||
library_agent = await self._get_or_add_library_agent(
|
||||
username_agent_slug, user_id, session, **kwargs
|
||||
)
|
||||
|
||||
if not isinstance(library_agent, AgentDetails):
|
||||
# library agent is an ErrorResponse
|
||||
return library_agent
|
||||
|
||||
if library_agent and (
|
||||
session.successful_agent_schedules.get(library_agent.graph_id, 0)
|
||||
if isinstance(library_agent, AgentDetails)
|
||||
else 0 >= config.max_agent_schedules
|
||||
):
|
||||
return ErrorResponse(
|
||||
message="Maximum number of agent schedules reached. You can't schedule this agent again in this chat session.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
# At this point we know the user is ready to run the agent
|
||||
# Create the schedule for the agent
|
||||
from backend.server.v2.library import db as library_db
|
||||
|
||||
# Get the library agent model for scheduling
|
||||
lib_agent = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=library_agent.graph_id, user_id=user_id
|
||||
)
|
||||
if not lib_agent:
|
||||
return ErrorResponse(
|
||||
message=f"Library agent not found for graph {library_agent.graph_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return await self._add_graph_execution_schedule(
|
||||
library_agent=lib_agent,
|
||||
user_id=user_id,
|
||||
cron=cron,
|
||||
name=cron_name,
|
||||
inputs=inputs,
|
||||
credentials=library_agent.required_credentials,
|
||||
session=session,
|
||||
)
|
||||
|
||||
async def _add_graph_execution_schedule(
|
||||
self,
|
||||
library_agent: library_model.LibraryAgent,
|
||||
user_id: str,
|
||||
cron: str,
|
||||
name: str,
|
||||
inputs: dict[str, Any],
|
||||
credentials: dict[str, CredentialsMetaInput],
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ExecutionStartedResponse | ErrorResponse:
|
||||
# Use timezone from request if provided, otherwise fetch from user profile
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
session_id = session.session_id
|
||||
# Map required credentials (schema field names) to actual user credential IDs
|
||||
# credentials param contains CredentialsMetaInput with schema field names as keys
|
||||
# We need to find the user's actual credentials that match the provider/type
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
user_credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
# Build a mapping from schema field name -> actual credential ID
|
||||
resolved_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[str] = []
|
||||
|
||||
for field_name, cred_meta in credentials.items():
|
||||
# Find a matching credential from the user's credentials
|
||||
matching_cred = next(
|
||||
(
|
||||
c
|
||||
for c in user_credentials
|
||||
if c.provider == cred_meta.provider and c.type == cred_meta.type
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
# Use the actual credential ID instead of the schema field name
|
||||
# Create a new CredentialsMetaInput with the actual credential ID
|
||||
# but keep the same provider/type from the original meta
|
||||
resolved_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=cred_meta.provider,
|
||||
type=cred_meta.type,
|
||||
title=cred_meta.title,
|
||||
)
|
||||
else:
|
||||
missing_credentials.append(
|
||||
f"{cred_meta.title} ({cred_meta.provider}/{cred_meta.type})"
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
return ErrorResponse(
|
||||
message=f"Cannot execute agent: missing {len(missing_credentials)} required credential(s). You need to call the get_required_setup_info tool to setup the credentials.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
result = await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_version=library_agent.graph_version,
|
||||
name=name,
|
||||
cron=cron,
|
||||
input_data=inputs,
|
||||
input_credentials=resolved_credentials,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
# Convert the next_run_time back to user timezone for display
|
||||
if result.next_run_time:
|
||||
result.next_run_time = convert_utc_time_to_user_timezone(
|
||||
result.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
session.successful_agent_schedules[library_agent.graph_id] = (
|
||||
session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1
|
||||
)
|
||||
|
||||
return ExecutionStartedResponse(
|
||||
message=f"Agent execution successfully scheduled. You can add a link to the agent at: /library/agents/{library_agent.id}. Do not run this tool again unless specifically asked to run the agent again.",
|
||||
session_id=session_id,
|
||||
execution_id=result.id,
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_name=library_agent.name,
|
||||
)
|
||||
|
||||
async def _get_or_add_library_agent(
|
||||
self, agent_id: str, user_id: str, session: ChatSession, **kwargs
|
||||
) -> AgentDetails | ErrorResponse:
|
||||
# Call _execute directly since we're calling internally from another tool
|
||||
session_id = session.session_id
|
||||
response = await GetRequiredSetupInfoTool()._execute(user_id, session, **kwargs)
|
||||
|
||||
if not isinstance(response, SetupRequirementsResponse):
|
||||
return ErrorResponse(
|
||||
message="Failed to get required setup information",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
setup_info = SetupInfo.model_validate(response.setup_info)
|
||||
|
||||
if not setup_info.user_readiness.ready_to_run:
|
||||
return ErrorResponse(
|
||||
message=f"User is not ready to run the agent. User Readiness: {setup_info.user_readiness.model_dump_json()} Requirments: {setup_info.requirements}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the graph using the graph_id and graph_version from the setup response
|
||||
if not response.graph_id or not response.graph_version:
|
||||
return ErrorResponse(
|
||||
message=f"Graph information not available for {agent_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
graph = await get_graph(
|
||||
graph_id=response.graph_id,
|
||||
version=response.graph_version,
|
||||
user_id=None, # Public access for store graphs
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
return ErrorResponse(
|
||||
message=f"Graph {agent_id} ({response.graph_id}v{response.graph_version}) not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
recommended_schedule_cron = graph.recommended_schedule_cron
|
||||
|
||||
# Extract credentials from the JSON schema properties
|
||||
credentials_input_schema = graph.credentials_input_schema
|
||||
required_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
if (
|
||||
isinstance(credentials_input_schema, dict)
|
||||
and "properties" in credentials_input_schema
|
||||
):
|
||||
for cred_name, cred_schema in credentials_input_schema[
|
||||
"properties"
|
||||
].items():
|
||||
# Get provider from credentials_provider array or properties.provider.const
|
||||
provider = "unknown"
|
||||
if (
|
||||
"credentials_provider" in cred_schema
|
||||
and cred_schema["credentials_provider"]
|
||||
):
|
||||
provider = cred_schema["credentials_provider"][0]
|
||||
elif (
|
||||
"properties" in cred_schema
|
||||
and "provider" in cred_schema["properties"]
|
||||
):
|
||||
provider = cred_schema["properties"]["provider"].get(
|
||||
"const", "unknown"
|
||||
)
|
||||
|
||||
# Get type from credentials_types array or properties.type.const
|
||||
cred_type = "api_key" # Default
|
||||
if (
|
||||
"credentials_types" in cred_schema
|
||||
and cred_schema["credentials_types"]
|
||||
):
|
||||
cred_type = cred_schema["credentials_types"][0]
|
||||
elif (
|
||||
"properties" in cred_schema and "type" in cred_schema["properties"]
|
||||
):
|
||||
cred_type = cred_schema["properties"]["type"].get(
|
||||
"const", "api_key"
|
||||
)
|
||||
|
||||
required_credentials[cred_name] = CredentialsMetaInput(
|
||||
id=cred_name,
|
||||
title=cred_schema.get("title", cred_name),
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type,
|
||||
)
|
||||
|
||||
# Check if we already have a library agent for this graph
|
||||
existing_library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if not existing_library_agent:
|
||||
# Now we need to add the graph to the users library
|
||||
library_agents: list[library_model.LibraryAgent] = (
|
||||
await library_db.create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
)
|
||||
assert len(library_agents) == 1, "Expected 1 library agent to be created"
|
||||
library_agent = library_agents[0]
|
||||
else:
|
||||
library_agent = existing_library_agent
|
||||
|
||||
return AgentDetails(
|
||||
graph_name=graph.name,
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_version=library_agent.graph_version,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
required_credentials=required_credentials,
|
||||
)
|
||||
@@ -0,0 +1,422 @@
|
||||
import uuid
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.setup_agent import SetupAgentTool
|
||||
from backend.util.clients import get_scheduler_client
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
setup_llm_test_data = setup_llm_test_data
|
||||
setup_test_data = setup_test_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_setup_agent_missing_cron(setup_test_data):
|
||||
"""Test error when cron is missing for schedule setup"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute without cron
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
inputs={"test_input": "Hello World"},
|
||||
# Missing: cron and name
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
assert (
|
||||
"cron" in result_data["message"].lower()
|
||||
or "name" in result_data["message"].lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_setup_agent_webhook_not_supported(setup_test_data):
|
||||
"""Test error when webhook setup is attempted"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with webhook setup_type
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="webhook",
|
||||
inputs={"test_input": "Hello World"},
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
message_lower = result_data["message"].lower()
|
||||
assert "schedule" in message_lower and "supported" in message_lower
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.skip(reason="Requires scheduler service to be running")
|
||||
async def test_setup_agent_schedule_success(setup_test_data):
|
||||
"""Test successfully setting up an agent with a schedule"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
name="Test Schedule",
|
||||
description="Test schedule description",
|
||||
cron="0 9 * * *", # Daily at 9am
|
||||
timezone="UTC",
|
||||
inputs={"test_input": "Hello World"},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check for execution started
|
||||
assert "message" in result_data
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert "graph_name" in result_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.skip(reason="Requires scheduler service to be running")
|
||||
async def test_setup_agent_with_credentials(setup_llm_test_data):
|
||||
"""Test setting up an agent that requires credentials"""
|
||||
# Use test data from fixture (includes OpenAI credentials)
|
||||
user = setup_llm_test_data["user"]
|
||||
store_submission = setup_llm_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
name="LLM Schedule",
|
||||
description="LLM schedule with credentials",
|
||||
cron="*/30 * * * *", # Every 30 minutes
|
||||
timezone="America/New_York",
|
||||
inputs={"user_prompt": "What is 2+2?"},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Should succeed since user has OpenAI credentials
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_setup_agent_invalid_agent(setup_test_data):
|
||||
"""Test error when agent doesn't exist"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute with non-existent agent
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="nonexistent/agent",
|
||||
setup_type="schedule",
|
||||
name="Test Schedule",
|
||||
cron="0 9 * * *",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
# Should fail to find the agent
|
||||
assert any(
|
||||
phrase in result_data["message"].lower()
|
||||
for phrase in ["not found", "failed", "error"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.skip(reason="Requires scheduler service to be running")
|
||||
async def test_setup_agent_schedule_created_in_scheduler(setup_test_data):
|
||||
"""Test that the schedule is actually created in the scheduler service"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Create a unique schedule name to identify this test
|
||||
schedule_name = f"Test Schedule {uuid.uuid4()}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
name=schedule_name,
|
||||
description="Test schedule to verify credentials",
|
||||
cron="0 0 * * *", # Daily at midnight
|
||||
timezone="UTC",
|
||||
inputs={"test_input": "Scheduled execution"},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "execution_id" in result_data
|
||||
|
||||
# Now verify the schedule was created in the scheduler service
|
||||
scheduler = get_scheduler_client()
|
||||
schedules = await scheduler.get_execution_schedules(graph.id, user.id)
|
||||
|
||||
# Find our schedule
|
||||
our_schedule = None
|
||||
for schedule in schedules:
|
||||
if schedule.name == schedule_name:
|
||||
our_schedule = schedule
|
||||
break
|
||||
|
||||
assert (
|
||||
our_schedule is not None
|
||||
), f"Schedule '{schedule_name}' not found in scheduler"
|
||||
assert our_schedule.cron == "0 0 * * *"
|
||||
assert our_schedule.graph_id == graph.id
|
||||
|
||||
# Clean up: delete the schedule
|
||||
await scheduler.delete_schedule(our_schedule.id, user_id=user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.skip(reason="Requires scheduler service to be running")
|
||||
async def test_setup_agent_schedule_with_credentials_triggered(setup_llm_test_data):
|
||||
"""Test that credentials are properly passed when a schedule is triggered"""
|
||||
# Use test data from fixture (includes OpenAI credentials)
|
||||
user = setup_llm_test_data["user"]
|
||||
graph = setup_llm_test_data["graph"]
|
||||
store_submission = setup_llm_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Create a unique schedule name
|
||||
schedule_name = f"LLM Test Schedule {uuid.uuid4()}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
name=schedule_name,
|
||||
description="Test LLM schedule with credentials",
|
||||
cron="* * * * *", # Every minute (for testing)
|
||||
timezone="UTC",
|
||||
inputs={"user_prompt": "Test prompt for credentials"},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "execution_id" in result_data
|
||||
|
||||
# Get the schedule from the scheduler
|
||||
scheduler = get_scheduler_client()
|
||||
schedules = await scheduler.get_execution_schedules(graph.id, user.id)
|
||||
|
||||
# Find our schedule
|
||||
our_schedule = None
|
||||
for schedule in schedules:
|
||||
if schedule.name == schedule_name:
|
||||
our_schedule = schedule
|
||||
break
|
||||
|
||||
assert our_schedule is not None, f"Schedule '{schedule_name}' not found"
|
||||
|
||||
# Verify the schedule has the correct input data
|
||||
assert our_schedule.input_data is not None
|
||||
assert "user_prompt" in our_schedule.input_data
|
||||
assert our_schedule.input_data["user_prompt"] == "Test prompt for credentials"
|
||||
|
||||
# Verify credentials are stored in the schedule
|
||||
# The credentials should be stored as input_credentials
|
||||
assert our_schedule.input_credentials is not None
|
||||
|
||||
# The credentials should contain the OpenAI provider credential
|
||||
# Note: The exact structure depends on how credentials are serialized
|
||||
# We're checking that credentials data exists and has the right provider
|
||||
if our_schedule.input_credentials:
|
||||
# Convert to dict if needed
|
||||
creds_dict = (
|
||||
our_schedule.input_credentials
|
||||
if isinstance(our_schedule.input_credentials, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
# Check if any credential has openai provider
|
||||
has_openai_cred = False
|
||||
for cred_key, cred_value in creds_dict.items():
|
||||
if isinstance(cred_value, dict):
|
||||
if cred_value.get("provider") == "openai":
|
||||
has_openai_cred = True
|
||||
# Verify the credential has the expected structure
|
||||
assert "id" in cred_value or "api_key" in cred_value
|
||||
break
|
||||
|
||||
# If we have LLM block, we should have stored credentials
|
||||
assert has_openai_cred, "OpenAI credentials not found in schedule"
|
||||
|
||||
# Clean up: delete the schedule
|
||||
await scheduler.delete_schedule(our_schedule.id, user_id=user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.skip(reason="Requires scheduler service to be running")
|
||||
async def test_setup_agent_creates_library_agent(setup_test_data):
|
||||
"""Test that setup creates a library agent for the user"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
name="Library Test Schedule",
|
||||
cron="0 12 * * *", # Daily at noon
|
||||
inputs={"test_input": "Library test"},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
|
||||
# Verify library agent was created
|
||||
from backend.server.v2.library import db as library_db
|
||||
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user.id
|
||||
)
|
||||
assert library_agent is not None
|
||||
assert library_agent.graph_id == graph.id
|
||||
assert library_agent.name == "Test Agent"
|
||||
@@ -10,6 +10,7 @@ from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.notification_bus import AsyncRedisNotificationEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.monitoring.instrumentation import (
|
||||
instrument_fastapi,
|
||||
@@ -22,6 +23,7 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
@@ -61,9 +63,21 @@ def get_connection_manager():
|
||||
|
||||
@continuous_retry()
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
@@ -228,7 +242,7 @@ async def websocket_router(
|
||||
user_id = await authenticate_websocket(websocket)
|
||||
if not user_id:
|
||||
return
|
||||
await manager.connect_socket(websocket)
|
||||
await manager.connect_socket(websocket, user_id=user_id)
|
||||
|
||||
# Track WebSocket connection
|
||||
update_websocket_connections(user_id, 1)
|
||||
@@ -301,7 +315,7 @@ async def websocket_router(
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect_socket(websocket)
|
||||
manager.disconnect_socket(websocket, user_id=user_id)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
finally:
|
||||
update_websocket_connections(user_id, -1)
|
||||
@@ -315,9 +329,13 @@ async def health():
|
||||
class WebsocketServer(AppProcess):
|
||||
def run(self):
|
||||
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
|
||||
cors_params = build_cors_params(
|
||||
settings.config.backend_cors_allow_origins,
|
||||
settings.config.app_env,
|
||||
)
|
||||
server_app = CORSMiddleware(
|
||||
app=app,
|
||||
allow_origins=settings.config.backend_cors_allow_origins,
|
||||
**cors_params,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -8,11 +8,13 @@ from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.test_helpers import override_config
|
||||
from backend.server.ws_api import AppEnvironment, WebsocketServer, WSMessage, WSMethod
|
||||
from backend.server.ws_api import app as websocket_app
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
handle_subscribe,
|
||||
handle_unsubscribe,
|
||||
settings,
|
||||
websocket_router,
|
||||
)
|
||||
|
||||
@@ -29,6 +31,47 @@ def mock_manager() -> AsyncMock:
|
||||
return AsyncMock(spec=ConnectionManager)
|
||||
|
||||
|
||||
def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
cors_params = {
|
||||
"allow_origins": ["https://app.example.com"],
|
||||
"allow_origin_regex": None,
|
||||
}
|
||||
mocker.patch("backend.server.ws_api.uvicorn.run")
|
||||
cors_middleware = mocker.patch(
|
||||
"backend.server.ws_api.CORSMiddleware", return_value=object()
|
||||
)
|
||||
build_cors = mocker.patch(
|
||||
"backend.server.ws_api.build_cors_params", return_value=cors_params
|
||||
)
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
), override_config(settings, "app_env", AppEnvironment.LOCAL):
|
||||
WebsocketServer().run()
|
||||
|
||||
build_cors.assert_called_once_with(
|
||||
cors_params["allow_origins"], AppEnvironment.LOCAL
|
||||
)
|
||||
cors_middleware.assert_called_once_with(
|
||||
app=websocket_app,
|
||||
allow_origins=cors_params["allow_origins"],
|
||||
allow_origin_regex=cors_params["allow_origin_regex"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
||||
mocker.patch("backend.server.ws_api.uvicorn.run")
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
||||
with pytest.raises(ValueError):
|
||||
WebsocketServer().run()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_subscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
@@ -53,7 +96,9 @@ async def test_websocket_router_subscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
mock_manager.subscribe_graph_exec.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_exec_id="test-graph-exec-1",
|
||||
@@ -72,7 +117,9 @@ async def test_websocket_router_subscribe(
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "sub")
|
||||
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -99,7 +146,9 @@ async def test_websocket_router_unsubscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_exec_id="test-graph-exec-1",
|
||||
@@ -115,7 +164,9 @@ async def test_websocket_router_unsubscribe(
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "unsub")
|
||||
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -136,11 +187,15 @@ async def test_websocket_router_invalid_method(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,6 +1,35 @@
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
class BlockError(Exception):
|
||||
"""An error occurred during the running of a block"""
|
||||
|
||||
def __init__(self, message: str, block_name: str, block_id: str) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.block_name = block_name
|
||||
self.block_id = block_id
|
||||
|
||||
def __str__(self):
|
||||
return f"raised by {self.block_name} with message: {self.message}. block_id: {self.block_id}"
|
||||
|
||||
|
||||
class BlockInputError(BlockError, ValueError):
|
||||
"""The block had incorrect inputs, resulting in an error condition"""
|
||||
|
||||
|
||||
class BlockOutputError(BlockError, ValueError):
|
||||
"""The block had incorrect outputs, resulting in an error condition"""
|
||||
|
||||
|
||||
class BlockExecutionError(BlockError, ValueError):
|
||||
"""The block failed to execute at runtime, resulting in a handled error"""
|
||||
|
||||
|
||||
class BlockUnknownError(BlockError):
|
||||
"""Critical unknown error with block handling"""
|
||||
|
||||
|
||||
class MissingConfigError(Exception):
|
||||
"""The attempted operation requires configuration which is not available"""
|
||||
|
||||
@@ -17,10 +46,12 @@ class NotAuthorizedError(ValueError):
|
||||
"""The user is not authorized to perform the requested operation"""
|
||||
|
||||
|
||||
class GraphNotInLibraryError(NotAuthorizedError):
|
||||
"""Raised when attempting to execute a graph that is not in the user's library (deleted/archived)."""
|
||||
class GraphNotAccessibleError(NotAuthorizedError):
|
||||
"""Raised when attempting to execute a graph that is not accessible to the user."""
|
||||
|
||||
pass
|
||||
|
||||
class GraphNotInLibraryError(GraphNotAccessibleError):
|
||||
"""Raised when attempting to execute a graph that is not / no longer in the user's library."""
|
||||
|
||||
|
||||
class InsufficientBalanceError(ValueError):
|
||||
@@ -98,3 +129,9 @@ class DatabaseError(Exception):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RedisError(Exception):
|
||||
"""Raised when there is an error interacting with Redis"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,8 @@ from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import ldclient
|
||||
from fastapi import HTTPException
|
||||
from autogpt_libs.auth.dependencies import get_optional_user_id
|
||||
from fastapi import HTTPException, Security
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
@@ -36,6 +37,7 @@ class Flag(str, Enum):
|
||||
BETA_BLOCKS = "beta-blocks"
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
@@ -252,6 +254,72 @@ def feature_flag(
|
||||
return decorator
|
||||
|
||||
|
||||
def create_feature_flag_dependency(
|
||||
flag_key: Flag,
|
||||
default: bool = False,
|
||||
) -> Callable[[str | None], Awaitable[None]]:
|
||||
"""
|
||||
Create a FastAPI dependency that checks a feature flag.
|
||||
|
||||
This dependency automatically extracts the user_id from the JWT token
|
||||
(if present) for proper LaunchDarkly user targeting, while still
|
||||
supporting anonymous access.
|
||||
|
||||
Args:
|
||||
flag_key: The Flag enum value to check
|
||||
default: Default value if flag evaluation fails
|
||||
|
||||
Returns:
|
||||
An async dependency function that raises HTTPException if flag is disabled
|
||||
|
||||
Example:
|
||||
router = APIRouter(
|
||||
dependencies=[Depends(create_feature_flag_dependency(Flag.CHAT))]
|
||||
)
|
||||
"""
|
||||
|
||||
async def check_feature_flag(
|
||||
user_id: str | None = Security(get_optional_user_id),
|
||||
) -> None:
|
||||
"""Check if feature flag is enabled for the user.
|
||||
|
||||
The user_id is automatically injected from JWT authentication if present,
|
||||
or None for anonymous access.
|
||||
"""
|
||||
# For routes that don't require authentication, use anonymous context
|
||||
check_user_id = user_id or "anonymous"
|
||||
|
||||
if not is_configured():
|
||||
logger.debug(
|
||||
f"LaunchDarkly not configured, using default {flag_key.value}={default}"
|
||||
)
|
||||
if not default:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
return
|
||||
|
||||
try:
|
||||
client = get_client()
|
||||
if not client.is_initialized():
|
||||
logger.debug(
|
||||
f"LaunchDarkly not initialized, using default {flag_key.value}={default}"
|
||||
)
|
||||
if not default:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
return
|
||||
|
||||
is_enabled = await is_feature_enabled(flag_key, check_user_id, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"LaunchDarkly error for flag {flag_key.value}: {e}, using default={default}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Failed to check feature flag")
|
||||
|
||||
return check_feature_flag
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
|
||||
|
||||
@@ -71,6 +72,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum number of workers to use for graph execution.",
|
||||
)
|
||||
|
||||
requeue_by_republishing: bool = Field(
|
||||
default=True,
|
||||
description="Send rate-limited messages to back of queue by republishing instead of front requeue to prevent blocking other users.",
|
||||
)
|
||||
|
||||
# FastAPI Thread Pool Configuration
|
||||
# IMPORTANT: FastAPI automatically offloads ALL sync functions to a thread pool:
|
||||
# - Sync endpoint functions (def instead of async def)
|
||||
@@ -412,6 +418,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Name of the event bus",
|
||||
)
|
||||
|
||||
notification_event_bus_name: str = Field(
|
||||
default="notification_event",
|
||||
description="Name of the websocket notification event bus",
|
||||
)
|
||||
|
||||
trust_endpoints_for_requests: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="A whitelist of trusted internal endpoints for the backend to make requests to.",
|
||||
@@ -422,34 +433,62 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum message size limit for communication with the message bus",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"])
|
||||
backend_cors_allow_origins: List[str] = Field(
|
||||
default=["http://localhost:3000"],
|
||||
description="Allowed Origins for CORS. Supports exact URLs (http/https) or entries prefixed with "
|
||||
'"regex:" to match via regular expression.',
|
||||
)
|
||||
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
@classmethod
|
||||
def validate_cors_allow_origins(cls, v: List[str]) -> List[str]:
|
||||
out = []
|
||||
port = None
|
||||
has_localhost = False
|
||||
has_127_0_0_1 = False
|
||||
for url in v:
|
||||
url = url.strip()
|
||||
if url.startswith(("http://", "https://")):
|
||||
if "localhost" in url:
|
||||
port = url.split(":")[2]
|
||||
has_localhost = True
|
||||
if "127.0.0.1" in url:
|
||||
port = url.split(":")[2]
|
||||
has_127_0_0_1 = True
|
||||
out.append(url)
|
||||
else:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
validated: List[str] = []
|
||||
localhost_ports: set[str] = set()
|
||||
ip127_ports: set[str] = set()
|
||||
|
||||
if has_127_0_0_1 and not has_localhost:
|
||||
out.append(f"http://localhost:{port}")
|
||||
if has_localhost and not has_127_0_0_1:
|
||||
out.append(f"http://127.0.0.1:{port}")
|
||||
for raw_origin in v:
|
||||
origin = raw_origin.strip()
|
||||
if origin.startswith("regex:"):
|
||||
pattern = origin[len("regex:") :]
|
||||
if not pattern:
|
||||
raise ValueError("Invalid regex pattern: pattern cannot be empty")
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
raise ValueError(
|
||||
f"Invalid regex pattern '{pattern}': {exc}"
|
||||
) from exc
|
||||
validated.append(origin)
|
||||
continue
|
||||
|
||||
return out
|
||||
if origin.startswith(("http://", "https://")):
|
||||
if "localhost" in origin:
|
||||
try:
|
||||
port = origin.split(":")[2]
|
||||
localhost_ports.add(port)
|
||||
except IndexError as exc:
|
||||
raise ValueError(
|
||||
"localhost origins must include an explicit port, e.g. http://localhost:3000"
|
||||
) from exc
|
||||
if "127.0.0.1" in origin:
|
||||
try:
|
||||
port = origin.split(":")[2]
|
||||
ip127_ports.add(port)
|
||||
except IndexError as exc:
|
||||
raise ValueError(
|
||||
"127.0.0.1 origins must include an explicit port, e.g. http://127.0.0.1:3000"
|
||||
) from exc
|
||||
validated.append(origin)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Invalid URL or regex origin: {origin}")
|
||||
|
||||
for port in ip127_ports - localhost_ports:
|
||||
validated.append(f"http://localhost:{port}")
|
||||
for port in localhost_ports - ip127_ports:
|
||||
validated.append(f"http://127.0.0.1:{port}")
|
||||
|
||||
return validated
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
-- Migrate deprecated Groq and OpenRouter models to their replacements
|
||||
-- This updates all AgentNode blocks that use deprecated models that have been decommissioned
|
||||
-- Deprecated models:
|
||||
-- - deepseek-r1-distill-llama-70b (Groq - decommissioned)
|
||||
-- - gemma2-9b-it (Groq - decommissioned)
|
||||
-- - llama3-70b-8192 (Groq - decommissioned)
|
||||
-- - llama3-8b-8192 (Groq - decommissioned)
|
||||
-- - google/gemini-flash-1.5 (OpenRouter - no endpoints found)
|
||||
|
||||
-- Update llama3-70b-8192 to llama-3.3-70b-versatile
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"llama-3.3-70b-versatile"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'llama3-70b-8192';
|
||||
|
||||
-- Update llama3-8b-8192 to llama-3.1-8b-instant
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"llama-3.1-8b-instant"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'llama3-8b-8192';
|
||||
|
||||
-- Update google/gemini-flash-1.5 to google/gemini-2.5-flash
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"google/gemini-2.5-flash"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'google/gemini-flash-1.5';
|
||||
|
||||
-- Update deepseek-r1-distill-llama-70b to gpt-5-chat-latest (no direct replacement)
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"gpt-5-chat-latest"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'deepseek-r1-distill-llama-70b';
|
||||
|
||||
-- Update gemma2-9b-it to gpt-5-chat-latest (no direct replacement)
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"gpt-5-chat-latest"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'gemma2-9b-it';
|
||||
@@ -402,7 +402,9 @@ class TestDataCreator:
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
graph_data["id"], graph_data.get("version", 1), user["id"]
|
||||
graph_data["id"],
|
||||
graph_data.get("version", 1),
|
||||
user_id=user["id"],
|
||||
)
|
||||
if graph:
|
||||
# Use the API function to create library agent
|
||||
|
||||
350
autogpt_platform/backend/test_requeue_integration.py
Normal file
350
autogpt_platform/backend/test_requeue_integration.py
Normal file
@@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration test for the requeue fix implementation.
|
||||
Tests actual RabbitMQ behavior to verify that republishing sends messages to back of queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from threading import Event
|
||||
from typing import List
|
||||
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
|
||||
|
||||
class QueueOrderTester:
|
||||
"""Helper class to test message ordering in RabbitMQ using a dedicated test queue."""
|
||||
|
||||
def __init__(self):
|
||||
self.received_messages: List[dict] = []
|
||||
self.stop_consuming = Event()
|
||||
self.queue_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
self.queue_client.connect()
|
||||
|
||||
# Use a dedicated test queue name to avoid conflicts
|
||||
self.test_queue_name = "test_requeue_ordering"
|
||||
self.test_exchange = "test_exchange"
|
||||
self.test_routing_key = "test.requeue"
|
||||
|
||||
def setup_queue(self):
|
||||
"""Set up a dedicated test queue for testing."""
|
||||
channel = self.queue_client.get_channel()
|
||||
|
||||
# Declare test exchange
|
||||
channel.exchange_declare(
|
||||
exchange=self.test_exchange, exchange_type="direct", durable=True
|
||||
)
|
||||
|
||||
# Declare test queue
|
||||
channel.queue_declare(
|
||||
queue=self.test_queue_name, durable=True, auto_delete=False
|
||||
)
|
||||
|
||||
# Bind queue to exchange
|
||||
channel.queue_bind(
|
||||
exchange=self.test_exchange,
|
||||
queue=self.test_queue_name,
|
||||
routing_key=self.test_routing_key,
|
||||
)
|
||||
|
||||
# Purge the queue to start fresh
|
||||
channel.queue_purge(self.test_queue_name)
|
||||
print(f"✅ Test queue {self.test_queue_name} setup and purged")
|
||||
|
||||
def create_test_message(self, message_id: str, user_id: str = "test-user") -> str:
|
||||
"""Create a test graph execution message."""
|
||||
return json.dumps(
|
||||
{
|
||||
"graph_exec_id": f"exec-{message_id}",
|
||||
"graph_id": f"graph-{message_id}",
|
||||
"user_id": user_id,
|
||||
"user_context": {"timezone": "UTC"},
|
||||
"nodes_input_masks": {},
|
||||
"starting_nodes_input": [],
|
||||
"parent_graph_exec_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
def publish_message(self, message: str):
|
||||
"""Publish a message to the test queue."""
|
||||
channel = self.queue_client.get_channel()
|
||||
channel.basic_publish(
|
||||
exchange=self.test_exchange,
|
||||
routing_key=self.test_routing_key,
|
||||
body=message,
|
||||
)
|
||||
|
||||
def consume_messages(self, max_messages: int = 10, timeout: float = 5.0):
|
||||
"""Consume messages and track their order."""
|
||||
|
||||
def callback(ch, method, properties, body):
|
||||
try:
|
||||
message_data = json.loads(body.decode())
|
||||
self.received_messages.append(message_data)
|
||||
ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
|
||||
if len(self.received_messages) >= max_messages:
|
||||
self.stop_consuming.set()
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {e}")
|
||||
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
|
||||
|
||||
# Use synchronous consumption with blocking
|
||||
channel = self.queue_client.get_channel()
|
||||
|
||||
# Check if there are messages in the queue first
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=self.test_queue_name, auto_ack=False
|
||||
)
|
||||
if method_frame:
|
||||
# There are messages, set up consumer
|
||||
channel.basic_nack(
|
||||
delivery_tag=method_frame.delivery_tag, requeue=True
|
||||
) # Put message back
|
||||
|
||||
# Set up consumer
|
||||
channel.basic_consume(
|
||||
queue=self.test_queue_name,
|
||||
on_message_callback=callback,
|
||||
)
|
||||
|
||||
# Consume with timeout
|
||||
start_time = time.time()
|
||||
while (
|
||||
not self.stop_consuming.is_set()
|
||||
and (time.time() - start_time) < timeout
|
||||
and len(self.received_messages) < max_messages
|
||||
):
|
||||
try:
|
||||
channel.connection.process_data_events(time_limit=0.1)
|
||||
except Exception as e:
|
||||
print(f"Error during consumption: {e}")
|
||||
break
|
||||
|
||||
# Cancel the consumer
|
||||
try:
|
||||
channel.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# No messages in queue - this might be expected for some tests
|
||||
pass
|
||||
|
||||
return self.received_messages
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up test resources."""
|
||||
try:
|
||||
channel = self.queue_client.get_channel()
|
||||
channel.queue_delete(queue=self.test_queue_name)
|
||||
channel.exchange_delete(exchange=self.test_exchange)
|
||||
print(f"✅ Test queue {self.test_queue_name} cleaned up")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Cleanup issue: {e}")
|
||||
|
||||
|
||||
def test_queue_ordering_behavior():
|
||||
"""
|
||||
Integration test to verify that our republishing method sends messages to back of queue.
|
||||
This tests the actual fix for the rate limiting queue blocking issue.
|
||||
"""
|
||||
tester = QueueOrderTester()
|
||||
|
||||
try:
|
||||
tester.setup_queue()
|
||||
|
||||
print("🧪 Testing actual RabbitMQ queue ordering behavior...")
|
||||
|
||||
# Test 1: Normal FIFO behavior
|
||||
print("1. Testing normal FIFO queue behavior")
|
||||
|
||||
# Publish messages in order: A, B, C
|
||||
msg_a = tester.create_test_message("A")
|
||||
msg_b = tester.create_test_message("B")
|
||||
msg_c = tester.create_test_message("C")
|
||||
|
||||
tester.publish_message(msg_a)
|
||||
tester.publish_message(msg_b)
|
||||
tester.publish_message(msg_c)
|
||||
|
||||
# Consume and verify FIFO order: A, B, C
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=3)
|
||||
|
||||
assert len(messages) == 3, f"Expected 3 messages, got {len(messages)}"
|
||||
assert (
|
||||
messages[0]["graph_exec_id"] == "exec-A"
|
||||
), f"First message should be A, got {messages[0]['graph_exec_id']}"
|
||||
assert (
|
||||
messages[1]["graph_exec_id"] == "exec-B"
|
||||
), f"Second message should be B, got {messages[1]['graph_exec_id']}"
|
||||
assert (
|
||||
messages[2]["graph_exec_id"] == "exec-C"
|
||||
), f"Third message should be C, got {messages[2]['graph_exec_id']}"
|
||||
|
||||
print("✅ FIFO order confirmed: A -> B -> C")
|
||||
|
||||
# Test 2: Rate limiting simulation - the key test!
|
||||
print("2. Testing rate limiting fix scenario")
|
||||
|
||||
# Simulate the scenario where user1 is rate limited
|
||||
user1_msg = tester.create_test_message("RATE-LIMITED", "user1")
|
||||
user2_msg1 = tester.create_test_message("USER2-1", "user2")
|
||||
user2_msg2 = tester.create_test_message("USER2-2", "user2")
|
||||
|
||||
# Initially publish user1 message (gets consumed, then rate limited on retry)
|
||||
tester.publish_message(user1_msg)
|
||||
|
||||
# Other users publish their messages
|
||||
tester.publish_message(user2_msg1)
|
||||
tester.publish_message(user2_msg2)
|
||||
|
||||
# Now simulate: user1 message gets "requeued" using our new republishing method
|
||||
# This is what happens in manager.py when requeue_by_republishing=True
|
||||
tester.publish_message(user1_msg) # Goes to back via our method
|
||||
|
||||
# Expected order: RATE-LIMITED, USER2-1, USER2-2, RATE-LIMITED (republished to back)
|
||||
# This shows that user2 messages get processed instead of being blocked
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=4)
|
||||
|
||||
assert len(messages) == 4, f"Expected 4 messages, got {len(messages)}"
|
||||
|
||||
# The key verification: user2 messages are NOT blocked by user1's rate-limited message
|
||||
user2_messages = [msg for msg in messages if msg["user_id"] == "user2"]
|
||||
assert len(user2_messages) == 2, "Both user2 messages should be processed"
|
||||
assert user2_messages[0]["graph_exec_id"] == "exec-USER2-1"
|
||||
assert user2_messages[1]["graph_exec_id"] == "exec-USER2-2"
|
||||
|
||||
print("✅ Rate limiting fix confirmed: user2 executions NOT blocked by user1")
|
||||
|
||||
# Test 3: Verify our method behaves like going to back of queue
|
||||
print("3. Testing republishing sends messages to back")
|
||||
|
||||
# Start with message X in queue
|
||||
msg_x = tester.create_test_message("X")
|
||||
tester.publish_message(msg_x)
|
||||
|
||||
# Add message Y
|
||||
msg_y = tester.create_test_message("Y")
|
||||
tester.publish_message(msg_y)
|
||||
|
||||
# Republish X (simulates requeue using our method)
|
||||
tester.publish_message(msg_x)
|
||||
|
||||
# Expected: X, Y, X (X was republished to back)
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=3)
|
||||
|
||||
assert len(messages) == 3
|
||||
# Y should come before the republished X
|
||||
y_index = next(
|
||||
i for i, msg in enumerate(messages) if msg["graph_exec_id"] == "exec-Y"
|
||||
)
|
||||
republished_x_index = next(
|
||||
i
|
||||
for i, msg in enumerate(messages[1:], 1)
|
||||
if msg["graph_exec_id"] == "exec-X"
|
||||
)
|
||||
|
||||
assert (
|
||||
y_index < republished_x_index
|
||||
), f"Y should come before republished X, but got order: {[m['graph_exec_id'] for m in messages]}"
|
||||
|
||||
print("✅ Republishing confirmed: messages go to back of queue")
|
||||
|
||||
print("🎉 All integration tests passed!")
|
||||
print("🎉 Our republishing method works correctly with real RabbitMQ")
|
||||
print("🎉 Queue blocking issue is fixed!")
|
||||
|
||||
finally:
|
||||
tester.cleanup()
|
||||
|
||||
|
||||
def test_traditional_requeue_behavior():
|
||||
"""
|
||||
Test that traditional requeue (basic_nack with requeue=True) sends messages to FRONT of queue.
|
||||
This validates our hypothesis about why queue blocking occurs.
|
||||
"""
|
||||
tester = QueueOrderTester()
|
||||
|
||||
try:
|
||||
tester.setup_queue()
|
||||
print("🧪 Testing traditional requeue behavior (basic_nack with requeue=True)")
|
||||
|
||||
# Step 1: Publish message A
|
||||
msg_a = tester.create_test_message("A")
|
||||
tester.publish_message(msg_a)
|
||||
|
||||
# Step 2: Publish message B
|
||||
msg_b = tester.create_test_message("B")
|
||||
tester.publish_message(msg_b)
|
||||
|
||||
# Step 3: Consume message A and requeue it using traditional method
|
||||
channel = tester.queue_client.get_channel()
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=False
|
||||
)
|
||||
|
||||
assert method_frame is not None, "Should have received message A"
|
||||
consumed_msg = json.loads(body.decode())
|
||||
assert (
|
||||
consumed_msg["graph_exec_id"] == "exec-A"
|
||||
), f"Should have consumed message A, got {consumed_msg['graph_exec_id']}"
|
||||
|
||||
# Traditional requeue: basic_nack with requeue=True (sends to FRONT)
|
||||
channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True)
|
||||
print(f"🔄 Traditional requeue (to FRONT): {consumed_msg['graph_exec_id']}")
|
||||
|
||||
# Step 4: Consume all messages using basic_get for reliability
|
||||
received_messages = []
|
||||
|
||||
# Get first message
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=True
|
||||
)
|
||||
if method_frame:
|
||||
msg = json.loads(body.decode())
|
||||
received_messages.append(msg)
|
||||
|
||||
# Get second message
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=True
|
||||
)
|
||||
if method_frame:
|
||||
msg = json.loads(body.decode())
|
||||
received_messages.append(msg)
|
||||
|
||||
# CRITICAL ASSERTION: Traditional requeue should put A at FRONT
|
||||
# Expected order: A (requeued to front), B
|
||||
assert (
|
||||
len(received_messages) == 2
|
||||
), f"Expected 2 messages, got {len(received_messages)}"
|
||||
|
||||
first_msg = received_messages[0]["graph_exec_id"]
|
||||
second_msg = received_messages[1]["graph_exec_id"]
|
||||
|
||||
# This is the critical test: requeued message A should come BEFORE B
|
||||
assert (
|
||||
first_msg == "exec-A"
|
||||
), f"Traditional requeue should put A at FRONT, but first message was: {first_msg}"
|
||||
assert (
|
||||
second_msg == "exec-B"
|
||||
), f"B should come after requeued A, but second message was: {second_msg}"
|
||||
|
||||
print(
|
||||
"✅ HYPOTHESIS CONFIRMED: Traditional requeue sends messages to FRONT of queue"
|
||||
)
|
||||
print(f" Order: {first_msg} (requeued to front) → {second_msg}")
|
||||
print(" This explains why rate-limited messages block other users!")
|
||||
|
||||
finally:
|
||||
tester.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_queue_ordering_behavior()
|
||||
@@ -1,2 +1,3 @@
|
||||
# Configure pnpm to save exact versions
|
||||
save-exact=true
|
||||
save-exact=true
|
||||
engine-strict=true
|
||||
@@ -2,7 +2,6 @@ node_modules
|
||||
pnpm-lock.yaml
|
||||
.next
|
||||
.auth
|
||||
build
|
||||
public
|
||||
Dockerfile
|
||||
.prettierignore
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// The config you add here will be used whenever a users loads a page in their browser.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import { consent } from "@/services/consent/cookies";
|
||||
import { environment } from "@/services/environment";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
@@ -11,6 +12,9 @@ const isDisabled = process.env.DISABLE_SENTRY === "true";
|
||||
|
||||
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
|
||||
// Check for monitoring consent (includes session replay)
|
||||
const hasMonitoringConsent = consent.hasConsentFor("monitoring");
|
||||
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
@@ -50,10 +54,12 @@ Sentry.init({
|
||||
// Define how likely Replay events are sampled.
|
||||
// This sets the sample rate to be 10%. You may want this to be 100% while
|
||||
// in development and sample at a lower rate in production
|
||||
replaysSessionSampleRate: 0.1,
|
||||
// GDPR: Only enable if user has consented to monitoring
|
||||
replaysSessionSampleRate: hasMonitoringConsent ? 0.1 : 0,
|
||||
|
||||
// Define how likely Replay events are sampled when an error occurs.
|
||||
replaysOnErrorSampleRate: 1.0,
|
||||
// GDPR: Only enable if user has consented to monitoring
|
||||
replaysOnErrorSampleRate: hasMonitoringConsent ? 1.0 : 0,
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
|
||||
@@ -5,6 +5,8 @@ const nextConfig = {
|
||||
productionBrowserSourceMaps: true,
|
||||
images: {
|
||||
domains: [
|
||||
// We dont need to maintain alphabetical order here
|
||||
// as we are doing logical grouping of domains
|
||||
"images.unsplash.com",
|
||||
"ddz4ak4pa3d19.cloudfront.net",
|
||||
"upload.wikimedia.org",
|
||||
@@ -12,6 +14,7 @@ const nextConfig = {
|
||||
|
||||
"ideogram.ai", // for generated images
|
||||
"picsum.photos", // for placeholder images
|
||||
"example.com", // for local test data images
|
||||
],
|
||||
remotePatterns: [
|
||||
{
|
||||
@@ -51,8 +54,13 @@ export default isDevelopmentBuild
|
||||
NEXT_PUBLIC_VERCEL_ENV: process.env.VERCEL_ENV,
|
||||
},
|
||||
|
||||
// Only print logs for uploading source maps in CI
|
||||
silent: !process.env.CI,
|
||||
// Enable debug logging if SENTRY_LOG_LEVEL is set to debug
|
||||
// This helps troubleshoot sourcemap upload issues
|
||||
debug: process.env.SENTRY_LOG_LEVEL === "debug",
|
||||
|
||||
// Show logs in CI/Vercel builds to help debug sourcemap upload issues
|
||||
// Set silent to false to see upload progress and any errors
|
||||
silent: !process.env.CI && process.env.SENTRY_LOG_LEVEL !== "debug",
|
||||
|
||||
// For all available options, see:
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/manual-setup/
|
||||
@@ -80,9 +88,15 @@ export default isDevelopmentBuild
|
||||
disable: false, // Source maps are enabled by default
|
||||
assets: ["**/*.js", "**/*.js.map"], // Specify which files to upload
|
||||
ignore: ["**/node_modules/**"], // Files to exclude
|
||||
deleteSourcemapsAfterUpload: true, // Security: delete after upload
|
||||
// Keep sourcemaps available for browser debugging (source is public anyway)
|
||||
deleteSourcemapsAfterUpload: false,
|
||||
},
|
||||
|
||||
// For monorepo: explicitly set the URL prefix where files will be served
|
||||
// Next.js serves static files from /_next/static/ regardless of monorepo structure
|
||||
// This ensures Sentry can correctly map sourcemaps to the served files
|
||||
urlPrefix: "~/_next/static",
|
||||
|
||||
// Automatically tree-shake Sentry logger statements to reduce bundle size
|
||||
disableLogger: true,
|
||||
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
"name": "frontend",
|
||||
"version": "0.3.4",
|
||||
"private": true,
|
||||
"engines": {
|
||||
"node": "22.x"
|
||||
},
|
||||
"scripts": {
|
||||
"dev": "pnpm run generate:api:force && next dev --turbo",
|
||||
"build": "next build",
|
||||
@@ -26,7 +29,7 @@
|
||||
],
|
||||
"dependencies": {
|
||||
"@faker-js/faker": "10.0.0",
|
||||
"@hookform/resolvers": "5.2.1",
|
||||
"@hookform/resolvers": "5.2.2",
|
||||
"@marsidev/react-turnstile": "1.3.1",
|
||||
"@next/third-parties": "15.4.6",
|
||||
"@phosphor-icons/react": "2.1.10",
|
||||
@@ -52,49 +55,49 @@
|
||||
"@rjsf/core": "5.24.13",
|
||||
"@rjsf/utils": "5.24.13",
|
||||
"@rjsf/validator-ajv8": "5.24.13",
|
||||
"@sentry/nextjs": "10.15.0",
|
||||
"@supabase/ssr": "0.6.1",
|
||||
"@supabase/supabase-js": "2.55.0",
|
||||
"@tanstack/react-query": "5.87.1",
|
||||
"@sentry/nextjs": "10.22.0",
|
||||
"@supabase/ssr": "0.7.0",
|
||||
"@supabase/supabase-js": "2.78.0",
|
||||
"@tanstack/react-query": "5.90.6",
|
||||
"@tanstack/react-table": "8.21.3",
|
||||
"@types/jaro-winkler": "0.2.4",
|
||||
"@vercel/analytics": "1.5.0",
|
||||
"@vercel/speed-insights": "1.2.0",
|
||||
"@xyflow/react": "12.8.3",
|
||||
"@xyflow/react": "12.9.2",
|
||||
"boring-avatars": "1.11.2",
|
||||
"class-variance-authority": "0.7.1",
|
||||
"clsx": "2.1.1",
|
||||
"cmdk": "1.1.1",
|
||||
"cookie": "1.0.2",
|
||||
"date-fns": "4.1.0",
|
||||
"dotenv": "17.2.1",
|
||||
"dotenv": "17.2.3",
|
||||
"elliptic": "6.6.1",
|
||||
"embla-carousel-react": "8.6.0",
|
||||
"framer-motion": "12.23.12",
|
||||
"geist": "1.4.2",
|
||||
"framer-motion": "12.23.24",
|
||||
"geist": "1.5.1",
|
||||
"highlight.js": "11.11.1",
|
||||
"jaro-winkler": "0.2.8",
|
||||
"katex": "0.16.22",
|
||||
"launchdarkly-react-client-sdk": "3.8.1",
|
||||
"katex": "0.16.25",
|
||||
"launchdarkly-react-client-sdk": "3.9.0",
|
||||
"lodash": "4.17.21",
|
||||
"lucide-react": "0.539.0",
|
||||
"lucide-react": "0.552.0",
|
||||
"moment": "2.30.1",
|
||||
"next": "15.4.7",
|
||||
"next-themes": "0.4.6",
|
||||
"nuqs": "2.4.3",
|
||||
"nuqs": "2.7.2",
|
||||
"party-js": "2.2.0",
|
||||
"react": "18.3.1",
|
||||
"react-currency-input-field": "4.0.3",
|
||||
"react-day-picker": "9.8.1",
|
||||
"react-day-picker": "9.11.1",
|
||||
"react-dom": "18.3.1",
|
||||
"react-drag-drop-files": "2.4.0",
|
||||
"react-hook-form": "7.62.0",
|
||||
"react-hook-form": "7.66.0",
|
||||
"react-icons": "5.5.0",
|
||||
"react-markdown": "9.0.3",
|
||||
"react-modal": "3.16.3",
|
||||
"react-shepherd": "6.1.9",
|
||||
"react-window": "1.8.11",
|
||||
"recharts": "3.1.2",
|
||||
"recharts": "3.3.0",
|
||||
"rehype-autolink-headings": "7.1.0",
|
||||
"rehype-highlight": "7.0.2",
|
||||
"rehype-katex": "7.0.1",
|
||||
@@ -112,47 +115,47 @@
|
||||
"zustand": "5.0.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.1.1",
|
||||
"@playwright/test": "1.55.0",
|
||||
"@chromatic-com/storybook": "4.1.2",
|
||||
"@playwright/test": "1.56.1",
|
||||
"@storybook/addon-a11y": "9.1.5",
|
||||
"@storybook/addon-docs": "9.1.5",
|
||||
"@storybook/addon-links": "9.1.5",
|
||||
"@storybook/addon-onboarding": "9.1.5",
|
||||
"@storybook/nextjs": "9.1.5",
|
||||
"@tanstack/eslint-plugin-query": "5.86.0",
|
||||
"@tanstack/react-query-devtools": "5.87.3",
|
||||
"@tanstack/eslint-plugin-query": "5.91.2",
|
||||
"@tanstack/react-query-devtools": "5.90.2",
|
||||
"@types/canvas-confetti": "1.9.0",
|
||||
"@types/lodash": "4.17.20",
|
||||
"@types/negotiator": "0.6.4",
|
||||
"@types/node": "24.3.1",
|
||||
"@types/node": "24.10.0",
|
||||
"@types/react": "18.3.17",
|
||||
"@types/react-dom": "18.3.5",
|
||||
"@types/react-modal": "3.16.3",
|
||||
"@types/react-window": "1.8.8",
|
||||
"axe-playwright": "2.1.0",
|
||||
"chromatic": "13.1.4",
|
||||
"axe-playwright": "2.2.2",
|
||||
"chromatic": "13.3.3",
|
||||
"concurrently": "9.2.1",
|
||||
"cross-env": "7.0.3",
|
||||
"eslint": "8.57.1",
|
||||
"eslint-config-next": "15.5.2",
|
||||
"eslint-plugin-storybook": "9.1.5",
|
||||
"import-in-the-middle": "1.14.2",
|
||||
"msw": "2.11.1",
|
||||
"msw-storybook-addon": "2.0.5",
|
||||
"orval": "7.11.2",
|
||||
"pbkdf2": "3.1.3",
|
||||
"msw": "2.11.6",
|
||||
"msw-storybook-addon": "2.0.6",
|
||||
"orval": "7.13.0",
|
||||
"pbkdf2": "3.1.5",
|
||||
"postcss": "8.5.6",
|
||||
"prettier": "3.6.2",
|
||||
"prettier-plugin-tailwindcss": "0.6.14",
|
||||
"prettier-plugin-tailwindcss": "0.7.1",
|
||||
"require-in-the-middle": "7.5.2",
|
||||
"storybook": "9.1.5",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.9.2"
|
||||
"typescript": "5.9.3"
|
||||
},
|
||||
"msw": {
|
||||
"workerDirectory": [
|
||||
"public"
|
||||
]
|
||||
},
|
||||
"packageManager": "pnpm@10.11.1+sha256.211e9990148495c9fc30b7e58396f7eeda83d9243eb75407ea4f8650fb161f7c"
|
||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||
}
|
||||
|
||||
@@ -37,6 +37,27 @@ export default defineConfig({
|
||||
/* Helps debugging failures */
|
||||
trace: "retain-on-failure",
|
||||
video: "retain-on-failure",
|
||||
|
||||
/* Auto-accept cookies in all tests to prevent banner interference */
|
||||
storageState: {
|
||||
cookies: [],
|
||||
origins: [
|
||||
{
|
||||
origin: "http://localhost:3000",
|
||||
localStorage: [
|
||||
{
|
||||
name: "autogpt_cookie_consent",
|
||||
value: JSON.stringify({
|
||||
hasConsented: true,
|
||||
timestamp: Date.now(),
|
||||
analytics: true,
|
||||
monitoring: true,
|
||||
}),
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
/* Maximum time one test can run for */
|
||||
timeout: 25000,
|
||||
|
||||
1797
autogpt_platform/frontend/pnpm-lock.yaml
generated
1797
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
4
autogpt_platform/frontend/pnpm-workspace.yaml
Normal file
4
autogpt_platform/frontend/pnpm-workspace.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
onlyBuiltDependencies:
|
||||
- "@vercel/speed-insights"
|
||||
- esbuild
|
||||
- msw
|
||||
@@ -5,24 +5,23 @@
|
||||
* Mock Service Worker.
|
||||
* @see https://github.com/mswjs/msw
|
||||
* - Please do NOT modify this file.
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.7.0'
|
||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
||||
const PACKAGE_VERSION = '2.11.6'
|
||||
const INTEGRITY_CHECKSUM = '4db4a41e972cec1b64cc569c66952d82'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
self.addEventListener('install', function () {
|
||||
addEventListener('install', function () {
|
||||
self.skipWaiting()
|
||||
})
|
||||
|
||||
self.addEventListener('activate', function (event) {
|
||||
addEventListener('activate', function (event) {
|
||||
event.waitUntil(self.clients.claim())
|
||||
})
|
||||
|
||||
self.addEventListener('message', async function (event) {
|
||||
const clientId = event.source.id
|
||||
addEventListener('message', async function (event) {
|
||||
const clientId = Reflect.get(event.source || {}, 'id')
|
||||
|
||||
if (!clientId || !self.clients) {
|
||||
return
|
||||
@@ -72,11 +71,6 @@ self.addEventListener('message', async function (event) {
|
||||
break
|
||||
}
|
||||
|
||||
case 'MOCK_DEACTIVATE': {
|
||||
activeClientIds.delete(clientId)
|
||||
break
|
||||
}
|
||||
|
||||
case 'CLIENT_CLOSED': {
|
||||
activeClientIds.delete(clientId)
|
||||
|
||||
@@ -94,69 +88,92 @@ self.addEventListener('message', async function (event) {
|
||||
}
|
||||
})
|
||||
|
||||
self.addEventListener('fetch', function (event) {
|
||||
const { request } = event
|
||||
addEventListener('fetch', function (event) {
|
||||
const requestInterceptedAt = Date.now()
|
||||
|
||||
// Bypass navigation requests.
|
||||
if (request.mode === 'navigate') {
|
||||
if (event.request.mode === 'navigate') {
|
||||
return
|
||||
}
|
||||
|
||||
// Opening the DevTools triggers the "only-if-cached" request
|
||||
// that cannot be handled by the worker. Bypass such requests.
|
||||
if (request.cache === 'only-if-cached' && request.mode !== 'same-origin') {
|
||||
if (
|
||||
event.request.cache === 'only-if-cached' &&
|
||||
event.request.mode !== 'same-origin'
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
// Bypass all requests when there are no active clients.
|
||||
// Prevents the self-unregistered worked from handling requests
|
||||
// after it's been deleted (still remains active until the next reload).
|
||||
// after it's been terminated (still remains active until the next reload).
|
||||
if (activeClientIds.size === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate unique request ID.
|
||||
const requestId = crypto.randomUUID()
|
||||
event.respondWith(handleRequest(event, requestId))
|
||||
event.respondWith(handleRequest(event, requestId, requestInterceptedAt))
|
||||
})
|
||||
|
||||
async function handleRequest(event, requestId) {
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {string} requestId
|
||||
* @param {number} requestInterceptedAt
|
||||
*/
|
||||
async function handleRequest(event, requestId, requestInterceptedAt) {
|
||||
const client = await resolveMainClient(event)
|
||||
const response = await getResponse(event, client, requestId)
|
||||
const requestCloneForEvents = event.request.clone()
|
||||
const response = await getResponse(
|
||||
event,
|
||||
client,
|
||||
requestId,
|
||||
requestInterceptedAt,
|
||||
)
|
||||
|
||||
// Send back the response clone for the "response:*" life-cycle events.
|
||||
// Ensure MSW is active and ready to handle the message, otherwise
|
||||
// this message will pend indefinitely.
|
||||
if (client && activeClientIds.has(client.id)) {
|
||||
;(async function () {
|
||||
const responseClone = response.clone()
|
||||
const serializedRequest = await serializeRequest(requestCloneForEvents)
|
||||
|
||||
sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
requestId,
|
||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||
// Clone the response so both the client and the library could consume it.
|
||||
const responseClone = response.clone()
|
||||
|
||||
sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||
request: {
|
||||
id: requestId,
|
||||
...serializedRequest,
|
||||
},
|
||||
response: {
|
||||
type: responseClone.type,
|
||||
status: responseClone.status,
|
||||
statusText: responseClone.statusText,
|
||||
body: responseClone.body,
|
||||
headers: Object.fromEntries(responseClone.headers.entries()),
|
||||
body: responseClone.body,
|
||||
},
|
||||
},
|
||||
[responseClone.body],
|
||||
)
|
||||
})()
|
||||
},
|
||||
responseClone.body ? [serializedRequest.body, responseClone.body] : [],
|
||||
)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// Resolve the main client for the given event.
|
||||
// Client that issues a request doesn't necessarily equal the client
|
||||
// that registered the worker. It's with the latter the worker should
|
||||
// communicate with during the response resolving phase.
|
||||
/**
|
||||
* Resolve the main client for the given event.
|
||||
* Client that issues a request doesn't necessarily equal the client
|
||||
* that registered the worker. It's with the latter the worker should
|
||||
* communicate with during the response resolving phase.
|
||||
* @param {FetchEvent} event
|
||||
* @returns {Promise<Client | undefined>}
|
||||
*/
|
||||
async function resolveMainClient(event) {
|
||||
const client = await self.clients.get(event.clientId)
|
||||
|
||||
@@ -184,12 +201,17 @@ async function resolveMainClient(event) {
|
||||
})
|
||||
}
|
||||
|
||||
async function getResponse(event, client, requestId) {
|
||||
const { request } = event
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {Client | undefined} client
|
||||
* @param {string} requestId
|
||||
* @param {number} requestInterceptedAt
|
||||
* @returns {Promise<Response>}
|
||||
*/
|
||||
async function getResponse(event, client, requestId, requestInterceptedAt) {
|
||||
// Clone the request because it might've been already used
|
||||
// (i.e. its body has been read and sent to the client).
|
||||
const requestClone = request.clone()
|
||||
const requestClone = event.request.clone()
|
||||
|
||||
function passthrough() {
|
||||
// Cast the request headers to a new Headers instance
|
||||
@@ -230,29 +252,18 @@ async function getResponse(event, client, requestId) {
|
||||
}
|
||||
|
||||
// Notify the client that a request has been intercepted.
|
||||
const requestBuffer = await request.arrayBuffer()
|
||||
const serializedRequest = await serializeRequest(event.request)
|
||||
const clientMessage = await sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'REQUEST',
|
||||
payload: {
|
||||
id: requestId,
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: requestBuffer,
|
||||
keepalive: request.keepalive,
|
||||
interceptedAt: requestInterceptedAt,
|
||||
...serializedRequest,
|
||||
},
|
||||
},
|
||||
[requestBuffer],
|
||||
[serializedRequest.body],
|
||||
)
|
||||
|
||||
switch (clientMessage.type) {
|
||||
@@ -268,6 +279,12 @@ async function getResponse(event, client, requestId) {
|
||||
return passthrough()
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Client} client
|
||||
* @param {any} message
|
||||
* @param {Array<Transferable>} transferrables
|
||||
* @returns {Promise<any>}
|
||||
*/
|
||||
function sendToClient(client, message, transferrables = []) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const channel = new MessageChannel()
|
||||
@@ -280,14 +297,18 @@ function sendToClient(client, message, transferrables = []) {
|
||||
resolve(event.data)
|
||||
}
|
||||
|
||||
client.postMessage(
|
||||
message,
|
||||
[channel.port2].concat(transferrables.filter(Boolean)),
|
||||
)
|
||||
client.postMessage(message, [
|
||||
channel.port2,
|
||||
...transferrables.filter(Boolean),
|
||||
])
|
||||
})
|
||||
}
|
||||
|
||||
async function respondWithMock(response) {
|
||||
/**
|
||||
* @param {Response} response
|
||||
* @returns {Response}
|
||||
*/
|
||||
function respondWithMock(response) {
|
||||
// Setting response status code to 0 is a no-op.
|
||||
// However, when responding with a "Response.error()", the produced Response
|
||||
// instance will have status code set to 0. Since it's not possible to create
|
||||
@@ -305,3 +326,24 @@ async function respondWithMock(response) {
|
||||
|
||||
return mockedResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Request} request
|
||||
*/
|
||||
async function serializeRequest(request) {
|
||||
return {
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: await request.arrayBuffer(),
|
||||
keepalive: request.keepalive,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,10 +24,10 @@ export function SelectedAgentCard(props: Props) {
|
||||
{/* Right content */}
|
||||
<div className="ml-3 flex flex-1 flex-col">
|
||||
<div className="mb-2 flex flex-col items-start">
|
||||
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-tight text-zinc-800">
|
||||
<span className="data-sentry-unmask w-[292px] truncate font-sans text-[14px] font-medium leading-tight text-zinc-800">
|
||||
{props.storeAgent.agent_name}
|
||||
</span>
|
||||
<span className="font-norma w-[292px] truncate font-sans text-xs text-zinc-600">
|
||||
<span className="data-sentry-unmask font-norma w-[292px] truncate font-sans text-xs text-zinc-600">
|
||||
by {props.storeAgent.creator}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -21,7 +21,6 @@ export default function OnboardingAgentCard({
|
||||
"relative animate-pulse",
|
||||
"h-[394px] w-[368px] rounded-[20px] border border-transparent bg-zinc-200",
|
||||
)}
|
||||
onClick={onClick}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -67,12 +66,12 @@ export default function OnboardingAgentCard({
|
||||
{/* Text content wrapper */}
|
||||
<div>
|
||||
{/* Title - 2 lines max */}
|
||||
<p className="text-md line-clamp-2 max-h-[50px] font-sans text-base font-medium leading-normal text-zinc-800">
|
||||
<p className="data-sentry-unmask text-md line-clamp-2 max-h-[50px] font-sans text-base font-medium leading-normal text-zinc-800">
|
||||
{agent_name}
|
||||
</p>
|
||||
|
||||
{/* Author - single line with truncate */}
|
||||
<p className="truncate text-sm font-normal leading-normal text-zinc-600">
|
||||
<p className="data-sentry-unmask truncate text-sm font-normal leading-normal text-zinc-600">
|
||||
by {creator}
|
||||
</p>
|
||||
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
import { postV1ResetOnboardingProgress } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
export default async function OnboardingResetPage() {
|
||||
await postV1ResetOnboardingProgress();
|
||||
redirect("/onboarding/1-welcome");
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
"use client";
|
||||
import { postV1ResetOnboardingProgress } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { redirect } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function OnboardingResetPage() {
|
||||
const { toast } = useToast();
|
||||
|
||||
useEffect(() => {
|
||||
postV1ResetOnboardingProgress()
|
||||
.then(() => {
|
||||
toast({
|
||||
title: "Onboarding reset successfully",
|
||||
description: "You can now start the onboarding process again",
|
||||
variant: "success",
|
||||
});
|
||||
|
||||
redirect("/onboarding/1-welcome");
|
||||
})
|
||||
.catch(() => {
|
||||
toast({
|
||||
title: "Failed to reset onboarding",
|
||||
description: "Please try again later",
|
||||
variant: "destructive",
|
||||
});
|
||||
});
|
||||
}, []);
|
||||
|
||||
return <LoadingSpinner cover />;
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
"use client";
|
||||
|
||||
import { useAdminImpersonation } from "./useAdminImpersonation";
|
||||
|
||||
export function AdminImpersonationBanner() {
|
||||
const { isImpersonating, impersonatedUserId, stopImpersonating } =
|
||||
useAdminImpersonation();
|
||||
|
||||
if (!isImpersonating) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mb-4 rounded-md border border-amber-500 bg-amber-50 p-4 text-amber-900">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<strong className="font-semibold">
|
||||
⚠️ ADMIN IMPERSONATION ACTIVE
|
||||
</strong>
|
||||
<span>
|
||||
You are currently acting as user:{" "}
|
||||
<code className="rounded bg-amber-100 px-1 font-mono text-sm">
|
||||
{impersonatedUserId}
|
||||
</code>
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
onClick={stopImpersonating}
|
||||
className="ml-4 flex h-8 items-center rounded-md border border-amber-300 bg-transparent px-3 text-sm hover:bg-amber-100"
|
||||
>
|
||||
Stop Impersonation
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user