mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
120 Commits
fix/sql-in
...
fix/accoun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc85a37305 | ||
|
|
8daec53230 | ||
|
|
ec6f593edc | ||
|
|
e6ed83462d | ||
|
|
1851264a6a | ||
|
|
8f25d43089 | ||
|
|
0c435c4afa | ||
|
|
18002cb8f0 | ||
|
|
240a65e7b3 | ||
|
|
07368468a4 | ||
|
|
52aac09577 | ||
|
|
64a775dfa7 | ||
|
|
5d97706bb8 | ||
|
|
244f3c7c71 | ||
|
|
355219acbd | ||
|
|
1ab66eaed4 | ||
|
|
126d5838a0 | ||
|
|
643aea849b | ||
|
|
3b092f34d8 | ||
|
|
0921d23628 | ||
|
|
0edc669874 | ||
|
|
e64d3d9b99 | ||
|
|
41dc39b97d | ||
|
|
80e573f33b | ||
|
|
06d20e7e4c | ||
|
|
07b5fe859a | ||
|
|
746dbbac84 | ||
|
|
901bb31e14 | ||
|
|
9438817702 | ||
|
|
184a73de7d | ||
|
|
1154f86a5c | ||
|
|
73c93cf554 | ||
|
|
02757d68f3 | ||
|
|
2569576d78 | ||
|
|
3b34c04a7a | ||
|
|
34c9ecf6bc | ||
|
|
a66219fc1f | ||
|
|
8b3a741f60 | ||
|
|
7c48598f44 | ||
|
|
804e3b403a | ||
|
|
9c3f679f30 | ||
|
|
9977144b3d | ||
|
|
81d61a0c94 | ||
|
|
e1e0fb7b25 | ||
|
|
a054740aac | ||
|
|
f78a6df96c | ||
|
|
4d43570552 | ||
|
|
d3d78660da | ||
|
|
749be06599 | ||
|
|
27d886f05c | ||
|
|
be01a1316a | ||
|
|
32bb6705d1 | ||
|
|
536e2a5ec8 | ||
|
|
a3e5f7fce2 | ||
|
|
d674cb80e2 | ||
|
|
c3a6235cee | ||
|
|
43638defa2 | ||
|
|
9371528aab | ||
|
|
711c439642 | ||
|
|
33989f09d0 | ||
|
|
d6ee402483 | ||
|
|
08a4a6652d | ||
|
|
58928b516b | ||
|
|
18bb78d93e | ||
|
|
8058b9487b | ||
|
|
e68896a25a | ||
|
|
dfed092869 | ||
|
|
5559d978d7 | ||
|
|
dcecb17bd1 | ||
|
|
a056d9e71a | ||
|
|
42b643579f | ||
|
|
5b52ca9227 | ||
|
|
8e83586d13 | ||
|
|
df9850a141 | ||
|
|
cbe4086e79 | ||
|
|
fbd5f34a61 | ||
|
|
0b680d4990 | ||
|
|
6037f80502 | ||
|
|
37b3e4e82e | ||
|
|
de7c5b5c31 | ||
|
|
d68dceb9c1 | ||
|
|
193866232c | ||
|
|
979826f559 | ||
|
|
2f87e13d17 | ||
|
|
2ad5a88a5c | ||
|
|
e9cd40c0d4 | ||
|
|
4744675ef9 | ||
|
|
910fd2640d | ||
|
|
eae2616fb5 | ||
|
|
ad3ea59d90 | ||
|
|
69b6b732a2 | ||
|
|
b1a2d21892 | ||
|
|
a78b08f5e7 | ||
|
|
5359f20070 | ||
|
|
427c7eb1d4 | ||
|
|
c17a2f807d | ||
|
|
f80739d38c | ||
|
|
f97e19f418 | ||
|
|
42b9facd4a | ||
|
|
a02b8d9ad7 | ||
|
|
834617d221 | ||
|
|
e6fb649ced | ||
|
|
2f8cdf62ba | ||
|
|
3dc5208f71 | ||
|
|
04493598e2 | ||
|
|
4140331731 | ||
|
|
594b1adcf7 | ||
|
|
cab6590908 | ||
|
|
a1ac109356 | ||
|
|
5506d59da1 | ||
|
|
749341100b | ||
|
|
4922f88851 | ||
|
|
5fb142c656 | ||
|
|
e14594ff4a | ||
|
|
de70ede54a | ||
|
|
59657eb42e | ||
|
|
5e5f45a713 | ||
|
|
b31d60276a | ||
|
|
b52e95e1fc | ||
|
|
f4ba02f2f1 |
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -80,7 +80,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
4
.github/workflows/copilot-setup-steps.yml
vendored
4
.github/workflows/copilot-setup-steps.yml
vendored
@@ -78,7 +78,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -299,4 +299,4 @@ jobs:
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
8
.github/workflows/platform-frontend-ci.yml
vendored
8
.github/workflows/platform-frontend-ci.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
4
.github/workflows/platform-fullstack-ci.yml
vendored
4
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -178,3 +178,4 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
@@ -192,6 +192,8 @@ Quick steps:
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from .config import verify_settings
|
||||
from .dependencies import get_user_id, requires_admin_user, requires_user
|
||||
from .dependencies import (
|
||||
get_optional_user_id,
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from .helpers import add_auth_responses_to_openapi
|
||||
from .models import User
|
||||
|
||||
@@ -8,6 +13,7 @@ __all__ = [
|
||||
"get_user_id",
|
||||
"requires_admin_user",
|
||||
"requires_user",
|
||||
"get_optional_user_id",
|
||||
"add_auth_responses_to_openapi",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -4,11 +4,53 @@ FastAPI dependency functions for JWT-based authentication and authorization.
|
||||
These are the high-level dependency functions used in route definitions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import fastapi
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
optional_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
# Header name for admin impersonation
|
||||
IMPERSONATION_HEADER_NAME = "X-Act-As-User-Id"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_optional_user_id(
|
||||
credentials: HTTPAuthorizationCredentials | None = fastapi.Security(
|
||||
optional_bearer
|
||||
),
|
||||
) -> str | None:
|
||||
"""
|
||||
Attempts to extract the user ID ("sub" claim) from a Bearer JWT if provided.
|
||||
|
||||
This dependency allows for both authenticated and anonymous access. If a valid bearer token is
|
||||
supplied, it parses the JWT and extracts the user ID. If the token is missing or invalid, it returns None,
|
||||
treating the request as anonymous.
|
||||
|
||||
Args:
|
||||
credentials: Optional HTTPAuthorizationCredentials object from FastAPI Security dependency.
|
||||
|
||||
Returns:
|
||||
The user ID (str) extracted from the JWT "sub" claim, or None if no valid token is present.
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse JWT token to get user ID
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auth token validation failed (anonymous access): {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
@@ -32,16 +74,44 @@ async def requires_admin_user(
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
async def get_user_id(
|
||||
request: fastapi.Request, jwt_payload: dict = fastapi.Security(get_jwt_payload)
|
||||
) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
Supports admin impersonation via X-Act-As-User-Id header:
|
||||
- If the header is present and user is admin, returns the impersonated user ID
|
||||
- Otherwise returns the authenticated user's own ID
|
||||
- Logs all impersonation actions for audit trail
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures or missing user ID
|
||||
HTTPException: 403 if non-admin tries to use impersonation
|
||||
"""
|
||||
# Get the authenticated user's ID from JWT
|
||||
user_id = jwt_payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
# Check for admin impersonation header
|
||||
impersonate_header = request.headers.get(IMPERSONATION_HEADER_NAME, "").strip()
|
||||
if impersonate_header:
|
||||
# Verify the authenticated user is an admin
|
||||
authenticated_user = verify_user(jwt_payload, admin_only=False)
|
||||
if authenticated_user.role != "admin":
|
||||
raise fastapi.HTTPException(
|
||||
status_code=403, detail="Only admin users can impersonate other users"
|
||||
)
|
||||
|
||||
# Log the impersonation for audit trail
|
||||
logger.info(
|
||||
f"Admin impersonation: {authenticated_user.user_id} ({authenticated_user.email}) "
|
||||
f"acting as user {impersonate_header} for requesting {request.method} {request.url}"
|
||||
)
|
||||
|
||||
return impersonate_header
|
||||
|
||||
return user_id
|
||||
|
||||
@@ -4,9 +4,10 @@ Tests the full authentication flow from HTTP requests to user validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Security
|
||||
from fastapi import FastAPI, HTTPException, Request, Security
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
@@ -45,6 +46,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
@@ -58,6 +60,7 @@ class TestAuthDependencies:
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
@@ -73,6 +76,7 @@ class TestAuthDependencies:
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
@@ -82,6 +86,7 @@ class TestAuthDependencies:
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
@@ -90,6 +95,7 @@ class TestAuthDependencies:
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
@@ -105,6 +111,7 @@ class TestAuthDependencies:
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
@@ -114,6 +121,7 @@ class TestAuthDependencies:
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
@@ -121,31 +129,40 @@ class TestAuthDependencies:
|
||||
with pytest.raises(KeyError):
|
||||
await requires_admin_user(jwt_payload)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
await get_user_id(request, jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
await get_user_id(request, jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,6 +187,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
@@ -184,6 +202,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
@@ -203,6 +222,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
@@ -240,6 +260,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
@@ -263,6 +284,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
@@ -276,6 +298,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
@@ -290,6 +313,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
@@ -314,6 +338,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
@@ -325,6 +350,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
@@ -333,3 +359,196 @@ class TestAuthDependenciesEdgeCases:
|
||||
# Valid case
|
||||
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
|
||||
assert user.user_id == "user"
|
||||
|
||||
|
||||
class TestAdminImpersonation:
|
||||
"""Test suite for admin user impersonation functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_impersonation_success(self, mocker: MockerFixture):
|
||||
"""Test admin successfully impersonating another user."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to verify audit logging
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the impersonated user ID
|
||||
assert user_id == "target-user-123"
|
||||
|
||||
# Should log the impersonation attempt
|
||||
mock_logger.info.assert_called_once()
|
||||
log_call = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_call
|
||||
assert "admin@example.com" in log_call
|
||||
assert "target-user-123" in log_call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_impersonation_attempt(self, mocker: MockerFixture):
|
||||
"""Test non-admin user attempting impersonation returns 403."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "regular-user",
|
||||
"role": "user",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return regular user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="regular-user", email="user@example.com", role="user"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Only admin users can impersonate other users" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_empty_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with empty header falls back to regular user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": ""}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_missing_header(self, mocker: MockerFixture):
|
||||
"""Test normal behavior when impersonation header is missing."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {} # No impersonation header
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_audit_logging_details(self, mocker: MockerFixture):
|
||||
"""Test that impersonation audit logging includes all required details."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "victim-user-789"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-999",
|
||||
"role": "admin",
|
||||
"email": "superadmin@company.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-999", email="superadmin@company.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to capture audit trail
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Verify all audit details are logged
|
||||
assert user_id == "victim-user-789"
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
log_message = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_message
|
||||
assert "superadmin@company.com" in log_message
|
||||
assert "victim-user-789" in log_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_header_case_sensitivity(self, mocker: MockerFixture):
|
||||
"""Test that impersonation header is case-sensitive."""
|
||||
request = Mock(spec=Request)
|
||||
# Use wrong case - should not trigger impersonation
|
||||
request.headers = {"x-act-as-user-id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to admin's own ID (header case mismatch)
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_with_whitespace_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with whitespace in header value."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": " target-user-123 "}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should strip whitespace and impersonate successfully
|
||||
assert user_id == "target-user-123"
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
@@ -134,13 +134,6 @@ POSTMARK_WEBHOOK_TOKEN=
|
||||
# Error Tracking
|
||||
SENTRY_DSN=
|
||||
|
||||
# Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
# This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
# This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
# Feature Flags
|
||||
LAUNCH_DARKLY_SDK_KEY=
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
@@ -19,7 +20,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
@@ -53,6 +54,7 @@ class AgentExecutorBlock(Block):
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
@@ -65,7 +67,13 @@ class AgentExecutorBlock(Block):
|
||||
categories={BlockCategory.AGENT},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
@@ -75,6 +83,8 @@ class AgentExecutorBlock(Block):
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
parent_graph_exec_id=graph_exec_id,
|
||||
is_sub_graph=True, # AgentExecutorBlock executions are always sub-graphs
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
|
||||
@@ -10,7 +10,12 @@ from backend.blocks.llm import (
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
@@ -23,7 +28,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
@@ -50,7 +55,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,13 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -18,6 +24,7 @@ from backend.util.file import MediaFileType
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
NANO_BANANA_PRO = "google/nano-banana-pro"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
@@ -42,7 +49,7 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -68,9 +75,8 @@ class AIImageCustomizerBlock(Block):
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -60,6 +60,14 @@ SIZE_TO_RECRAFT_DIMENSIONS = {
|
||||
ImageSize.TALL: "1024x1536",
|
||||
}
|
||||
|
||||
SIZE_TO_NANO_BANANA_RATIO = {
|
||||
ImageSize.SQUARE: "1:1",
|
||||
ImageSize.LANDSCAPE: "4:3",
|
||||
ImageSize.PORTRAIT: "3:4",
|
||||
ImageSize.WIDE: "16:9",
|
||||
ImageSize.TALL: "9:16",
|
||||
}
|
||||
|
||||
|
||||
class ImageStyle(str, Enum):
|
||||
"""
|
||||
@@ -98,10 +106,11 @@ class ImageGenModel(str, Enum):
|
||||
FLUX_ULTRA = "Flux 1.1 Pro Ultra"
|
||||
RECRAFT = "Recraft v3"
|
||||
SD3_5 = "Stable Diffusion 3.5 Medium"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -135,9 +144,8 @@ class AIImageGeneratorBlock(Block):
|
||||
title="Image Style",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
image_url: str = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -262,6 +270,20 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
return output
|
||||
|
||||
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
|
||||
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
|
||||
input_params = {
|
||||
"prompt": modified_prompt,
|
||||
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
|
||||
"resolution": "2K", # Default to 2K for good quality/cost balance
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high", # Most permissive
|
||||
}
|
||||
output = await self._run_client(
|
||||
credentials, "google/nano-banana-pro", input_params
|
||||
)
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate image: {str(e)}")
|
||||
|
||||
|
||||
@@ -6,7 +6,13 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -54,7 +60,7 @@ class NormalizationStrategy(str, Enum):
|
||||
|
||||
|
||||
class AIMusicGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -107,9 +113,8 @@ class AIMusicGeneratorBlock(Block):
|
||||
title="Normalization Strategy",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: str = SchemaField(description="URL of the generated audio file")
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -6,7 +6,13 @@ from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -148,7 +154,7 @@ logger = logging.getLogger(__name__)
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -187,9 +193,8 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
placeholder=VisualMediaType.STOCK_VIDEOS,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
@@ -336,7 +341,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -364,9 +369,8 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
@@ -524,7 +528,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
@@ -542,9 +546,8 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -23,7 +24,7 @@ class AirtableCreateBaseBlock(Block):
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -53,7 +54,7 @@ class AirtableCreateBaseBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
@@ -118,7 +119,7 @@ class AirtableListBasesBlock(Block):
|
||||
Lists all bases in an Airtable workspace that the user has access to.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -129,7 +130,7 @@ class AirtableListBasesBlock(Block):
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
bases: list[dict] = SchemaField(description="Array of base objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more bases)", default=None
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -31,7 +32,7 @@ class AirtableListRecordsBlock(Block):
|
||||
Lists records from an Airtable table with optional filtering, sorting, and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -65,7 +66,7 @@ class AirtableListRecordsBlock(Block):
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
@@ -137,7 +138,7 @@ class AirtableGetRecordBlock(Block):
|
||||
Retrieves a single record from an Airtable table by its ID.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -153,7 +154,7 @@ class AirtableGetRecordBlock(Block):
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
@@ -217,7 +218,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
Creates one or more records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -239,7 +240,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
@@ -290,7 +291,7 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
Updates one or more existing records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -306,7 +307,7 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of updated record objects")
|
||||
|
||||
def __init__(self):
|
||||
@@ -339,7 +340,7 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
Deletes one or more records from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -351,7 +352,7 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
description="Array of upto 10 record IDs to delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of deletion results")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -7,7 +7,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -23,13 +24,13 @@ class AirtableListSchemaBlock(Block):
|
||||
fields, and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
base_schema: dict = SchemaField(
|
||||
description="Complete base schema with tables, fields, and views"
|
||||
)
|
||||
@@ -66,7 +67,7 @@ class AirtableCreateTableBlock(Block):
|
||||
Creates a new table in an Airtable base with specified fields and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -77,7 +78,7 @@ class AirtableCreateTableBlock(Block):
|
||||
default=[{"name": "Name", "type": "singleLineText"}],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
table: dict = SchemaField(description="Created table object")
|
||||
table_id: str = SchemaField(description="ID of the created table")
|
||||
|
||||
@@ -109,7 +110,7 @@ class AirtableUpdateTableBlock(Block):
|
||||
Updates an existing table's properties such as name or description.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -125,7 +126,7 @@ class AirtableUpdateTableBlock(Block):
|
||||
description="The date dependency of the table to update", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
table: dict = SchemaField(description="Updated table object")
|
||||
|
||||
def __init__(self):
|
||||
@@ -157,7 +158,7 @@ class AirtableCreateFieldBlock(Block):
|
||||
Adds a new field (column) to an existing Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -176,7 +177,7 @@ class AirtableCreateFieldBlock(Block):
|
||||
description="The options of the field to create", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
field: dict = SchemaField(description="Created field object")
|
||||
field_id: str = SchemaField(description="ID of the created field")
|
||||
|
||||
@@ -209,7 +210,7 @@ class AirtableUpdateFieldBlock(Block):
|
||||
Updates an existing field's properties in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -225,7 +226,7 @@ class AirtableUpdateFieldBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
field: dict = SchemaField(description="Updated field object")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -3,7 +3,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
@@ -32,7 +33,7 @@ class AirtableWebhookTriggerBlock(Block):
|
||||
Thin wrapper just forwards the payloads one at a time to the next block.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -43,7 +44,7 @@ class AirtableWebhookTriggerBlock(Block):
|
||||
description="Airtable webhook event filter"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -10,14 +10,20 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
@@ -69,7 +75,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default_factory=list,
|
||||
|
||||
@@ -14,14 +14,20 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
"""Search for people in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
@@ -109,7 +115,7 @@ class SearchPeopleBlock(Block):
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
|
||||
@@ -6,14 +6,20 @@ from backend.blocks.apollo._auth import (
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
@@ -68,7 +74,7 @@ class GetPersonDetailBlock(Block):
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.block import BlockSchemaInput
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -17,7 +17,7 @@ async def get_profile_key(user_id: str):
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchema):
|
||||
class BaseAyrshareInput(BlockSchemaInput):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
post: str = SchemaField(
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -38,7 +38,7 @@ class PostToBlueskyBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -101,7 +101,7 @@ class PostToFacebookBlock(Block):
|
||||
description="URL for custom link preview", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToGMBBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToInstagramBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToLinkedInBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class PostToPinterestBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -19,7 +19,7 @@ class PostToRedditBlock(Block):
|
||||
|
||||
pass # Uses all base fields
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -43,7 +43,7 @@ class PostToSnapchatBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -38,7 +38,7 @@ class PostToTelegramBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -31,7 +31,7 @@ class PostToThreadsBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -98,7 +98,7 @@ class PostToTikTokBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -97,7 +97,7 @@ class PostToXBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -119,7 +119,7 @@ class PostToYouTubeBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -23,7 +24,7 @@ class BaasBotJoinMeetingBlock(Block):
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
@@ -57,7 +58,7 @@ class BaasBotJoinMeetingBlock(Block):
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
@@ -103,13 +104,13 @@ class BaasBotLeaveMeetingBlock(Block):
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
@@ -138,7 +139,7 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
@@ -147,7 +148,7 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
@@ -185,13 +186,13 @@ class BaasBotDeleteRecordingBlock(Block):
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
@@ -11,7 +11,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -27,7 +28,7 @@ TEST_CREDENTIALS = APIKeyCredentials(
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchema):
|
||||
class TextModification(BlockSchemaInput):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
@@ -60,7 +61,7 @@ class TextModification(BlockSchema):
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
@@ -96,7 +97,7 @@ class BannerbearTextOverlayBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
@@ -105,7 +106,6 @@ class BannerbearTextOverlayBlock(Block):
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType, convert
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
@@ -19,7 +26,7 @@ class FileStoreBlock(Block):
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
@@ -57,7 +64,7 @@ class StoreValueBlock(Block):
|
||||
The block output will be static, the output can be consumed multiple times.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
input: Any = SchemaField(
|
||||
description="Trigger the block to produce the output. "
|
||||
"The value is only used when `data` is None."
|
||||
@@ -68,7 +75,7 @@ class StoreValueBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output: Any = SchemaField(description="The stored data retained in the block.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -94,10 +101,10 @@ class StoreValueBlock(Block):
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
@@ -121,10 +128,10 @@ class PrintToConsoleBlock(Block):
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -154,15 +161,14 @@ class TypeOptions(enum.Enum):
|
||||
|
||||
|
||||
class UniversalTypeConverterBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
value: Any = SchemaField(
|
||||
description="The value to convert to a universal type."
|
||||
)
|
||||
type: TypeOptions = SchemaField(description="The type to convert the value to.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
error: str = SchemaField(description="Error message if conversion failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -195,10 +201,10 @@ class ReverseListOrderBlock(Block):
|
||||
A block which takes in a list and returns it in the opposite order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
input_list: list[Any] = SchemaField(description="The list to reverse")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -2,7 +2,13 @@ import os
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
@@ -15,12 +21,12 @@ class BlockInstallationBlock(Block):
|
||||
for development purposes only.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
code: str = SchemaField(
|
||||
description="Python code of the block to be installed",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
success: str = SchemaField(
|
||||
description="Success message if the block is installed successfully",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
@@ -16,7 +22,7 @@ class ComparisonOperator(Enum):
|
||||
|
||||
|
||||
class ConditionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
value1: Any = SchemaField(
|
||||
description="Enter the first value for comparison",
|
||||
placeholder="For example: 10 or 'hello' or True",
|
||||
@@ -40,7 +46,7 @@ class ConditionBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the condition evaluation (True or False)"
|
||||
)
|
||||
@@ -111,7 +117,7 @@ class ConditionBlock(Block):
|
||||
|
||||
|
||||
class IfInputMatchesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
input: Any = SchemaField(
|
||||
description="The input to match against",
|
||||
placeholder="For example: 10 or 'hello' or True",
|
||||
@@ -131,7 +137,7 @@ class IfInputMatchesBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the condition evaluation (True or False)"
|
||||
)
|
||||
|
||||
@@ -4,9 +4,15 @@ from typing import Any, Literal, Optional
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, JsonValue, SecretStr
|
||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -61,7 +67,7 @@ class MainCodeExecutionResult(BaseModel):
|
||||
jpeg: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
latex: Optional[str] = None
|
||||
json: Optional[JsonValue] = None # type: ignore (reportIncompatibleMethodOverride)
|
||||
json_data: Optional[JsonValue] = Field(None, alias="json")
|
||||
javascript: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
chart: Optional[Chart] = None
|
||||
@@ -159,7 +165,7 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
# TODO : Add support to upload and download files
|
||||
# NOTE: Currently, you can only customize the CPU and Memory
|
||||
# by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -217,7 +223,7 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
@@ -232,7 +238,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -296,7 +301,7 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
|
||||
|
||||
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -346,7 +351,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(
|
||||
title="Text Result",
|
||||
@@ -356,7 +361,6 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -421,7 +425,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
|
||||
|
||||
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -454,7 +458,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
@@ -469,7 +473,6 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
import re
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CodeExtractionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
text: str = SchemaField(
|
||||
description="Text containing code blocks to extract (e.g., AI response)",
|
||||
placeholder="Enter text containing code blocks",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
html: str = SchemaField(description="Extracted HTML code")
|
||||
css: str = SchemaField(description="Extracted CSS code")
|
||||
javascript: str = SchemaField(description="Extracted JavaScript code")
|
||||
|
||||
224
autogpt_platform/backend/backend/blocks/codex.py
Normal file
224
autogpt_platform/backend/backend/blocks/codex.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.responses import Response as OpenAIResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodexCallResult:
|
||||
"""Structured response returned by Codex invocations."""
|
||||
|
||||
response: str
|
||||
reasoning: str
|
||||
response_id: str
|
||||
|
||||
|
||||
class CodexModel(str, Enum):
|
||||
"""Codex-capable OpenAI models."""
|
||||
|
||||
GPT5_1_CODEX = "gpt-5.1-codex"
|
||||
|
||||
|
||||
class CodexReasoningEffort(str, Enum):
|
||||
"""Configuration for the Responses API reasoning effort."""
|
||||
|
||||
NONE = "none"
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
CodexCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPENAI], Literal["api_key"]
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="e2fcb203-3f2d-4ad4-a344-8df3bc7db36b",
|
||||
provider="openai",
|
||||
api_key=SecretStr("mock-openai-api-key"),
|
||||
title="Mock OpenAI API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def CodexCredentialsField() -> CodexCredentials:
|
||||
return CredentialsField(
|
||||
description="OpenAI API key with access to Codex models (Responses API).",
|
||||
)
|
||||
|
||||
|
||||
class CodeGenerationBlock(Block):
|
||||
"""Block that talks to Codex models via the OpenAI Responses API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
prompt: str = SchemaField(
|
||||
description="Primary coding request passed to the Codex model.",
|
||||
placeholder="Generate a Python function that reverses a list.",
|
||||
)
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default=(
|
||||
"You are Codex, an elite software engineer. "
|
||||
"Favor concise, working code and highlight important caveats."
|
||||
),
|
||||
description="Optional instructions injected via the Responses API instructions field.",
|
||||
advanced=True,
|
||||
)
|
||||
model: CodexModel = SchemaField(
|
||||
title="Codex Model",
|
||||
default=CodexModel.GPT5_1_CODEX,
|
||||
description="Codex-optimized model served via the Responses API.",
|
||||
advanced=False,
|
||||
)
|
||||
reasoning_effort: CodexReasoningEffort = SchemaField(
|
||||
title="Reasoning Effort",
|
||||
default=CodexReasoningEffort.MEDIUM,
|
||||
description="Controls the Responses API reasoning budget. Select 'none' to skip reasoning configs.",
|
||||
advanced=True,
|
||||
)
|
||||
max_output_tokens: int | None = SchemaField(
|
||||
title="Max Output Tokens",
|
||||
default=2048,
|
||||
description="Upper bound for generated tokens (hard limit 128,000). Leave blank to let OpenAI decide.",
|
||||
advanced=True,
|
||||
)
|
||||
credentials: CodexCredentials = CodexCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
response: str = SchemaField(
|
||||
description="Code-focused response returned by the Codex model."
|
||||
)
|
||||
reasoning: str = SchemaField(
|
||||
description="Reasoning summary returned by the model, if available.",
|
||||
default="",
|
||||
)
|
||||
response_id: str = SchemaField(
|
||||
description="ID of the Responses API call for auditing/debugging.",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="86a2a099-30df-47b4-b7e4-34ae5f83e0d5",
|
||||
description="Generate or refactor code using OpenAI's Codex (Responses API).",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CodeGenerationBlock.Input,
|
||||
output_schema=CodeGenerationBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"prompt": "Write a TypeScript function that deduplicates an array.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
],
|
||||
test_output=[
|
||||
("response", str),
|
||||
("reasoning", str),
|
||||
("response_id", str),
|
||||
],
|
||||
test_mock={
|
||||
"call_codex": lambda *_args, **_kwargs: CodexCallResult(
|
||||
response="function dedupe<T>(items: T[]): T[] { return [...new Set(items)]; }",
|
||||
reasoning="Used Set to remove duplicates in O(n).",
|
||||
response_id="resp_test",
|
||||
)
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
self.execution_stats = NodeExecutionStats()
|
||||
|
||||
async def call_codex(
|
||||
self,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
model: CodexModel,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
max_output_tokens: int | None,
|
||||
reasoning_effort: CodexReasoningEffort,
|
||||
) -> CodexCallResult:
|
||||
"""Invoke the OpenAI Responses API."""
|
||||
client = AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
request_payload: dict[str, Any] = {
|
||||
"model": model.value,
|
||||
"input": prompt,
|
||||
}
|
||||
if system_prompt:
|
||||
request_payload["instructions"] = system_prompt
|
||||
if max_output_tokens is not None:
|
||||
request_payload["max_output_tokens"] = max_output_tokens
|
||||
if reasoning_effort != CodexReasoningEffort.NONE:
|
||||
request_payload["reasoning"] = {"effort": reasoning_effort.value}
|
||||
|
||||
response = await client.responses.create(**request_payload)
|
||||
if not isinstance(response, OpenAIResponse):
|
||||
raise TypeError(f"Expected OpenAIResponse, got {type(response).__name__}")
|
||||
|
||||
# Extract data directly from typed response
|
||||
text_output = response.output_text or ""
|
||||
reasoning_summary = (
|
||||
str(response.reasoning.summary)
|
||||
if response.reasoning and response.reasoning.summary
|
||||
else ""
|
||||
)
|
||||
response_id = response.id or ""
|
||||
|
||||
# Update usage stats
|
||||
self.execution_stats.input_token_count = (
|
||||
response.usage.input_tokens if response.usage else 0
|
||||
)
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.output_tokens if response.usage else 0
|
||||
)
|
||||
self.execution_stats.llm_call_count += 1
|
||||
|
||||
return CodexCallResult(
|
||||
response=text_output,
|
||||
reasoning=reasoning_summary,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**_kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.call_codex(
|
||||
credentials=credentials,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
system_prompt=input_data.system_prompt,
|
||||
max_output_tokens=input_data.max_output_tokens,
|
||||
reasoning_effort=input_data.reasoning_effort,
|
||||
)
|
||||
|
||||
yield "response", result.response
|
||||
yield "reasoning", result.reasoning
|
||||
yield "response_id", result.response_id
|
||||
@@ -5,7 +5,8 @@ from backend.data.block import (
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -27,10 +28,10 @@ class TranscriptionDataModel(BaseModel):
|
||||
|
||||
|
||||
class CompassAITriggerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
payload: TranscriptionDataModel = SchemaField(hidden=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
transcription: str = SchemaField(
|
||||
description="The contents of the compass transcription."
|
||||
)
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class WordCharacterCountBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
text: str = SchemaField(
|
||||
description="Input text to count words and characters",
|
||||
placeholder="Enter your text here",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
word_count: int = SchemaField(description="Number of words in the input text")
|
||||
character_count: int = SchemaField(
|
||||
description="Number of characters in the input text"
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
from backend.util.mock import MockObject
|
||||
@@ -12,13 +18,13 @@ from backend.util.prompt import estimate_token_count_str
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
@@ -62,7 +68,7 @@ class CreateDictionaryBlock(Block):
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
@@ -86,11 +92,10 @@ class AddToDictionaryBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -141,11 +146,11 @@ class AddToDictionaryBlock(Block):
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
@@ -201,7 +206,7 @@ class FindInDictionaryBlock(Block):
|
||||
|
||||
|
||||
class RemoveFromDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
@@ -210,12 +215,11 @@ class RemoveFromDictionaryBlock(Block):
|
||||
default=False, description="Whether to return the removed value."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after removal."
|
||||
)
|
||||
removed_value: Any = SchemaField(description="The removed value if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -251,19 +255,18 @@ class RemoveFromDictionaryBlock(Block):
|
||||
|
||||
|
||||
class ReplaceDictionaryValueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to replace the value for.")
|
||||
value: Any = SchemaField(description="The new value for the given key.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after replacement."
|
||||
)
|
||||
old_value: Any = SchemaField(description="The value that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -300,10 +303,10 @@ class ReplaceDictionaryValueBlock(Block):
|
||||
|
||||
|
||||
class DictionaryIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -327,7 +330,7 @@ class DictionaryIsEmptyBlock(Block):
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
@@ -343,11 +346,10 @@ class CreateListBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -404,7 +406,7 @@ class CreateListBlock(Block):
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
@@ -425,11 +427,10 @@ class AddToListBlock(Block):
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -484,11 +485,11 @@ class AddToListBlock(Block):
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
@@ -526,15 +527,14 @@ class FindInListBlock(Block):
|
||||
|
||||
|
||||
class GetListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
list: List[Any] = SchemaField(description="The list to get the item from.")
|
||||
index: int = SchemaField(
|
||||
description="The 0-based index of the item (supports negative indices)."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
item: Any = SchemaField(description="The item at the specified index.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -561,7 +561,7 @@ class GetListItemBlock(Block):
|
||||
|
||||
|
||||
class RemoveFromListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
value: Any = SchemaField(
|
||||
default=None, description="Value to remove from the list."
|
||||
@@ -574,10 +574,9 @@ class RemoveFromListBlock(Block):
|
||||
default=False, description="Whether to return the removed item."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
updated_list: List[Any] = SchemaField(description="The list after removal.")
|
||||
removed_item: Any = SchemaField(description="The removed item if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -618,17 +617,16 @@ class RemoveFromListBlock(Block):
|
||||
|
||||
|
||||
class ReplaceListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
index: int = SchemaField(
|
||||
description="Index of the item to replace (supports negative indices)."
|
||||
)
|
||||
value: Any = SchemaField(description="The new value for the given index.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
updated_list: List[Any] = SchemaField(description="The list after replacement.")
|
||||
old_item: Any = SchemaField(description="The item that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -663,10 +661,10 @@ class ReplaceListItemBlock(Block):
|
||||
|
||||
|
||||
class ListIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
list: List[Any] = SchemaField(description="The list to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
is_empty: bool = SchemaField(description="True if the list is empty.")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -8,7 +8,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
@@ -18,7 +19,7 @@ from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class KeywordSuggestion(BlockSchema):
|
||||
class KeywordSuggestion(BlockSchemaInput):
|
||||
"""Schema for a keyword suggestion result."""
|
||||
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
@@ -45,7 +46,7 @@ class KeywordSuggestion(BlockSchema):
|
||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
@@ -77,7 +78,7 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
||||
description="List of keyword suggestions with metrics"
|
||||
)
|
||||
@@ -90,7 +91,6 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -213,12 +213,12 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="The keyword suggestion object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
|
||||
@@ -8,7 +8,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
@@ -18,7 +19,7 @@ from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class RelatedKeyword(BlockSchema):
|
||||
class RelatedKeyword(BlockSchemaInput):
|
||||
"""Schema for a related keyword result."""
|
||||
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
@@ -45,7 +46,7 @@ class RelatedKeyword(BlockSchema):
|
||||
class DataForSeoRelatedKeywordsBlock(Block):
|
||||
"""Block for getting related keywords from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
@@ -85,7 +86,7 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
le=4,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
description="List of related keywords with metrics"
|
||||
)
|
||||
@@ -98,7 +99,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -231,12 +231,12 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
"""Extracts individual fields from a RelatedKeyword object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="The related keyword object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
import codecs
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TextDecoderBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
text: str = SchemaField(
|
||||
description="A string containing escaped characters to be decoded",
|
||||
placeholder='Your entire text block with \\n and \\" escaped characters',
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
decoded_text: str = SchemaField(
|
||||
description="The decoded text with escape sequences processed"
|
||||
)
|
||||
|
||||
@@ -7,7 +7,13 @@ from typing import Any
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
@@ -28,10 +34,10 @@ TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
@@ -164,7 +170,7 @@ class ReadDiscordMessagesBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordMessageBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message to send"
|
||||
@@ -178,7 +184,7 @@ class SendDiscordMessageBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(
|
||||
description="The status of the operation (e.g., 'Message sent', 'Error')"
|
||||
)
|
||||
@@ -310,7 +316,7 @@ class SendDiscordMessageBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordDMBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
user_id: str = SchemaField(
|
||||
description="The Discord user ID to send the DM to (e.g., '123456789012345678')"
|
||||
@@ -319,7 +325,7 @@ class SendDiscordDMBlock(Block):
|
||||
description="The content of the direct message to send"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="The status of the operation")
|
||||
message_id: str = SchemaField(description="The ID of the sent message")
|
||||
|
||||
@@ -399,7 +405,7 @@ class SendDiscordDMBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordEmbedBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel ID or channel name to send the embed to"
|
||||
@@ -436,7 +442,7 @@ class SendDiscordEmbedBlock(Block):
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
message_id: str = SchemaField(description="ID of the sent embed message")
|
||||
|
||||
@@ -586,7 +592,7 @@ class SendDiscordEmbedBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordFileBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel ID or channel name to send the file to"
|
||||
@@ -607,7 +613,7 @@ class SendDiscordFileBlock(Block):
|
||||
description="Optional message to send with the file", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
message_id: str = SchemaField(description="ID of the sent message")
|
||||
|
||||
@@ -788,7 +794,7 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
|
||||
class ReplyToDiscordMessageBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_id: str = SchemaField(
|
||||
description="The channel ID where the message to reply to is located"
|
||||
@@ -799,7 +805,7 @@ class ReplyToDiscordMessageBlock(Block):
|
||||
description="Whether to mention the original message author", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
reply_id: str = SchemaField(description="ID of the reply message")
|
||||
|
||||
@@ -913,13 +919,13 @@ class ReplyToDiscordMessageBlock(Block):
|
||||
|
||||
|
||||
class DiscordUserInfoBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
user_id: str = SchemaField(
|
||||
description="The Discord user ID to get information about"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
user_id: str = SchemaField(
|
||||
description="The user's ID (passed through for chaining)"
|
||||
)
|
||||
@@ -1030,7 +1036,7 @@ class DiscordUserInfoBlock(Block):
|
||||
|
||||
|
||||
class DiscordChannelInfoBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel name or channel ID to look up"
|
||||
@@ -1041,7 +1047,7 @@ class DiscordChannelInfoBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
channel_id: str = SchemaField(description="The channel's ID")
|
||||
channel_name: str = SchemaField(description="The channel's name")
|
||||
server_id: str = SchemaField(description="The server's ID")
|
||||
|
||||
@@ -2,7 +2,13 @@
|
||||
Discord OAuth-based blocks.
|
||||
"""
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import DiscordOAuthUser, get_current_user
|
||||
@@ -21,12 +27,12 @@ class DiscordGetCurrentUserBlock(Block):
|
||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
||||
["identify"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
||||
username: str = SchemaField(description="The user's username")
|
||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import smtplib
|
||||
import socket
|
||||
import ssl
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
@@ -42,16 +50,14 @@ def SMTPCredentialsField() -> SMTPCredentialsInput:
|
||||
|
||||
|
||||
class SMTPConfig(BaseModel):
|
||||
smtp_server: str = SchemaField(
|
||||
default="smtp.example.com", description="SMTP server address"
|
||||
)
|
||||
smtp_server: str = SchemaField(description="SMTP server address")
|
||||
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
||||
|
||||
model_config = ConfigDict(title="SMTP Config")
|
||||
|
||||
|
||||
class SendEmailBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
to_email: str = SchemaField(
|
||||
description="Recipient email address", placeholder="recipient@example.com"
|
||||
)
|
||||
@@ -61,13 +67,10 @@ class SendEmailBlock(Block):
|
||||
body: str = SchemaField(
|
||||
description="Body of the email", placeholder="Enter the email body"
|
||||
)
|
||||
config: SMTPConfig = SchemaField(
|
||||
description="SMTP Config",
|
||||
default=SMTPConfig(),
|
||||
)
|
||||
config: SMTPConfig = SchemaField(description="SMTP Config")
|
||||
credentials: SMTPCredentialsInput = SMTPCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the email sending operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the email sending failed"
|
||||
@@ -114,7 +117,7 @@ class SendEmailBlock(Block):
|
||||
msg["Subject"] = subject
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
|
||||
with smtplib.SMTP(smtp_server, smtp_port) as server:
|
||||
with smtplib.SMTP(smtp_server, smtp_port, timeout=30) as server:
|
||||
server.starttls()
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.sendmail(smtp_username, to_email, msg.as_string())
|
||||
@@ -124,10 +127,59 @@ class SendEmailBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
yield "status", self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
subject=input_data.subject,
|
||||
body=input_data.body,
|
||||
credentials=credentials,
|
||||
)
|
||||
try:
|
||||
status = self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
subject=input_data.subject,
|
||||
body=input_data.body,
|
||||
credentials=credentials,
|
||||
)
|
||||
yield "status", status
|
||||
except socket.gaierror:
|
||||
yield "error", (
|
||||
f"Cannot connect to SMTP server '{input_data.config.smtp_server}'. "
|
||||
"Please verify the server address is correct."
|
||||
)
|
||||
except socket.timeout:
|
||||
yield "error", (
|
||||
f"Connection timeout to '{input_data.config.smtp_server}' "
|
||||
f"on port {input_data.config.smtp_port}. "
|
||||
"The server may be down or unreachable."
|
||||
)
|
||||
except ConnectionRefusedError:
|
||||
yield "error", (
|
||||
f"Connection refused to '{input_data.config.smtp_server}' "
|
||||
f"on port {input_data.config.smtp_port}. "
|
||||
"Common SMTP ports are: 587 (TLS), 465 (SSL), 25 (plain). "
|
||||
"Please verify the port is correct."
|
||||
)
|
||||
except smtplib.SMTPNotSupportedError:
|
||||
yield "error", (
|
||||
f"STARTTLS not supported by server '{input_data.config.smtp_server}'. "
|
||||
"Try using port 465 for SSL or port 25 for unencrypted connection."
|
||||
)
|
||||
except ssl.SSLError as e:
|
||||
yield "error", (
|
||||
f"SSL/TLS error when connecting to '{input_data.config.smtp_server}': {str(e)}. "
|
||||
"The server may require a different security protocol."
|
||||
)
|
||||
except smtplib.SMTPAuthenticationError:
|
||||
yield "error", (
|
||||
"Authentication failed. Please verify your username and password are correct."
|
||||
)
|
||||
except smtplib.SMTPRecipientsRefused:
|
||||
yield "error", (
|
||||
f"Recipient email address '{input_data.to_email}' was rejected by the server. "
|
||||
"Please verify the email address is valid."
|
||||
)
|
||||
except smtplib.SMTPSenderRefused:
|
||||
yield "error", (
|
||||
"Sender email address defined in the credentials that where used"
|
||||
"was rejected by the server. "
|
||||
"Please verify your account is authorized to send emails."
|
||||
)
|
||||
except smtplib.SMTPDataError as e:
|
||||
yield "error", f"Email data rejected by server: {str(e)}"
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -8,7 +8,13 @@ which provides access to LinkedIn profile data and related information.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
@@ -29,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
class GetLinkedinProfileBlock(Block):
|
||||
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for GetLinkedinProfileBlock."""
|
||||
|
||||
linkedin_url: str = SchemaField(
|
||||
@@ -80,13 +86,12 @@ class GetLinkedinProfileBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for GetLinkedinProfileBlock."""
|
||||
|
||||
profile: PersonProfileResponse = SchemaField(
|
||||
description="LinkedIn profile data"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfileBlock."""
|
||||
@@ -199,7 +204,7 @@ class GetLinkedinProfileBlock(Block):
|
||||
class LinkedinPersonLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
first_name: str = SchemaField(
|
||||
@@ -242,13 +247,12 @@ class LinkedinPersonLookupBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
lookup_result: PersonLookupResponse = SchemaField(
|
||||
description="LinkedIn profile lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinPersonLookupBlock."""
|
||||
@@ -346,7 +350,7 @@ class LinkedinPersonLookupBlock(Block):
|
||||
class LinkedinRoleLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role: str = SchemaField(
|
||||
@@ -366,13 +370,12 @@ class LinkedinRoleLookupBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role_lookup_result: RoleLookupResponse = SchemaField(
|
||||
description="LinkedIn role lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinRoleLookupBlock."""
|
||||
@@ -449,7 +452,7 @@ class LinkedinRoleLookupBlock(Block):
|
||||
class GetLinkedinProfilePictureBlock(Block):
|
||||
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
linkedin_profile_url: str = SchemaField(
|
||||
@@ -460,13 +463,12 @@ class GetLinkedinProfilePictureBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
profile_picture_url: MediaFileType = SchemaField(
|
||||
description="LinkedIn profile picture URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfilePictureBlock."""
|
||||
|
||||
22
autogpt_platform/backend/backend/blocks/exa/_test.py
Normal file
22
autogpt_platform/backend/backend/blocks/exa/_test.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Test credentials and helpers for Exa blocks.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,55 +1,59 @@
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.api import AnswerResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
MediaFileType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float
|
||||
neuralSearch: float
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
class AnswerCitation(BaseModel):
|
||||
"""Citation model for answer endpoint."""
|
||||
|
||||
id: str = SchemaField(description="The temporary ID for the document")
|
||||
url: str = SchemaField(description="The URL of the search result")
|
||||
title: Optional[str] = SchemaField(description="The title of the search result")
|
||||
author: Optional[str] = SchemaField(description="The author of the content")
|
||||
publishedDate: Optional[str] = SchemaField(
|
||||
description="An estimate of the creation date"
|
||||
)
|
||||
text: Optional[str] = SchemaField(description="The full text content of the source")
|
||||
image: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the image associated with the result"
|
||||
)
|
||||
favicon: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the favicon for the domain"
|
||||
)
|
||||
|
||||
class SearchBreakdown(BaseModel):
|
||||
search: float
|
||||
contents: float
|
||||
breakdown: CostBreakdown
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float
|
||||
neuralSearch_26_100_results: float
|
||||
neuralSearch_100_plus_results: float
|
||||
keywordSearch_1_100_results: float
|
||||
keywordSearch_100_plus_results: float
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float
|
||||
breakDown: list[SearchBreakdown]
|
||||
perRequestPrices: PerRequestPrices
|
||||
perPagePrices: PerPagePrices
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_citation) -> "AnswerCitation":
|
||||
"""Convert SDK AnswerResult (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_citation, "id", ""),
|
||||
url=getattr(sdk_citation, "url", ""),
|
||||
title=getattr(sdk_citation, "title", None),
|
||||
author=getattr(sdk_citation, "author", None),
|
||||
publishedDate=getattr(sdk_citation, "published_date", None),
|
||||
text=getattr(sdk_citation, "text", None),
|
||||
image=getattr(sdk_citation, "image", None),
|
||||
favicon=getattr(sdk_citation, "favicon", None),
|
||||
)
|
||||
|
||||
|
||||
class ExaAnswerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -58,31 +62,21 @@ class ExaAnswerBlock(Block):
|
||||
placeholder="What is the latest valuation of SpaceX?",
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
default=False,
|
||||
description="If true, the response includes full text content in the search results",
|
||||
advanced=True,
|
||||
)
|
||||
model: str = SchemaField(
|
||||
default="exa",
|
||||
description="The search model to use (exa or exa-pro)",
|
||||
placeholder="exa",
|
||||
advanced=True,
|
||||
description="Include full text content in the search results used for the answer",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
answer: str = SchemaField(
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
citations: list[dict] = SchemaField(
|
||||
description="Search results used to generate the answer",
|
||||
default_factory=list,
|
||||
citations: list[AnswerCitation] = SchemaField(
|
||||
description="Search results used to generate the answer"
|
||||
)
|
||||
cost_dollars: CostDollars = SchemaField(
|
||||
description="Cost breakdown of the request"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
citation: AnswerCitation = SchemaField(
|
||||
description="Individual citation from the answer"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -96,26 +90,24 @@ class ExaAnswerBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/answer"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"text": input_data.text,
|
||||
"model": input_data.model,
|
||||
}
|
||||
# Get answer using SDK (stream=False for blocks) - this IS async, needs await
|
||||
response = await aexa.answer(
|
||||
query=input_data.query, text=input_data.text, stream=False
|
||||
)
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# this should remain true as long as they don't start defaulting to streaming only.
|
||||
# provides a bit of safety for sdk updates.
|
||||
assert type(response) is AnswerResponse
|
||||
|
||||
yield "answer", data.get("answer", "")
|
||||
yield "citations", data.get("citations", [])
|
||||
yield "cost_dollars", data.get("costDollars", {})
|
||||
yield "answer", response.answer
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
citations = [
|
||||
AnswerCitation.from_sdk(sdk_citation)
|
||||
for sdk_citation in response.citations or []
|
||||
]
|
||||
|
||||
yield "citations", citations
|
||||
for citation in citations:
|
||||
yield "citation", citation
|
||||
|
||||
118
autogpt_platform/backend/backend/blocks/exa/code_context.py
Normal file
118
autogpt_platform/backend/backend/blocks/exa/code_context.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Exa Code Context Block
|
||||
|
||||
Provides code search capabilities to find relevant code snippets and examples
|
||||
from open source repositories, documentation, and Stack Overflow.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class CodeContextResponse(BaseModel):
|
||||
"""Stable output model for code context responses."""
|
||||
|
||||
request_id: str
|
||||
query: str
|
||||
response: str
|
||||
results_count: int
|
||||
cost_dollars: str
|
||||
search_time: float
|
||||
output_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "CodeContextResponse":
|
||||
"""Convert API response to our stable model."""
|
||||
return cls(
|
||||
request_id=data.get("requestId", ""),
|
||||
query=data.get("query", ""),
|
||||
response=data.get("response", ""),
|
||||
results_count=data.get("resultsCount", 0),
|
||||
cost_dollars=data.get("costDollars", ""),
|
||||
search_time=data.get("searchTime", 0.0),
|
||||
output_tokens=data.get("outputTokens", 0),
|
||||
)
|
||||
|
||||
|
||||
class ExaCodeContextBlock(Block):
|
||||
"""Get relevant code snippets and examples from open source repositories."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query to find relevant code snippets. Describe what you're trying to do or what code you're looking for.",
|
||||
placeholder="how to use React hooks for state management",
|
||||
)
|
||||
tokens_num: Union[str, int] = SchemaField(
|
||||
default="dynamic",
|
||||
description="Token limit for response. Use 'dynamic' for automatic sizing, 5000 for standard queries, or 10000 for comprehensive examples.",
|
||||
placeholder="dynamic",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
request_id: str = SchemaField(description="Unique identifier for this request")
|
||||
query: str = SchemaField(description="The search query used")
|
||||
response: str = SchemaField(
|
||||
description="Formatted code snippets and contextual examples with sources"
|
||||
)
|
||||
results_count: int = SchemaField(
|
||||
description="Number of code sources found and included"
|
||||
)
|
||||
cost_dollars: str = SchemaField(description="Cost of this request in dollars")
|
||||
search_time: float = SchemaField(
|
||||
description="Time taken to search in milliseconds"
|
||||
)
|
||||
output_tokens: int = SchemaField(description="Number of tokens in the response")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f9e0d1c-2b3a-4567-8901-23456789abcd",
|
||||
description="Search billions of GitHub repos, docs, and Stack Overflow for relevant code examples",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExaCodeContextBlock.Input,
|
||||
output_schema=ExaCodeContextBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/context"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"tokensNum": input_data.tokens_num,
|
||||
}
|
||||
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
context = CodeContextResponse.from_api(data)
|
||||
|
||||
yield "request_id", context.request_id
|
||||
yield "query", context.query
|
||||
yield "response", context.response
|
||||
yield "results_count", context.results_count
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
@@ -1,39 +1,127 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
from .helpers import (
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
ExtrasSettings,
|
||||
HighlightSettings,
|
||||
LivecrawlTypes,
|
||||
SummarySettings,
|
||||
)
|
||||
|
||||
|
||||
class ContentStatusTag(str, Enum):
|
||||
CRAWL_NOT_FOUND = "CRAWL_NOT_FOUND"
|
||||
CRAWL_TIMEOUT = "CRAWL_TIMEOUT"
|
||||
CRAWL_LIVECRAWL_TIMEOUT = "CRAWL_LIVECRAWL_TIMEOUT"
|
||||
SOURCE_NOT_AVAILABLE = "SOURCE_NOT_AVAILABLE"
|
||||
CRAWL_UNKNOWN_ERROR = "CRAWL_UNKNOWN_ERROR"
|
||||
|
||||
|
||||
class ContentError(BaseModel):
|
||||
tag: Optional[ContentStatusTag] = SchemaField(
|
||||
default=None, description="Specific error type"
|
||||
)
|
||||
httpStatusCode: Optional[int] = SchemaField(
|
||||
default=None, description="The corresponding HTTP status code"
|
||||
)
|
||||
|
||||
|
||||
class ContentStatus(BaseModel):
|
||||
id: str = SchemaField(description="The URL that was requested")
|
||||
status: str = SchemaField(
|
||||
description="Status of the content fetch operation (success or error)"
|
||||
)
|
||||
error: Optional[ContentError] = SchemaField(
|
||||
default=None, description="Error details, only present when status is 'error'"
|
||||
)
|
||||
|
||||
|
||||
class ExaContentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
ids: list[str] = SchemaField(
|
||||
description="Array of document IDs obtained from searches"
|
||||
urls: list[str] = SchemaField(
|
||||
description="Array of URLs to crawl (preferred over 'ids')",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentSettings(),
|
||||
ids: list[str] = SchemaField(
|
||||
description="[DEPRECATED - use 'urls' instead] Array of document IDs obtained from searches",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
description="Retrieve text content from pages",
|
||||
default=True,
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
description="Text snippets most relevant from each page",
|
||||
default=HighlightSettings(),
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
description="LLM-generated summary of the webpage",
|
||||
default=SummarySettings(),
|
||||
)
|
||||
livecrawl: Optional[LivecrawlTypes] = SchemaField(
|
||||
description="Livecrawling options: never, fallback (default), always, preferred",
|
||||
default=LivecrawlTypes.FALLBACK,
|
||||
advanced=True,
|
||||
)
|
||||
livecrawl_timeout: Optional[int] = SchemaField(
|
||||
description="Timeout for livecrawling in milliseconds",
|
||||
default=10000,
|
||||
advanced=True,
|
||||
)
|
||||
subpages: Optional[int] = SchemaField(
|
||||
description="Number of subpages to crawl", default=0, ge=0, advanced=True
|
||||
)
|
||||
subpage_target: Optional[str | list[str]] = SchemaField(
|
||||
description="Keyword(s) to find specific subpages of search results",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
extras: ExtrasSettings = SchemaField(
|
||||
description="Extra parameters for additional content",
|
||||
default=ExtrasSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents", default_factory=list
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of document contents with metadata"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single document content result"
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs"
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
statuses: list[ContentStatus] = SchemaField(
|
||||
description="Status information for each requested URL"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -47,23 +135,91 @@ class ExaContentsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/contents"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
if not input_data.urls and not input_data.ids:
|
||||
raise ValueError("Either 'urls' or 'ids' must be provided")
|
||||
|
||||
# Convert ContentSettings to API format
|
||||
payload = {
|
||||
"ids": input_data.ids,
|
||||
"text": input_data.contents.text,
|
||||
"highlights": input_data.contents.highlights,
|
||||
"summary": input_data.contents.summary,
|
||||
}
|
||||
sdk_kwargs = {}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Prefer urls over ids
|
||||
if input_data.urls:
|
||||
sdk_kwargs["urls"] = input_data.urls
|
||||
elif input_data.ids:
|
||||
sdk_kwargs["ids"] = input_data.ids
|
||||
|
||||
if input_data.text:
|
||||
sdk_kwargs["text"] = {"includeHtmlTags": True}
|
||||
|
||||
# Handle highlights - only include if modified from defaults
|
||||
if input_data.highlights and (
|
||||
input_data.highlights.num_sentences != 1
|
||||
or input_data.highlights.highlights_per_url != 1
|
||||
or input_data.highlights.query is not None
|
||||
):
|
||||
highlights_dict = {}
|
||||
highlights_dict["numSentences"] = input_data.highlights.num_sentences
|
||||
highlights_dict["highlightsPerUrl"] = (
|
||||
input_data.highlights.highlights_per_url
|
||||
)
|
||||
if input_data.highlights.query:
|
||||
highlights_dict["query"] = input_data.highlights.query
|
||||
sdk_kwargs["highlights"] = highlights_dict
|
||||
|
||||
# Handle summary - only include if modified from defaults
|
||||
if input_data.summary and (
|
||||
input_data.summary.query is not None
|
||||
or input_data.summary.schema is not None
|
||||
):
|
||||
summary_dict = {}
|
||||
if input_data.summary.query:
|
||||
summary_dict["query"] = input_data.summary.query
|
||||
if input_data.summary.schema:
|
||||
summary_dict["schema"] = input_data.summary.schema
|
||||
sdk_kwargs["summary"] = summary_dict
|
||||
|
||||
if input_data.livecrawl:
|
||||
sdk_kwargs["livecrawl"] = input_data.livecrawl.value
|
||||
|
||||
if input_data.livecrawl_timeout is not None:
|
||||
sdk_kwargs["livecrawl_timeout"] = input_data.livecrawl_timeout
|
||||
|
||||
if input_data.subpages is not None:
|
||||
sdk_kwargs["subpages"] = input_data.subpages
|
||||
|
||||
if input_data.subpage_target:
|
||||
sdk_kwargs["subpage_target"] = input_data.subpage_target
|
||||
|
||||
# Handle extras - only include if modified from defaults
|
||||
if input_data.extras and (
|
||||
input_data.extras.links > 0 or input_data.extras.image_links > 0
|
||||
):
|
||||
extras_dict = {}
|
||||
if input_data.extras.links:
|
||||
extras_dict["links"] = input_data.extras.links
|
||||
if input_data.extras.image_links:
|
||||
extras_dict["image_links"] = input_data.extras.image_links
|
||||
sdk_kwargs["extras"] = extras_dict
|
||||
|
||||
# Always enable context for LLM-ready output
|
||||
sdk_kwargs["context"] = True
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
response = await aexa.get_contents(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.statuses:
|
||||
yield "statuses", response.statuses
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -1,51 +1,150 @@
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from backend.sdk import BaseModel, SchemaField
|
||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
||||
|
||||
|
||||
class TextSettings(BaseModel):
|
||||
max_characters: int = SchemaField(
|
||||
default=1000,
|
||||
class LivecrawlTypes(str, Enum):
|
||||
NEVER = "never"
|
||||
FALLBACK = "fallback"
|
||||
ALWAYS = "always"
|
||||
PREFERRED = "preferred"
|
||||
|
||||
|
||||
class TextEnabled(BaseModel):
|
||||
discriminator: Literal["enabled"] = "enabled"
|
||||
|
||||
|
||||
class TextDisabled(BaseModel):
|
||||
discriminator: Literal["disabled"] = "disabled"
|
||||
|
||||
|
||||
class TextAdvanced(BaseModel):
|
||||
discriminator: Literal["advanced"] = "advanced"
|
||||
max_characters: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Maximum number of characters to return",
|
||||
placeholder="1000",
|
||||
)
|
||||
include_html_tags: bool = SchemaField(
|
||||
default=False,
|
||||
description="Whether to include HTML tags in the text",
|
||||
description="Include HTML tags in the response, helps LLMs understand text structure",
|
||||
placeholder="False",
|
||||
)
|
||||
|
||||
|
||||
class HighlightSettings(BaseModel):
|
||||
num_sentences: int = SchemaField(
|
||||
default=3,
|
||||
default=1,
|
||||
description="Number of sentences per highlight",
|
||||
placeholder="3",
|
||||
placeholder="1",
|
||||
ge=1,
|
||||
)
|
||||
highlights_per_url: int = SchemaField(
|
||||
default=3,
|
||||
default=1,
|
||||
description="Number of highlights per URL",
|
||||
placeholder="3",
|
||||
placeholder="1",
|
||||
ge=1,
|
||||
)
|
||||
query: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Custom query to direct the LLM's selection of highlights",
|
||||
placeholder="Key advancements",
|
||||
)
|
||||
|
||||
|
||||
class SummarySettings(BaseModel):
|
||||
query: Optional[str] = SchemaField(
|
||||
default="",
|
||||
description="Query string for summarization",
|
||||
placeholder="Enter query",
|
||||
default=None,
|
||||
description="Custom query for the LLM-generated summary",
|
||||
placeholder="Main developments",
|
||||
)
|
||||
schema: Optional[dict] = SchemaField( # type: ignore
|
||||
default=None,
|
||||
description="JSON schema for structured output from summary",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class ExtrasSettings(BaseModel):
|
||||
links: int = SchemaField(
|
||||
default=0,
|
||||
description="Number of URLs to return from each webpage",
|
||||
placeholder="1",
|
||||
ge=0,
|
||||
)
|
||||
image_links: int = SchemaField(
|
||||
default=0,
|
||||
description="Number of images to return for each result",
|
||||
placeholder="1",
|
||||
ge=0,
|
||||
)
|
||||
|
||||
|
||||
class ContextEnabled(BaseModel):
|
||||
discriminator: Literal["enabled"] = "enabled"
|
||||
|
||||
|
||||
class ContextDisabled(BaseModel):
|
||||
discriminator: Literal["disabled"] = "disabled"
|
||||
|
||||
|
||||
class ContextAdvanced(BaseModel):
|
||||
discriminator: Literal["advanced"] = "advanced"
|
||||
max_characters: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Maximum character limit for context string",
|
||||
placeholder="10000",
|
||||
)
|
||||
|
||||
|
||||
class ContentSettings(BaseModel):
|
||||
text: TextSettings = SchemaField(
|
||||
default=TextSettings(),
|
||||
text: Optional[Union[bool, TextEnabled, TextDisabled, TextAdvanced]] = SchemaField(
|
||||
default=None,
|
||||
description="Text content retrieval. Boolean for simple enable/disable or object for advanced settings",
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
default=HighlightSettings(),
|
||||
highlights: Optional[HighlightSettings] = SchemaField(
|
||||
default=None,
|
||||
description="Text snippets most relevant from each page",
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
default=SummarySettings(),
|
||||
summary: Optional[SummarySettings] = SchemaField(
|
||||
default=None,
|
||||
description="LLM-generated summary of the webpage",
|
||||
)
|
||||
livecrawl: Optional[LivecrawlTypes] = SchemaField(
|
||||
default=None,
|
||||
description="Livecrawling options: never, fallback, always, preferred",
|
||||
advanced=True,
|
||||
)
|
||||
livecrawl_timeout: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Timeout for livecrawling in milliseconds",
|
||||
placeholder="10000",
|
||||
advanced=True,
|
||||
)
|
||||
subpages: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Number of subpages to crawl",
|
||||
placeholder="0",
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
subpage_target: Optional[Union[str, list[str]]] = SchemaField(
|
||||
default=None,
|
||||
description="Keyword(s) to find specific subpages of search results",
|
||||
advanced=True,
|
||||
)
|
||||
extras: Optional[ExtrasSettings] = SchemaField(
|
||||
default=None,
|
||||
description="Extra parameters for additional content",
|
||||
advanced=True,
|
||||
)
|
||||
context: Optional[Union[bool, ContextEnabled, ContextDisabled, ContextAdvanced]] = (
|
||||
SchemaField(
|
||||
default=None,
|
||||
description="Format search results into a context string for LLMs",
|
||||
advanced=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -127,3 +226,225 @@ class WebsetEnrichmentConfig(BaseModel):
|
||||
default=None,
|
||||
description="Options for the enrichment",
|
||||
)
|
||||
|
||||
|
||||
# Shared result models
|
||||
class ExaSearchExtras(BaseModel):
|
||||
links: list[str] = SchemaField(
|
||||
default_factory=list, description="Array of links from the search result"
|
||||
)
|
||||
imageLinks: list[str] = SchemaField(
|
||||
default_factory=list, description="Array of image links from the search result"
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchResults(BaseModel):
|
||||
title: str | None = None
|
||||
url: str | None = None
|
||||
publishedDate: str | None = None
|
||||
author: str | None = None
|
||||
id: str
|
||||
image: MediaFileType | None = None
|
||||
favicon: MediaFileType | None = None
|
||||
text: str | None = None
|
||||
highlights: list[str] = SchemaField(default_factory=list)
|
||||
highlightScores: list[float] = SchemaField(default_factory=list)
|
||||
summary: str | None = None
|
||||
subpages: list[dict] = SchemaField(default_factory=list)
|
||||
extras: ExaSearchExtras | None = None
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_result) -> "ExaSearchResults":
|
||||
"""Convert SDK Result (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_result, "id", ""),
|
||||
url=getattr(sdk_result, "url", None),
|
||||
title=getattr(sdk_result, "title", None),
|
||||
author=getattr(sdk_result, "author", None),
|
||||
publishedDate=getattr(sdk_result, "published_date", None),
|
||||
text=getattr(sdk_result, "text", None),
|
||||
highlights=getattr(sdk_result, "highlights", None) or [],
|
||||
highlightScores=getattr(sdk_result, "highlight_scores", None) or [],
|
||||
summary=getattr(sdk_result, "summary", None),
|
||||
subpages=getattr(sdk_result, "subpages", None) or [],
|
||||
image=getattr(sdk_result, "image", None),
|
||||
favicon=getattr(sdk_result, "favicon", None),
|
||||
extras=getattr(sdk_result, "extras", None),
|
||||
)
|
||||
|
||||
|
||||
# Cost tracking models
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float = SchemaField(default=0.0)
|
||||
neuralSearch: float = SchemaField(default=0.0)
|
||||
contentText: float = SchemaField(default=0.0)
|
||||
contentHighlight: float = SchemaField(default=0.0)
|
||||
contentSummary: float = SchemaField(default=0.0)
|
||||
|
||||
|
||||
class CostBreakdownItem(BaseModel):
|
||||
search: float = SchemaField(default=0.0)
|
||||
contents: float = SchemaField(default=0.0)
|
||||
breakdown: CostBreakdown = SchemaField(default_factory=CostBreakdown)
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float = SchemaField(default=0.005)
|
||||
neuralSearch_26_100_results: float = SchemaField(default=0.025)
|
||||
neuralSearch_100_plus_results: float = SchemaField(default=1.0)
|
||||
keywordSearch_1_100_results: float = SchemaField(default=0.0025)
|
||||
keywordSearch_100_plus_results: float = SchemaField(default=3.0)
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float = SchemaField(default=0.001)
|
||||
contentHighlight: float = SchemaField(default=0.001)
|
||||
contentSummary: float = SchemaField(default=0.001)
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float = SchemaField(description="Total dollar cost for your request")
|
||||
breakDown: list[CostBreakdownItem] = SchemaField(
|
||||
default_factory=list, description="Breakdown of costs by operation type"
|
||||
)
|
||||
perRequestPrices: PerRequestPrices = SchemaField(
|
||||
default_factory=PerRequestPrices,
|
||||
description="Standard price per request for different operations",
|
||||
)
|
||||
perPagePrices: PerPagePrices = SchemaField(
|
||||
default_factory=PerPagePrices,
|
||||
description="Standard price per page for different content operations",
|
||||
)
|
||||
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
# Handle backward compatibility with boolean
|
||||
if isinstance(text, bool):
|
||||
return text
|
||||
elif isinstance(text, TextDisabled):
|
||||
return False
|
||||
elif isinstance(text, TextEnabled):
|
||||
return True
|
||||
elif isinstance(text, TextAdvanced):
|
||||
text_dict = {}
|
||||
if text.max_characters:
|
||||
text_dict["maxCharacters"] = text.max_characters
|
||||
if text.include_html_tags:
|
||||
text_dict["includeHtmlTags"] = text.include_html_tags
|
||||
return text_dict if text_dict else True
|
||||
return None
|
||||
|
||||
|
||||
def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str, Any]:
|
||||
"""Process ContentSettings into API payload format."""
|
||||
if not contents:
|
||||
return {}
|
||||
|
||||
content_settings = {}
|
||||
|
||||
# Handle text field (can be boolean or object)
|
||||
text_value = process_text_field(contents.text)
|
||||
if text_value is not None:
|
||||
content_settings["text"] = text_value
|
||||
|
||||
# Handle highlights
|
||||
if contents.highlights:
|
||||
highlights_dict: Dict[str, Any] = {
|
||||
"numSentences": contents.highlights.num_sentences,
|
||||
"highlightsPerUrl": contents.highlights.highlights_per_url,
|
||||
}
|
||||
if contents.highlights.query:
|
||||
highlights_dict["query"] = contents.highlights.query
|
||||
content_settings["highlights"] = highlights_dict
|
||||
|
||||
if contents.summary:
|
||||
summary_dict = {}
|
||||
if contents.summary.query:
|
||||
summary_dict["query"] = contents.summary.query
|
||||
if contents.summary.schema:
|
||||
summary_dict["schema"] = contents.summary.schema
|
||||
content_settings["summary"] = summary_dict
|
||||
|
||||
if contents.livecrawl:
|
||||
content_settings["livecrawl"] = contents.livecrawl.value
|
||||
|
||||
if contents.livecrawl_timeout is not None:
|
||||
content_settings["livecrawlTimeout"] = contents.livecrawl_timeout
|
||||
|
||||
if contents.subpages is not None:
|
||||
content_settings["subpages"] = contents.subpages
|
||||
|
||||
if contents.subpage_target:
|
||||
content_settings["subpageTarget"] = contents.subpage_target
|
||||
|
||||
if contents.extras:
|
||||
extras_dict = {}
|
||||
if contents.extras.links:
|
||||
extras_dict["links"] = contents.extras.links
|
||||
if contents.extras.image_links:
|
||||
extras_dict["imageLinks"] = contents.extras.image_links
|
||||
content_settings["extras"] = extras_dict
|
||||
|
||||
context_value = process_context_field(contents.context)
|
||||
if context_value is not None:
|
||||
content_settings["context"] = context_value
|
||||
|
||||
return content_settings
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
return None
|
||||
|
||||
# Handle backward compatibility with boolean
|
||||
if isinstance(context, bool):
|
||||
return context if context else None
|
||||
elif isinstance(context, dict) and "maxCharacters" in context:
|
||||
return {"maxCharacters": context["maxCharacters"]}
|
||||
elif isinstance(context, ContextDisabled):
|
||||
return None # Don't send context field at all when disabled
|
||||
elif isinstance(context, ContextEnabled):
|
||||
return True
|
||||
elif isinstance(context, ContextAdvanced):
|
||||
if context.max_characters:
|
||||
return {"maxCharacters": context.max_characters}
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
def format_date_fields(
|
||||
input_data: Any, date_field_mapping: Dict[str, str]
|
||||
) -> Dict[str, str]:
|
||||
"""Format datetime fields for API payload."""
|
||||
formatted_dates = {}
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
formatted_dates[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
return formatted_dates
|
||||
|
||||
|
||||
def add_optional_fields(
|
||||
input_data: Any,
|
||||
field_mapping: Dict[str, str],
|
||||
payload: Dict[str, Any],
|
||||
process_enums: bool = False,
|
||||
) -> None:
|
||||
"""Add optional fields to payload if they have values."""
|
||||
for input_field, api_field in field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value: # Only add non-empty values
|
||||
if process_enums and hasattr(value, "value"):
|
||||
payload[api_field] = value.value
|
||||
else:
|
||||
payload[api_field] = value
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Enum definitions based on available options
|
||||
class WebsetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WebsetSearchStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known, based on example it's "created"
|
||||
|
||||
|
||||
class ImportStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
CSV = "csv"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
ENABLED = "enabled"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
SEARCH = "search"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorRunStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class CanceledReason(str, Enum):
|
||||
WEBSET_DELETED = "webset_deleted"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class FailedReason(str, Enum):
|
||||
INVALID_FORMAT = "invalid_format"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class Confidence(str, Enum):
|
||||
HIGH = "high"
|
||||
# Add more if known
|
||||
|
||||
|
||||
# Nested models
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Criterion(BaseModel):
|
||||
description: str
|
||||
successRate: Optional[int] = None
|
||||
|
||||
|
||||
class ExcludeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
definition: str
|
||||
limit: Optional[float] = None
|
||||
|
||||
|
||||
class ScopeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
relationship: Optional[Relationship] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
found: int
|
||||
analyzed: int
|
||||
completion: int
|
||||
timeLeft: int
|
||||
|
||||
|
||||
class Bounds(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class Expected(BaseModel):
|
||||
total: int
|
||||
confidence: str = Field(default="high") # Use str or Confidence enum
|
||||
bounds: Bounds
|
||||
|
||||
|
||||
class Recall(BaseModel):
|
||||
expected: Expected
|
||||
reasoning: str
|
||||
|
||||
|
||||
class WebsetSearch(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_search")
|
||||
status: str = Field(default="created") # Or use WebsetSearchStatus
|
||||
websetId: str
|
||||
query: str
|
||||
entity: Entity
|
||||
criteria: List[Criterion]
|
||||
count: int
|
||||
behavior: str = Field(default="override")
|
||||
exclude: List[ExcludeItem]
|
||||
scope: List[ScopeItem]
|
||||
progress: Progress
|
||||
recall: Recall
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
canceledAt: Optional[datetime] = None
|
||||
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class ImportEntity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="import")
|
||||
status: str = Field(default="pending") # Or use ImportStatus
|
||||
format: str = Field(default="csv") # Or use ImportFormat
|
||||
entity: ImportEntity
|
||||
title: str
|
||||
count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
|
||||
failedAt: Optional[datetime] = None
|
||||
failedMessage: Optional[str] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
label: str
|
||||
|
||||
|
||||
class WebsetEnrichment(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_enrichment")
|
||||
status: str = Field(default="pending") # Or use EnrichmentStatus
|
||||
websetId: str
|
||||
title: str
|
||||
description: str
|
||||
format: str = Field(default="text") # Or use EnrichmentFormat
|
||||
options: List[Option]
|
||||
instructions: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Cadence(BaseModel):
|
||||
cron: str
|
||||
timezone: str = Field(default="Etc/UTC")
|
||||
|
||||
|
||||
class BehaviorConfig(BaseModel):
|
||||
query: Optional[str] = None
|
||||
criteria: Optional[List[Criterion]] = None
|
||||
entity: Optional[Entity] = None
|
||||
count: Optional[int] = None
|
||||
behavior: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class Behavior(BaseModel):
|
||||
type: str = Field(default="search") # Or use MonitorBehaviorType
|
||||
config: BehaviorConfig
|
||||
|
||||
|
||||
class MonitorRun(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor_run")
|
||||
status: str = Field(default="created") # Or use MonitorRunStatus
|
||||
monitorId: str
|
||||
type: str = Field(default="search")
|
||||
completedAt: Optional[datetime] = None
|
||||
failedAt: Optional[datetime] = None
|
||||
failedReason: Optional[str] = None
|
||||
canceledAt: Optional[datetime] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Monitor(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor")
|
||||
status: str = Field(default="enabled") # Or use MonitorStatus
|
||||
websetId: str
|
||||
cadence: Cadence
|
||||
behavior: Behavior
|
||||
lastRun: Optional[MonitorRun] = None
|
||||
nextRunAt: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset")
|
||||
status: WebsetStatus
|
||||
externalId: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
searches: List[WebsetSearch]
|
||||
imports: List[Import]
|
||||
enrichments: List[WebsetEnrichment]
|
||||
monitors: List[Monitor]
|
||||
streams: List[Any]
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ListWebsets(BaseModel):
|
||||
data: List[Webset]
|
||||
hasMore: bool
|
||||
nextCursor: Optional[str] = None
|
||||
518
autogpt_platform/backend/backend/blocks/exa/research.py
Normal file
518
autogpt_platform/backend/backend/blocks/exa/research.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""
|
||||
Exa Research Task Blocks
|
||||
|
||||
Provides asynchronous research capabilities that explore the web, gather sources,
|
||||
synthesize findings, and return structured results with citations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class ResearchModel(str, Enum):
|
||||
"""Available research models."""
|
||||
|
||||
FAST = "exa-research-fast"
|
||||
STANDARD = "exa-research"
|
||||
PRO = "exa-research-pro"
|
||||
|
||||
|
||||
class ResearchStatus(str, Enum):
|
||||
"""Research task status."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ResearchCostModel(BaseModel):
|
||||
"""Cost breakdown for a research request."""
|
||||
|
||||
total: float
|
||||
num_searches: int
|
||||
num_pages: int
|
||||
reasoning_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "ResearchCostModel":
|
||||
"""Convert API response, rounding fractional counts to integers."""
|
||||
return cls(
|
||||
total=data.get("total", 0.0),
|
||||
num_searches=int(round(data.get("numSearches", 0))),
|
||||
num_pages=int(round(data.get("numPages", 0))),
|
||||
reasoning_tokens=int(round(data.get("reasoningTokens", 0))),
|
||||
)
|
||||
|
||||
|
||||
class ResearchOutputModel(BaseModel):
|
||||
"""Research output with content and optional structured data."""
|
||||
|
||||
content: str
|
||||
parsed: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ResearchTaskModel(BaseModel):
|
||||
"""Stable output model for research tasks."""
|
||||
|
||||
research_id: str
|
||||
created_at: int
|
||||
model: str
|
||||
instructions: str
|
||||
status: str
|
||||
output_schema: Optional[Dict[str, Any]] = None
|
||||
output: Optional[ResearchOutputModel] = None
|
||||
cost_dollars: Optional[ResearchCostModel] = None
|
||||
finished_at: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "ResearchTaskModel":
|
||||
"""Convert API response to our stable model."""
|
||||
output_data = data.get("output")
|
||||
output = None
|
||||
if output_data:
|
||||
output = ResearchOutputModel(
|
||||
content=output_data.get("content", ""),
|
||||
parsed=output_data.get("parsed"),
|
||||
)
|
||||
|
||||
cost_data = data.get("costDollars")
|
||||
cost = None
|
||||
if cost_data:
|
||||
cost = ResearchCostModel.from_api(cost_data)
|
||||
|
||||
return cls(
|
||||
research_id=data.get("researchId", ""),
|
||||
created_at=data.get("createdAt", 0),
|
||||
model=data.get("model", "exa-research"),
|
||||
instructions=data.get("instructions", ""),
|
||||
status=data.get("status", "pending"),
|
||||
output_schema=data.get("outputSchema"),
|
||||
output=output,
|
||||
cost_dollars=cost,
|
||||
finished_at=data.get("finishedAt"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
class ExaCreateResearchBlock(Block):
|
||||
"""Create an asynchronous research task that explores the web and synthesizes findings."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="Research instructions - clearly define what information to find, how to conduct research, and desired output format.",
|
||||
placeholder="Research the top 5 AI coding assistants, their features, pricing, and user reviews",
|
||||
)
|
||||
model: ResearchModel = SchemaField(
|
||||
default=ResearchModel.STANDARD,
|
||||
description="Research model: 'fast' for quick results, 'standard' for balanced quality, 'pro' for thorough analysis",
|
||||
)
|
||||
output_schema: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="JSON Schema to enforce structured output. When provided, results are validated and returned as parsed JSON.",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=True,
|
||||
description="Wait for research to complete before returning. Ensures you get results immediately.",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=600,
|
||||
description="Maximum time to wait for completion in seconds (only if wait_for_completion is True)",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=3600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(
|
||||
description="Unique identifier for tracking this research request"
|
||||
)
|
||||
status: str = SchemaField(description="Final status of the research")
|
||||
model: str = SchemaField(description="The research model used")
|
||||
instructions: str = SchemaField(
|
||||
description="The research instructions provided"
|
||||
)
|
||||
created_at: int = SchemaField(
|
||||
description="When the research was created (Unix timestamp in ms)"
|
||||
)
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (only if wait_for_completion was True and completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output (only if wait_for_completion and outputSchema were provided)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(
|
||||
description="Total cost in USD (only if wait_for_completion was True and completed)"
|
||||
)
|
||||
elapsed_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (only if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a1f2e3d4-c5b6-4a78-9012-3456789abcde",
|
||||
description="Create research task with optional waiting - explores web and synthesizes findings with citations",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.AI},
|
||||
input_schema=ExaCreateResearchBlock.Input,
|
||||
output_schema=ExaCreateResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/research/v1"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": input_data.model.value,
|
||||
"instructions": input_data.instructions,
|
||||
}
|
||||
|
||||
if input_data.output_schema:
|
||||
payload["outputSchema"] = input_data.output_schema
|
||||
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
research_id = data.get("researchId", "")
|
||||
|
||||
if input_data.wait_for_completion:
|
||||
start_time = time.time()
|
||||
get_url = f"https://api.exa.ai/research/v1/{research_id}"
|
||||
get_headers = {"x-api-key": credentials.api_key.get_secret_value()}
|
||||
check_interval = 10
|
||||
|
||||
while time.time() - start_time < input_data.polling_timeout:
|
||||
poll_response = await Requests().get(url=get_url, headers=get_headers)
|
||||
poll_data = poll_response.json()
|
||||
|
||||
status = poll_data.get("status", "")
|
||||
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
research = ResearchTaskModel.from_api(poll_data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "status", research.status
|
||||
yield "model", research.model
|
||||
yield "instructions", research.instructions
|
||||
yield "created_at", research.created_at
|
||||
yield "elapsed_time", elapsed
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
raise ValueError(
|
||||
f"Research did not complete within {input_data.polling_timeout} seconds"
|
||||
)
|
||||
else:
|
||||
yield "research_id", research_id
|
||||
yield "status", data.get("status", "pending")
|
||||
yield "model", data.get("model", input_data.model.value)
|
||||
yield "instructions", data.get("instructions", input_data.instructions)
|
||||
yield "created_at", data.get("createdAt", 0)
|
||||
|
||||
|
||||
class ExaGetResearchBlock(Block):
|
||||
"""Get the status and results of a research task."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
research_id: str = SchemaField(
|
||||
description="The ID of the research task to retrieve",
|
||||
placeholder="01jszdfs0052sg4jc552sg4jc5",
|
||||
)
|
||||
include_events: bool = SchemaField(
|
||||
default=False,
|
||||
description="Include detailed event log of research operations",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(description="The research task identifier")
|
||||
status: str = SchemaField(
|
||||
description="Current status: pending, running, completed, canceled, or failed"
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="The original research instructions"
|
||||
)
|
||||
model: str = SchemaField(description="The research model used")
|
||||
created_at: int = SchemaField(
|
||||
description="When research was created (Unix timestamp in ms)"
|
||||
)
|
||||
finished_at: Optional[int] = SchemaField(
|
||||
description="When research finished (Unix timestamp in ms, if completed/canceled/failed)"
|
||||
)
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (if completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output matching outputSchema (if provided and completed)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(
|
||||
description="Total cost in USD (if completed)"
|
||||
)
|
||||
cost_searches: Optional[int] = SchemaField(
|
||||
description="Number of searches performed (if completed)"
|
||||
)
|
||||
cost_pages: Optional[int] = SchemaField(
|
||||
description="Number of pages crawled (if completed)"
|
||||
)
|
||||
cost_reasoning_tokens: Optional[int] = SchemaField(
|
||||
description="AI tokens used for reasoning (if completed)"
|
||||
)
|
||||
error_message: Optional[str] = SchemaField(
|
||||
description="Error message if research failed"
|
||||
)
|
||||
events: Optional[List[dict]] = SchemaField(
|
||||
description="Detailed event log (if include_events was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b2e3f4a5-6789-4bcd-9012-3456789abcde",
|
||||
description="Get status and results of a research task",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetResearchBlock.Input,
|
||||
output_schema=ExaGetResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/research/v1/{input_data.research_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.include_events:
|
||||
params["events"] = "true"
|
||||
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
research = ResearchTaskModel.from_api(data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "status", research.status
|
||||
yield "instructions", research.instructions
|
||||
yield "model", research.model
|
||||
yield "created_at", research.created_at
|
||||
yield "finished_at", research.finished_at
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
if input_data.include_events:
|
||||
yield "events", data.get("events", [])
|
||||
|
||||
|
||||
class ExaWaitForResearchBlock(Block):
|
||||
"""Wait for a research task to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
research_id: str = SchemaField(
|
||||
description="The ID of the research task to wait for",
|
||||
placeholder="01jszdfs0052sg4jc552sg4jc5",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=600,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=3600,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=10,
|
||||
description="Seconds between status checks",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(description="The research task identifier")
|
||||
final_status: str = SchemaField(description="Final status when polling stopped")
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (if completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output (if outputSchema was provided and completed)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(description="Total cost in USD")
|
||||
elapsed_time: float = SchemaField(description="Total time waited in seconds")
|
||||
timed_out: bool = SchemaField(
|
||||
description="Whether polling timed out before completion"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c3d4e5f6-7890-4abc-9012-3456789abcde",
|
||||
description="Wait for a research task to complete with configurable timeout",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForResearchBlock.Input,
|
||||
output_schema=ExaWaitForResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
url = f"https://api.exa.ai/research/v1/{input_data.research_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
status = data.get("status", "")
|
||||
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
research = ResearchTaskModel.from_api(data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "final_status", research.status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", False
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
|
||||
return
|
||||
|
||||
await asyncio.sleep(input_data.check_interval)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "research_id", input_data.research_id
|
||||
yield "final_status", data.get("status", "unknown")
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
|
||||
class ExaListResearchBlock(Block):
|
||||
"""List all research tasks with pagination support."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of research tasks to return (1-50)",
|
||||
ge=1,
|
||||
le=50,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_tasks: List[ResearchTaskModel] = SchemaField(
|
||||
description="List of research tasks ordered by creation time (newest first)"
|
||||
)
|
||||
research_task: ResearchTaskModel = SchemaField(
|
||||
description="Individual research task (yielded for each task)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more tasks to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d4e5f6a7-8901-4bcd-9012-3456789abcde",
|
||||
description="List all research tasks with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListResearchBlock.Input,
|
||||
output_schema=ExaListResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/research/v1"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"limit": input_data.limit,
|
||||
}
|
||||
if input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
tasks = [ResearchTaskModel.from_api(task) for task in data.get("data", [])]
|
||||
|
||||
yield "research_tasks", tasks
|
||||
|
||||
for task in tasks:
|
||||
yield "research_task", task
|
||||
|
||||
yield "has_more", data.get("hasMore", False)
|
||||
yield "next_cursor", data.get("nextCursor")
|
||||
@@ -1,32 +1,66 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchTypes(Enum):
|
||||
KEYWORD = "keyword"
|
||||
NEURAL = "neural"
|
||||
FAST = "fast"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ExaSearchCategories(Enum):
|
||||
COMPANY = "company"
|
||||
RESEARCH_PAPER = "research paper"
|
||||
NEWS = "news"
|
||||
PDF = "pdf"
|
||||
GITHUB = "github"
|
||||
TWEET = "tweet"
|
||||
PERSONAL_SITE = "personal site"
|
||||
LINKEDIN_PROFILE = "linkedin profile"
|
||||
FINANCIAL_REPORT = "financial report"
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(description="The search query")
|
||||
use_auto_prompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt", default=True, advanced=True
|
||||
type: ExaSearchTypes = SchemaField(
|
||||
description="Type of search", default=ExaSearchTypes.AUTO, advanced=True
|
||||
)
|
||||
type: str = SchemaField(description="Type of search", default="", advanced=True)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within", default="", advanced=True
|
||||
category: ExaSearchCategories | None = SchemaField(
|
||||
description="Category to search within: company, research paper, news, pdf, github, tweet, personal site, linkedin profile, financial report",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
user_location: str | None = SchemaField(
|
||||
description="The two-letter ISO country code of the user (e.g., 'US')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
@@ -39,17 +73,17 @@ class ExaSearchBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content"
|
||||
start_crawl_date: datetime | None = SchemaField(
|
||||
description="Start date for crawled content", advanced=True, default=None
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content"
|
||||
end_crawl_date: datetime | None = SchemaField(
|
||||
description="End date for crawled content", advanced=True, default=None
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content"
|
||||
start_published_date: datetime | None = SchemaField(
|
||||
description="Start date for published content", advanced=True, default=None
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content"
|
||||
end_published_date: datetime | None = SchemaField(
|
||||
description="End date for published content", advanced=True, default=None
|
||||
)
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include", default_factory=list, advanced=True
|
||||
@@ -62,14 +96,30 @@ class ExaSearchBlock(Block):
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
moderation: bool = SchemaField(
|
||||
description="Enable content moderation to filter unsafe content from search results",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results", default_factory=list
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of search results"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
result: ExaSearchResults = SchemaField(description="Single search result")
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the search results ready for LLMs."
|
||||
)
|
||||
search_type: str = SchemaField(
|
||||
description="For auto searches, indicates which search type was selected."
|
||||
)
|
||||
resolved_search_type: str = SchemaField(
|
||||
description="The search type that was actually used for this request (neural or keyword)"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -83,51 +133,76 @@ class ExaSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
sdk_kwargs = {
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.use_auto_prompt,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.model_dump(),
|
||||
"num_results": input_data.number_of_results,
|
||||
}
|
||||
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
if input_data.type:
|
||||
sdk_kwargs["type"] = input_data.type.value
|
||||
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
if input_data.category:
|
||||
sdk_kwargs["category"] = input_data.category.value
|
||||
|
||||
optional_field_mapping = {
|
||||
"type": "type",
|
||||
"category": "category",
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
if input_data.user_location:
|
||||
sdk_kwargs["user_location"] = input_data.user_location
|
||||
|
||||
# Add other fields
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# Extract just the results array from the response
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
|
||||
if input_data.moderation:
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
|
||||
# heck if we need to use search_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if content_settings:
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.search_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
response = await aexa.search(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.resolved_search_type:
|
||||
yield "resolved_search_type", response.resolved_search_type
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
|
||||
class ExaFindSimilarBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -28,7 +35,7 @@ class ExaFindSimilarBlock(Block):
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: list[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
description="List of domains to include in the search. If specified, results will only come from these domains.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -37,17 +44,17 @@ class ExaFindSimilarBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content"
|
||||
start_crawl_date: Optional[datetime] = SchemaField(
|
||||
description="Start date for crawled content", advanced=True, default=None
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content"
|
||||
end_crawl_date: Optional[datetime] = SchemaField(
|
||||
description="End date for crawled content", advanced=True, default=None
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content"
|
||||
start_published_date: Optional[datetime] = SchemaField(
|
||||
description="Start date for published content", advanced=True, default=None
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content"
|
||||
end_published_date: Optional[datetime] = SchemaField(
|
||||
description="End date for published content", advanced=True, default=None
|
||||
)
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
@@ -64,15 +71,27 @@ class ExaFindSimilarBlock(Block):
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
moderation: bool = SchemaField(
|
||||
description="Enable content moderation to filter unsafe content from search results",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of similar documents with metadata and content"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single similar document result"
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs."
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -86,47 +105,65 @@ class ExaFindSimilarBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/findSimilar"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
sdk_kwargs = {
|
||||
"url": input_data.url,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.model_dump(),
|
||||
"num_results": input_data.number_of_results,
|
||||
}
|
||||
|
||||
optional_field_mapping = {
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
|
||||
# Add optional fields if they have values
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
if input_data.moderation:
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# check if we need to use find_similar_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if content_settings:
|
||||
# Use find_similar_and_contents when contents are requested
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.find_similar_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
response = await aexa.find_similar(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
@@ -84,7 +85,7 @@ class ExaWebsetWebhookBlock(Block):
|
||||
including creation, updates, searches, and exports.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="Exa API credentials for webhook management"
|
||||
)
|
||||
@@ -104,7 +105,7 @@ class ExaWebsetWebhookBlock(Block):
|
||||
description="Webhook payload data", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
event_type: str = SchemaField(description="Type of event that occurred")
|
||||
event_id: str = SchemaField(description="Unique identifier for this event")
|
||||
webset_id: str = SchemaField(description="ID of the affected webset")
|
||||
@@ -131,45 +132,33 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""Process incoming Exa webhook payload."""
|
||||
try:
|
||||
payload = input_data.payload
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(
|
||||
event_type, input_data.event_filter
|
||||
)
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(event_type, input_data.event_filter)
|
||||
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors gracefully
|
||||
yield "event_type", "error"
|
||||
yield "event_id", ""
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "data", {"error": str(e)}
|
||||
yield "timestamp", ""
|
||||
yield "metadata", {}
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
def _should_process_event(
|
||||
self, event_type: str, event_filter: WebsetEventFilter
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,554 @@
|
||||
"""
|
||||
Exa Websets Enrichment Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing enrichments on webset items,
|
||||
allowing extraction of additional structured data from existing items.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetEnrichment as SdkWebsetEnrichment
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
class WebsetEnrichmentModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetEnrichment."""
|
||||
|
||||
id: str
|
||||
webset_id: str
|
||||
status: str
|
||||
title: Optional[str]
|
||||
description: str
|
||||
format: str
|
||||
options: List[str]
|
||||
instructions: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, enrichment: SdkWebsetEnrichment) -> "WebsetEnrichmentModel":
|
||||
"""Convert SDK WebsetEnrichment to our stable model."""
|
||||
# Extract options
|
||||
options_list = []
|
||||
if enrichment.options:
|
||||
for option in enrichment.options:
|
||||
option_dict = option.model_dump(by_alias=True)
|
||||
options_list.append(option_dict.get("label", ""))
|
||||
|
||||
return cls(
|
||||
id=enrichment.id,
|
||||
webset_id=enrichment.webset_id,
|
||||
status=(
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
),
|
||||
title=enrichment.title,
|
||||
description=enrichment.description,
|
||||
format=(
|
||||
enrichment.format.value
|
||||
if enrichment.format and hasattr(enrichment.format, "value")
|
||||
else "text"
|
||||
),
|
||||
options=options_list,
|
||||
instructions=enrichment.instructions,
|
||||
metadata=enrichment.metadata if enrichment.metadata else {},
|
||||
created_at=(
|
||||
enrichment.created_at.isoformat() if enrichment.created_at else ""
|
||||
),
|
||||
updated_at=(
|
||||
enrichment.updated_at.isoformat() if enrichment.updated_at else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
"""Format types for enrichment responses."""
|
||||
|
||||
TEXT = "text" # Free text response
|
||||
DATE = "date" # Date/datetime format
|
||||
NUMBER = "number" # Numeric value
|
||||
OPTIONS = "options" # Multiple choice from provided options
|
||||
EMAIL = "email" # Email address format
|
||||
PHONE = "phone" # Phone number format
|
||||
|
||||
|
||||
class ExaCreateEnrichmentBlock(Block):
|
||||
"""Create a new enrichment to extract additional data from webset items."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="What data to extract from each item",
|
||||
placeholder="Extract the company's main product or service offering",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Short title for this enrichment (auto-generated if not provided)",
|
||||
placeholder="Main Product",
|
||||
)
|
||||
format: EnrichmentFormat = SchemaField(
|
||||
default=EnrichmentFormat.TEXT,
|
||||
description="Expected format of the extracted data",
|
||||
)
|
||||
options: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Available options when format is 'options'",
|
||||
placeholder='["B2B", "B2C", "Both", "Unknown"]',
|
||||
advanced=True,
|
||||
)
|
||||
apply_to_existing: bool = SchemaField(
|
||||
default=True,
|
||||
description="Apply this enrichment to existing items in the webset",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the enrichment",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for the enrichment to complete on existing items",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait for completion in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the created enrichment"
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The webset this enrichment belongs to"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(
|
||||
description="Description of what data is extracted"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the extracted data")
|
||||
instructions: str = SchemaField(
|
||||
description="Generated instructions for the enrichment"
|
||||
)
|
||||
items_enriched: Optional[int] = SchemaField(
|
||||
description="Number of items enriched (if wait_for_completion was True)"
|
||||
)
|
||||
completion_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="71146ae8-0cb1-4a15-8cde-eae30de71cb6",
|
||||
description="Create enrichments to extract additional structured data from webset items",
|
||||
categories={BlockCategory.AI, BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateEnrichmentBlock.Input,
|
||||
output_schema=ExaCreateEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import time
|
||||
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"description": input_data.description,
|
||||
"format": input_data.format.value,
|
||||
}
|
||||
|
||||
# Add title if provided
|
||||
if input_data.title:
|
||||
payload["title"] = input_data.title
|
||||
|
||||
# Add options for 'options' format
|
||||
if input_data.format == EnrichmentFormat.OPTIONS and input_data.options:
|
||||
payload["options"] = [{"label": opt} for opt in input_data.options]
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
enrichment_id = sdk_enrichment.id
|
||||
status = (
|
||||
sdk_enrichment.status.value
|
||||
if hasattr(sdk_enrichment.status, "value")
|
||||
else str(sdk_enrichment.status)
|
||||
)
|
||||
|
||||
# If wait_for_completion is True and apply_to_existing is True, poll for completion
|
||||
if input_data.wait_for_completion and input_data.apply_to_existing:
|
||||
import asyncio
|
||||
|
||||
poll_interval = 5
|
||||
max_interval = 30
|
||||
poll_start = time.time()
|
||||
items_enriched = 0
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_enrich = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
current_status = (
|
||||
current_enrich.status.value
|
||||
if hasattr(current_enrich.status, "value")
|
||||
else str(current_enrich.status)
|
||||
)
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
items_enriched += search.progress.found
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", current_status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
yield "items_enriched", items_enriched
|
||||
yield "completion_time", completion_time
|
||||
return
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
poll_interval = min(poll_interval * 1.5, max_interval)
|
||||
|
||||
# Timeout
|
||||
completion_time = time.time() - start_time
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
yield "items_enriched", 0
|
||||
yield "completion_time", completion_time
|
||||
else:
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
|
||||
|
||||
class ExaGetEnrichmentBlock(Block):
|
||||
"""Get the status and details of a webset enrichment."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to retrieve",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(
|
||||
description="Description of what data is extracted"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the extracted data")
|
||||
options: list[str] = SchemaField(
|
||||
description="Available options (for 'options' format)"
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="Generated instructions for the enrichment"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the enrichment was created")
|
||||
updated_at: str = SchemaField(
|
||||
description="When the enrichment was last updated"
|
||||
)
|
||||
metadata: dict = SchemaField(description="Metadata attached to the enrichment")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b8c9d0e1-f2a3-4567-89ab-cdef01234567",
|
||||
description="Get the status and details of a webset enrichment",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetEnrichmentBlock.Input,
|
||||
output_schema=ExaGetEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
enrichment = WebsetEnrichmentModel.from_sdk(sdk_enrichment)
|
||||
|
||||
yield "enrichment_id", enrichment.id
|
||||
yield "status", enrichment.status
|
||||
yield "title", enrichment.title
|
||||
yield "description", enrichment.description
|
||||
yield "format", enrichment.format
|
||||
yield "options", enrichment.options
|
||||
yield "instructions", enrichment.instructions
|
||||
yield "created_at", enrichment.created_at
|
||||
yield "updated_at", enrichment.updated_at
|
||||
yield "metadata", enrichment.metadata
|
||||
|
||||
|
||||
class ExaUpdateEnrichmentBlock(Block):
|
||||
"""Update an existing enrichment configuration."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to update",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New description for what data to extract",
|
||||
)
|
||||
format: Optional[EnrichmentFormat] = SchemaField(
|
||||
default=None,
|
||||
description="New format for the extracted data",
|
||||
)
|
||||
options: Optional[list[str]] = SchemaField(
|
||||
default=None,
|
||||
description="New options when format is 'options'",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="New metadata to attach to the enrichment",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(description="Updated description")
|
||||
format: str = SchemaField(description="Updated format")
|
||||
success: str = SchemaField(description="Whether the update was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c8d5c5fb-9684-4a29-bd2a-5b38d71776c9",
|
||||
description="Update an existing enrichment configuration",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateEnrichmentBlock.Input,
|
||||
output_schema=ExaUpdateEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/enrichments/{input_data.enrichment_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the update payload
|
||||
payload = {}
|
||||
|
||||
if input_data.description is not None:
|
||||
payload["description"] = input_data.description
|
||||
|
||||
if input_data.format is not None:
|
||||
payload["format"] = input_data.format.value
|
||||
|
||||
if input_data.options is not None:
|
||||
payload["options"] = [{"label": opt} for opt in input_data.options]
|
||||
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
try:
|
||||
response = await Requests().patch(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "enrichment_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "title", data.get("title", "")
|
||||
yield "description", data.get("description", "")
|
||||
yield "format", data.get("format", "")
|
||||
yield "success", "true"
|
||||
|
||||
except ValueError as e:
|
||||
# Re-raise user input validation errors
|
||||
raise ValueError(f"Failed to update enrichment: {e}") from e
|
||||
# Let all other exceptions propagate naturally
|
||||
|
||||
|
||||
class ExaDeleteEnrichmentBlock(Block):
|
||||
"""Delete an enrichment from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to delete",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(description="The ID of the deleted enrichment")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b250de56-2ca6-4237-a7b8-b5684892189f",
|
||||
description="Delete an enrichment from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteEnrichmentBlock.Input,
|
||||
output_schema=ExaDeleteEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
yield "enrichment_id", deleted_enrichment.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaCancelEnrichmentBlock(Block):
|
||||
"""Cancel a running enrichment operation."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to cancel",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the canceled enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Status after cancellation")
|
||||
items_enriched_before_cancel: int = SchemaField(
|
||||
description="Approximate number of items enriched before cancellation"
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7e1f8f0f-b6ab-43b3-bd1d-0c534a649295",
|
||||
description="Cancel a running enrichment operation",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelEnrichmentBlock.Input,
|
||||
output_schema=ExaCancelEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
|
||||
for sdk_item in items_response.data:
|
||||
# Check if this enrichment is present
|
||||
for enrich_result in sdk_item.enrichments:
|
||||
if enrich_result.enrichment_id == input_data.enrichment_id:
|
||||
items_enriched += 1
|
||||
break
|
||||
|
||||
status = (
|
||||
canceled_enrichment.status.value
|
||||
if hasattr(canceled_enrichment.status, "value")
|
||||
else str(canceled_enrichment.status)
|
||||
)
|
||||
|
||||
yield "enrichment_id", canceled_enrichment.id
|
||||
yield "status", status
|
||||
yield "items_enriched_before_cancel", items_enriched
|
||||
yield "success", "true"
|
||||
@@ -0,0 +1,676 @@
|
||||
"""
|
||||
Exa Websets Import/Export Management Blocks
|
||||
|
||||
This module provides blocks for importing data into websets from CSV files
|
||||
and exporting webset data in various formats.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import Optional, Union
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import CreateImportResponse
|
||||
from exa_py.websets.types import Import as SdkImport
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class ImportModel(BaseModel):
|
||||
"""Stable output model mirroring SDK Import."""
|
||||
|
||||
id: str
|
||||
status: str
|
||||
title: str
|
||||
format: str
|
||||
entity_type: str
|
||||
count: int
|
||||
upload_url: Optional[str] # Only in CreateImportResponse
|
||||
upload_valid_until: Optional[str] # Only in CreateImportResponse
|
||||
failed_reason: str
|
||||
failed_message: str
|
||||
metadata: dict
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(
|
||||
cls, import_obj: Union[SdkImport, CreateImportResponse]
|
||||
) -> "ImportModel":
|
||||
"""Convert SDK Import or CreateImportResponse to our stable model."""
|
||||
# Extract entity type from union (may be None)
|
||||
entity_type = "unknown"
|
||||
if import_obj.entity:
|
||||
entity_dict = import_obj.entity.model_dump(by_alias=True, exclude_none=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Handle status enum
|
||||
status_str = (
|
||||
import_obj.status.value
|
||||
if hasattr(import_obj.status, "value")
|
||||
else str(import_obj.status)
|
||||
)
|
||||
|
||||
# Handle format enum
|
||||
format_str = (
|
||||
import_obj.format.value
|
||||
if hasattr(import_obj.format, "value")
|
||||
else str(import_obj.format)
|
||||
)
|
||||
|
||||
# Handle failed_reason enum (may be None or enum)
|
||||
failed_reason_str = ""
|
||||
if import_obj.failed_reason:
|
||||
failed_reason_str = (
|
||||
import_obj.failed_reason.value
|
||||
if hasattr(import_obj.failed_reason, "value")
|
||||
else str(import_obj.failed_reason)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=import_obj.id,
|
||||
status=status_str,
|
||||
title=import_obj.title or "",
|
||||
format=format_str,
|
||||
entity_type=entity_type,
|
||||
count=int(import_obj.count or 0),
|
||||
upload_url=getattr(
|
||||
import_obj, "upload_url", None
|
||||
), # Only in CreateImportResponse
|
||||
upload_valid_until=getattr(
|
||||
import_obj, "upload_valid_until", None
|
||||
), # Only in CreateImportResponse
|
||||
failed_reason=failed_reason_str,
|
||||
failed_message=import_obj.failed_message or "",
|
||||
metadata=import_obj.metadata or {},
|
||||
created_at=(
|
||||
import_obj.created_at.isoformat() if import_obj.created_at else ""
|
||||
),
|
||||
updated_at=(
|
||||
import_obj.updated_at.isoformat() if import_obj.updated_at else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
"""Supported import formats."""
|
||||
|
||||
CSV = "csv"
|
||||
# JSON = "json" # Future support
|
||||
|
||||
|
||||
class ImportEntityType(str, Enum):
|
||||
"""Entity types for imports."""
|
||||
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Supported export formats."""
|
||||
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
JSON_LINES = "jsonl"
|
||||
|
||||
|
||||
class ExaCreateImportBlock(Block):
|
||||
"""Create an import to load external data that can be used with websets."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title for this import",
|
||||
placeholder="Customer List Import",
|
||||
)
|
||||
csv_data: str = SchemaField(
|
||||
description="CSV data to import (as a string)",
|
||||
placeholder="name,url\nAcme Corp,https://acme.com\nExample Inc,https://example.com",
|
||||
)
|
||||
entity_type: ImportEntityType = SchemaField(
|
||||
default=ImportEntityType.COMPANY,
|
||||
description="Type of entities being imported",
|
||||
)
|
||||
entity_description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type",
|
||||
advanced=True,
|
||||
)
|
||||
identifier_column: int = SchemaField(
|
||||
default=0,
|
||||
description="Column index containing the identifier (0-based)",
|
||||
ge=0,
|
||||
)
|
||||
url_column: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Column index containing URLs (optional)",
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the import",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(
|
||||
description="The unique identifier for the created import"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the import")
|
||||
title: str = SchemaField(description="Title of the import")
|
||||
count: int = SchemaField(description="Number of items in the import")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (only if csv_data not provided in request)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (only if upload_url is provided)"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the import was created")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="020a35d8-8a53-4e60-8b60-1de5cbab1df3",
|
||||
description="Import CSV data to use with websets for targeted searches",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaCreateImportBlock.Input,
|
||||
output_schema=ExaCreateImportBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"title": "Test Import",
|
||||
"csv_data": "name,url\nAcme,https://acme.com",
|
||||
"entity_type": ImportEntityType.COMPANY,
|
||||
"identifier_column": 0,
|
||||
},
|
||||
test_output=[
|
||||
("import_id", "import-123"),
|
||||
("status", "pending"),
|
||||
("title", "Test Import"),
|
||||
("count", 1),
|
||||
("entity_type", "company"),
|
||||
("upload_url", None),
|
||||
("upload_valid_until", None),
|
||||
("created_at", "2024-01-01T00:00:00"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK import object
|
||||
mock_import = MagicMock()
|
||||
mock_import.id = "import-123"
|
||||
mock_import.status = MagicMock(value="pending")
|
||||
mock_import.title = "Test Import"
|
||||
mock_import.format = MagicMock(value="csv")
|
||||
mock_import.count = 1
|
||||
mock_import.upload_url = None
|
||||
mock_import.upload_valid_until = None
|
||||
mock_import.failed_reason = None
|
||||
mock_import.failed_message = ""
|
||||
mock_import.metadata = {}
|
||||
mock_import.created_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_import.updated_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
|
||||
# Mock entity
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.model_dump = MagicMock(return_value={"type": "company"})
|
||||
mock_import.entity = mock_entity
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
csv_reader = csv.reader(StringIO(input_data.csv_data))
|
||||
rows = list(csv_reader)
|
||||
count = len(rows) - 1 if len(rows) > 1 else 0
|
||||
|
||||
size = len(input_data.csv_data.encode("utf-8"))
|
||||
|
||||
payload = {
|
||||
"title": input_data.title,
|
||||
"format": ImportFormat.CSV.value,
|
||||
"count": count,
|
||||
"size": size,
|
||||
"csv": {
|
||||
"identifier": input_data.identifier_column,
|
||||
},
|
||||
}
|
||||
|
||||
# Add URL column if specified
|
||||
if input_data.url_column is not None:
|
||||
payload["csv"]["url"] = input_data.url_column
|
||||
|
||||
# Add entity configuration
|
||||
entity = {"type": input_data.entity_type.value}
|
||||
if (
|
||||
input_data.entity_type == ImportEntityType.CUSTOM
|
||||
and input_data.entity_description
|
||||
):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_import = aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
yield "import_id", import_obj.id
|
||||
yield "status", import_obj.status
|
||||
yield "title", import_obj.title
|
||||
yield "count", import_obj.count
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "created_at", import_obj.created_at
|
||||
|
||||
|
||||
class ExaGetImportBlock(Block):
|
||||
"""Get the status and details of an import."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
import_id: str = SchemaField(
|
||||
description="The ID of the import to retrieve",
|
||||
placeholder="import-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(description="The unique identifier for the import")
|
||||
status: str = SchemaField(description="Current status of the import")
|
||||
title: str = SchemaField(description="Title of the import")
|
||||
format: str = SchemaField(description="Format of the imported data")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
count: int = SchemaField(description="Number of items imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (if import not yet uploaded)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (if applicable)"
|
||||
)
|
||||
failed_reason: Optional[str] = SchemaField(
|
||||
description="Reason for failure (if applicable)"
|
||||
)
|
||||
failed_message: Optional[str] = SchemaField(
|
||||
description="Detailed failure message (if applicable)"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the import was created")
|
||||
updated_at: str = SchemaField(description="When the import was last updated")
|
||||
metadata: dict = SchemaField(description="Metadata attached to the import")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="236663c8-a8dc-45f7-a050-2676bb0a3dd2",
|
||||
description="Get the status and details of an import",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaGetImportBlock.Input,
|
||||
output_schema=ExaGetImportBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
# Yield all fields
|
||||
yield "import_id", import_obj.id
|
||||
yield "status", import_obj.status
|
||||
yield "title", import_obj.title
|
||||
yield "format", import_obj.format
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "count", import_obj.count
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "failed_reason", import_obj.failed_reason
|
||||
yield "failed_message", import_obj.failed_message
|
||||
yield "created_at", import_obj.created_at
|
||||
yield "updated_at", import_obj.updated_at
|
||||
yield "metadata", import_obj.metadata
|
||||
|
||||
|
||||
class ExaListImportsBlock(Block):
|
||||
"""List all imports with pagination."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of imports to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
imports: list[dict] = SchemaField(description="List of imports")
|
||||
import_item: dict = SchemaField(
|
||||
description="Individual import (yielded for each import)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more imports to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="65323630-f7e9-4692-a624-184ba14c0686",
|
||||
description="List all imports with pagination support",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaListImportsBlock.Input,
|
||||
output_schema=ExaListImportsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.imports.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
# Convert SDK imports to our stable models
|
||||
imports = [ImportModel.from_sdk(i) for i in response.data]
|
||||
|
||||
yield "imports", [i.model_dump() for i in imports]
|
||||
|
||||
for import_obj in imports:
|
||||
yield "import_item", import_obj.model_dump()
|
||||
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
|
||||
|
||||
class ExaDeleteImportBlock(Block):
|
||||
"""Delete an import."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
import_id: str = SchemaField(
|
||||
description="The ID of the import to delete",
|
||||
placeholder="import-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(description="The ID of the deleted import")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="81ae30ed-c7ba-4b5d-8483-b726846e570c",
|
||||
description="Delete an import",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaDeleteImportBlock.Input,
|
||||
output_schema=ExaDeleteImportBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaExportWebsetBlock(Block):
|
||||
"""Export all data from a webset in various formats."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to export",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
format: ExportFormat = SchemaField(
|
||||
default=ExportFormat.JSON,
|
||||
description="Export format",
|
||||
)
|
||||
include_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include full content in export",
|
||||
)
|
||||
include_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include enrichment data in export",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of items to export",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
export_data: str = SchemaField(
|
||||
description="Exported data in the requested format"
|
||||
)
|
||||
item_count: int = SchemaField(description="Number of items exported")
|
||||
total_items: int = SchemaField(
|
||||
description="Total number of items in the webset"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the export was truncated due to max_items limit"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the exported data")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5da9d0fd-4b5b-4318-8302-8f71d0ccce9d",
|
||||
description="Export webset data in JSON, CSV, or JSON Lines format",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaExportWebsetBlock.Input,
|
||||
output_schema=ExaExportWebsetBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"webset_id": "test-webset",
|
||||
"format": ExportFormat.JSON,
|
||||
"include_content": True,
|
||||
"include_enrichments": True,
|
||||
"max_items": 10,
|
||||
},
|
||||
test_output=[
|
||||
("export_data", str),
|
||||
("item_count", 2),
|
||||
("total_items", 2),
|
||||
("truncated", False),
|
||||
("format", "json"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock webset items
|
||||
mock_item1 = MagicMock()
|
||||
mock_item1.model_dump = MagicMock(
|
||||
return_value={
|
||||
"id": "item-1",
|
||||
"url": "https://example.com",
|
||||
"title": "Test Item 1",
|
||||
}
|
||||
)
|
||||
|
||||
mock_item2 = MagicMock()
|
||||
mock_item2.model_dump = MagicMock(
|
||||
return_value={
|
||||
"id": "item-2",
|
||||
"url": "https://example.org",
|
||||
"title": "Test Item 2",
|
||||
}
|
||||
)
|
||||
|
||||
# Create mock iterator
|
||||
mock_items = [mock_item1, mock_item2]
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
all_items = []
|
||||
|
||||
# Use SDK's list_all iterator to fetch items
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
# Convert to dict for export
|
||||
item_dict = sdk_item.model_dump(by_alias=True, exclude_none=True)
|
||||
all_items.append(item_dict)
|
||||
|
||||
# Calculate total and truncated
|
||||
total_items = len(all_items) # SDK doesn't provide total count
|
||||
truncated = len(all_items) >= input_data.max_items
|
||||
|
||||
# Process items based on include flags
|
||||
if not input_data.include_content:
|
||||
for item in all_items:
|
||||
item.pop("content", None)
|
||||
|
||||
if not input_data.include_enrichments:
|
||||
for item in all_items:
|
||||
item.pop("enrichments", None)
|
||||
|
||||
# Format the export data
|
||||
export_data = ""
|
||||
|
||||
if input_data.format == ExportFormat.JSON:
|
||||
export_data = json.dumps(all_items, indent=2, default=str)
|
||||
|
||||
elif input_data.format == ExportFormat.JSON_LINES:
|
||||
lines = [json.dumps(item, default=str) for item in all_items]
|
||||
export_data = "\n".join(lines)
|
||||
|
||||
elif input_data.format == ExportFormat.CSV:
|
||||
# Extract all unique keys for CSV headers
|
||||
all_keys = set()
|
||||
for item in all_items:
|
||||
all_keys.update(self._flatten_dict(item).keys())
|
||||
|
||||
# Create CSV
|
||||
output = StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=sorted(all_keys))
|
||||
writer.writeheader()
|
||||
|
||||
for item in all_items:
|
||||
flat_item = self._flatten_dict(item)
|
||||
writer.writerow(flat_item)
|
||||
|
||||
export_data = output.getvalue()
|
||||
|
||||
yield "export_data", export_data
|
||||
yield "item_count", len(all_items)
|
||||
yield "total_items", total_items
|
||||
yield "truncated", truncated
|
||||
yield "format", input_data.format.value
|
||||
|
||||
except ValueError as e:
|
||||
# Re-raise user input validation errors
|
||||
raise ValueError(f"Failed to export webset: {e}") from e
|
||||
# Let all other exceptions propagate naturally
|
||||
|
||||
def _flatten_dict(self, d: dict, parent_key: str = "", sep: str = "_") -> dict:
|
||||
"""Flatten nested dictionaries for CSV export."""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
|
||||
elif isinstance(v, list):
|
||||
# Convert lists to JSON strings for CSV
|
||||
items.append((new_key, json.dumps(v, default=str)))
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
591
autogpt_platform/backend/backend/blocks/exa/websets_items.py
Normal file
591
autogpt_platform/backend/backend/blocks/exa/websets_items.py
Normal file
@@ -0,0 +1,591 @@
|
||||
"""
|
||||
Exa Websets Item Management Blocks
|
||||
|
||||
This module provides blocks for managing items within Exa websets, including
|
||||
retrieving, listing, deleting, and bulk operations on webset items.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetItem as SdkWebsetItem
|
||||
from exa_py.websets.types import (
|
||||
WebsetItemArticleProperties,
|
||||
WebsetItemCompanyProperties,
|
||||
WebsetItemCustomProperties,
|
||||
WebsetItemPersonProperties,
|
||||
WebsetItemResearchPaperProperties,
|
||||
)
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for enrichment results
|
||||
class EnrichmentResultModel(BaseModel):
|
||||
"""Stable output model mirroring SDK EnrichmentResult."""
|
||||
|
||||
enrichment_id: str
|
||||
format: str
|
||||
result: Optional[List[str]]
|
||||
reasoning: Optional[str]
|
||||
references: List[Dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_enrich) -> "EnrichmentResultModel":
|
||||
"""Convert SDK EnrichmentResult to our model."""
|
||||
format_str = (
|
||||
sdk_enrich.format.value
|
||||
if hasattr(sdk_enrich.format, "value")
|
||||
else str(sdk_enrich.format)
|
||||
)
|
||||
|
||||
# Convert references to dicts
|
||||
references_list = []
|
||||
if sdk_enrich.references:
|
||||
for ref in sdk_enrich.references:
|
||||
references_list.append(ref.model_dump(by_alias=True, exclude_none=True))
|
||||
|
||||
return cls(
|
||||
enrichment_id=sdk_enrich.enrichment_id,
|
||||
format=format_str,
|
||||
result=sdk_enrich.result,
|
||||
reasoning=sdk_enrich.reasoning,
|
||||
references=references_list,
|
||||
)
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class WebsetItemModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetItem."""
|
||||
|
||||
id: str
|
||||
url: Optional[AnyUrl]
|
||||
title: str
|
||||
content: str
|
||||
entity_data: Dict[str, Any]
|
||||
enrichments: Dict[str, EnrichmentResultModel]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, item: SdkWebsetItem) -> "WebsetItemModel":
|
||||
"""Convert SDK WebsetItem to our stable model."""
|
||||
# Extract properties from the union type
|
||||
properties_dict = {}
|
||||
url_value = None
|
||||
title = ""
|
||||
content = ""
|
||||
|
||||
if hasattr(item, "properties") and item.properties:
|
||||
properties_dict = item.properties.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# URL is always available on all property types
|
||||
url_value = item.properties.url
|
||||
|
||||
# Extract title using isinstance checks on the union type
|
||||
if isinstance(item.properties, WebsetItemPersonProperties):
|
||||
title = item.properties.person.name
|
||||
content = "" # Person type has no content
|
||||
elif isinstance(item.properties, WebsetItemCompanyProperties):
|
||||
title = item.properties.company.name
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemArticleProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemResearchPaperProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemCustomProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
else:
|
||||
# Fallback
|
||||
title = item.properties.description
|
||||
content = getattr(item.properties, "content", "")
|
||||
|
||||
# Convert enrichments from list to dict keyed by enrichment_id using Pydantic models
|
||||
enrichments_dict: Dict[str, EnrichmentResultModel] = {}
|
||||
if hasattr(item, "enrichments") and item.enrichments:
|
||||
for sdk_enrich in item.enrichments:
|
||||
enrich_model = EnrichmentResultModel.from_sdk(sdk_enrich)
|
||||
enrichments_dict[enrich_model.enrichment_id] = enrich_model
|
||||
|
||||
return cls(
|
||||
id=item.id,
|
||||
url=url_value,
|
||||
title=title,
|
||||
content=content or "",
|
||||
entity_data=properties_dict,
|
||||
enrichments=enrichments_dict,
|
||||
created_at=item.created_at.isoformat() if item.created_at else "",
|
||||
updated_at=item.updated_at.isoformat() if item.updated_at else "",
|
||||
)
|
||||
|
||||
|
||||
class ExaGetWebsetItemBlock(Block):
|
||||
"""Get a specific item from a webset by its ID."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
item_id: str = SchemaField(
|
||||
description="The ID of the specific item to retrieve",
|
||||
placeholder="item-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
item_id: str = SchemaField(description="The unique identifier for the item")
|
||||
url: str = SchemaField(description="The URL of the original source")
|
||||
title: str = SchemaField(description="The title of the item")
|
||||
content: str = SchemaField(description="The main content of the item")
|
||||
entity_data: dict = SchemaField(description="Entity-specific structured data")
|
||||
enrichments: dict = SchemaField(description="Enrichment data added to the item")
|
||||
created_at: str = SchemaField(
|
||||
description="When the item was added to the webset"
|
||||
)
|
||||
updated_at: str = SchemaField(description="When the item was last updated")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c4a7d9e2-8f3b-4a6c-9d8e-a5b6c7d8e9f0",
|
||||
description="Get a specific item from a webset by its ID",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetItemBlock.Input,
|
||||
output_schema=ExaGetWebsetItemBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_item = aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
yield "item_id", item.id
|
||||
yield "url", item.url
|
||||
yield "title", item.title
|
||||
yield "content", item.content
|
||||
yield "entity_data", item.entity_data
|
||||
yield "enrichments", item.enrichments
|
||||
yield "created_at", item.created_at
|
||||
yield "updated_at", item.updated_at
|
||||
|
||||
|
||||
class ExaListWebsetItemsBlock(Block):
|
||||
"""List items in a webset with pagination and optional filtering."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of items to return (1-100)",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_items: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for items to be available if webset is still processing",
|
||||
advanced=True,
|
||||
)
|
||||
wait_timeout: int = SchemaField(
|
||||
default=60,
|
||||
description="Maximum time to wait for items in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=300,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
items: list[WebsetItemModel] = SchemaField(
|
||||
description="List of webset items",
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID of the webset",
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each item in the list)",
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more items to paginate through",
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7b5e8c9f-01a2-43c4-95e6-f7a8b9c0d1e2",
|
||||
description="List items in a webset with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListWebsetItemsBlock.Input,
|
||||
output_schema=ExaListWebsetItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if input_data.wait_for_items:
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
interval = 2
|
||||
response = None
|
||||
|
||||
while time.time() - start_time < input_data.wait_timeout:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
if response.data:
|
||||
break
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.2, 10)
|
||||
|
||||
if not response:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
else:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
yield "items", items
|
||||
|
||||
for item in items:
|
||||
yield "item", item
|
||||
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
yield "webset_id", input_data.webset_id
|
||||
|
||||
|
||||
class ExaDeleteWebsetItemBlock(Block):
|
||||
"""Delete a specific item from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
item_id: str = SchemaField(
|
||||
description="The ID of the item to delete",
|
||||
placeholder="item-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
item_id: str = SchemaField(description="The ID of the deleted item")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12c57fbe-c270-4877-a2b6-d2d05529ba79",
|
||||
description="Delete a specific item from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteWebsetItemBlock.Input,
|
||||
output_schema=ExaDeleteWebsetItemBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_item = aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
yield "item_id", deleted_item.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaBulkWebsetItemsBlock(Block):
|
||||
"""Get all items from a webset in a single operation (with size limits)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of items to retrieve (1-1000). Note: Large values may take longer.",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
include_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include enrichment data for each item",
|
||||
)
|
||||
include_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include full content for each item",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
items: list[WebsetItemModel] = SchemaField(
|
||||
description="All items from the webset"
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each item)"
|
||||
)
|
||||
total_retrieved: int = SchemaField(
|
||||
description="Total number of items retrieved"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether results were truncated due to max_items limit"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="dbd619f5-476e-4395-af9a-a7a7c0fb8c4e",
|
||||
description="Get all items from a webset in bulk (with configurable limits)",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaBulkWebsetItemsBlock.Input,
|
||||
output_schema=ExaBulkWebsetItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
all_items: List[WebsetItemModel] = []
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
if not input_data.include_enrichments:
|
||||
item.enrichments = {}
|
||||
if not input_data.include_content:
|
||||
item.content = ""
|
||||
|
||||
all_items.append(item)
|
||||
|
||||
yield "items", all_items
|
||||
|
||||
for item in all_items:
|
||||
yield "item", item
|
||||
|
||||
yield "total_retrieved", len(all_items)
|
||||
yield "truncated", len(all_items) >= input_data.max_items
|
||||
|
||||
|
||||
class ExaWebsetItemsSummaryBlock(Block):
|
||||
"""Get a summary of items in a webset without retrieving all data."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
sample_size: int = SchemaField(
|
||||
default=5,
|
||||
description="Number of sample items to include",
|
||||
ge=0,
|
||||
le=10,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
total_items: int = SchemaField(
|
||||
description="Total number of items in the webset"
|
||||
)
|
||||
entity_type: str = SchemaField(description="Type of entities in the webset")
|
||||
sample_items: list[WebsetItemModel] = SchemaField(
|
||||
description="Sample of items from the webset"
|
||||
)
|
||||
enrichment_columns: list[str] = SchemaField(
|
||||
description="List of enrichment columns available"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="db7813ad-10bd-4652-8623-5667d6fecdd5",
|
||||
description="Get a summary of webset items without retrieving all data",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWebsetItemsSummaryBlock.Input,
|
||||
output_schema=ExaWebsetItemsSummaryBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
first_search = webset.searches[0]
|
||||
if first_search.entity:
|
||||
# The entity is a union type, extract type field
|
||||
entity_dict = first_search.entity.model_dump(by_alias=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Get enrichment columns
|
||||
enrichment_columns = []
|
||||
if webset.enrichments:
|
||||
enrichment_columns = [
|
||||
e.title if e.title else e.description for e in webset.enrichments
|
||||
]
|
||||
|
||||
# Get sample items if requested
|
||||
sample_items: List[WebsetItemModel] = []
|
||||
if input_data.sample_size > 0:
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
# Convert to our stable models
|
||||
sample_items = [
|
||||
WebsetItemModel.from_sdk(item) for item in items_response.data
|
||||
]
|
||||
|
||||
total_items = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
total_items += search.progress.found
|
||||
|
||||
yield "total_items", total_items
|
||||
yield "entity_type", entity_type
|
||||
yield "sample_items", sample_items
|
||||
yield "enrichment_columns", enrichment_columns
|
||||
|
||||
|
||||
class ExaGetNewItemsBlock(Block):
|
||||
"""Get items added to a webset since a specific cursor (incremental processing helper)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
since_cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor from previous run - only items after this will be returned. Leave empty on first run.",
|
||||
placeholder="cursor-from-previous-run",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of new items to retrieve",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
new_items: list[WebsetItemModel] = SchemaField(
|
||||
description="Items added since the cursor"
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each new item)"
|
||||
)
|
||||
count: int = SchemaField(description="Number of new items found")
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Save this cursor for the next run to get only newer items"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more new items beyond max_items"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3ff9bdf5-9613-4d21-8a60-90eb8b69c414",
|
||||
description="Get items added since a cursor - enables incremental processing without reprocessing",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=ExaGetNewItemsBlock.Input,
|
||||
output_schema=ExaGetNewItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get items starting from cursor
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
)
|
||||
|
||||
# Convert SDK items to our stable models
|
||||
new_items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
# Yield the full list
|
||||
yield "new_items", new_items
|
||||
|
||||
# Yield individual items for processing
|
||||
for item in new_items:
|
||||
yield "item", item
|
||||
|
||||
# Yield metadata for next run
|
||||
yield "count", len(new_items)
|
||||
yield "next_cursor", response.next_cursor
|
||||
yield "has_more", response.has_more
|
||||
600
autogpt_platform/backend/backend/blocks/exa/websets_monitor.py
Normal file
600
autogpt_platform/backend/backend/blocks/exa/websets_monitor.py
Normal file
@@ -0,0 +1,600 @@
|
||||
"""
|
||||
Exa Websets Monitor Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing monitors that automatically
|
||||
keep websets updated with fresh data on a schedule.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import Monitor as SdkMonitor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class MonitorModel(BaseModel):
|
||||
"""Stable output model mirroring SDK Monitor."""
|
||||
|
||||
id: str
|
||||
status: str
|
||||
webset_id: str
|
||||
behavior_type: str
|
||||
behavior_config: dict
|
||||
cron_expression: str
|
||||
timezone: str
|
||||
next_run_at: str
|
||||
last_run: dict
|
||||
metadata: dict
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, monitor: SdkMonitor) -> "MonitorModel":
|
||||
"""Convert SDK Monitor to our stable model."""
|
||||
# Extract behavior information
|
||||
behavior_dict = monitor.behavior.model_dump(by_alias=True, exclude_none=True)
|
||||
behavior_type = behavior_dict.get("type", "unknown")
|
||||
behavior_config = behavior_dict.get("config", {})
|
||||
|
||||
# Extract cadence information
|
||||
cadence_dict = monitor.cadence.model_dump(by_alias=True, exclude_none=True)
|
||||
cron_expr = cadence_dict.get("cron", "")
|
||||
timezone = cadence_dict.get("timezone", "Etc/UTC")
|
||||
|
||||
# Extract last run information
|
||||
last_run_dict = {}
|
||||
if monitor.last_run:
|
||||
last_run_dict = monitor.last_run.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Handle status enum
|
||||
status_str = (
|
||||
monitor.status.value
|
||||
if hasattr(monitor.status, "value")
|
||||
else str(monitor.status)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=monitor.id,
|
||||
status=status_str,
|
||||
webset_id=monitor.webset_id,
|
||||
behavior_type=behavior_type,
|
||||
behavior_config=behavior_config,
|
||||
cron_expression=cron_expr,
|
||||
timezone=timezone,
|
||||
next_run_at=monitor.next_run_at.isoformat() if monitor.next_run_at else "",
|
||||
last_run=last_run_dict,
|
||||
metadata=monitor.metadata or {},
|
||||
created_at=monitor.created_at.isoformat() if monitor.created_at else "",
|
||||
updated_at=monitor.updated_at.isoformat() if monitor.updated_at else "",
|
||||
)
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
"""Status of a monitor."""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
"""Type of behavior for a monitor."""
|
||||
|
||||
SEARCH = "search" # Run new searches
|
||||
REFRESH = "refresh" # Refresh existing items
|
||||
|
||||
|
||||
class SearchBehavior(str, Enum):
|
||||
"""How search results interact with existing items."""
|
||||
|
||||
APPEND = "append"
|
||||
OVERRIDE = "override"
|
||||
|
||||
|
||||
class ExaCreateMonitorBlock(Block):
|
||||
"""Create a monitor to automatically keep a webset updated on a schedule."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to monitor",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
# Schedule configuration
|
||||
cron_expression: str = SchemaField(
|
||||
description="Cron expression for scheduling (5 fields, max once per day)",
|
||||
placeholder="0 9 * * 1", # Every Monday at 9 AM
|
||||
)
|
||||
timezone: str = SchemaField(
|
||||
default="Etc/UTC",
|
||||
description="IANA timezone for the schedule",
|
||||
placeholder="America/New_York",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Behavior configuration
|
||||
behavior_type: MonitorBehaviorType = SchemaField(
|
||||
default=MonitorBehaviorType.SEARCH,
|
||||
description="Type of monitor behavior (search for new items or refresh existing)",
|
||||
)
|
||||
|
||||
# Search configuration (for SEARCH behavior)
|
||||
search_query: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Search query for finding new items (required for search behavior)",
|
||||
placeholder="AI startups that raised funding in the last week",
|
||||
)
|
||||
search_count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find in each search",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
search_criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Criteria that items must meet",
|
||||
advanced=True,
|
||||
)
|
||||
search_behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.APPEND,
|
||||
description="How new results interact with existing items",
|
||||
advanced=True,
|
||||
)
|
||||
entity_type: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Type of entity to search for (company, person, etc.)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Refresh configuration (for REFRESH behavior)
|
||||
refresh_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Refresh content from source URLs (for refresh behavior)",
|
||||
advanced=True,
|
||||
)
|
||||
refresh_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Re-run enrichments on items (for refresh behavior)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Metadata
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the monitor",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the created monitor"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this monitor belongs to")
|
||||
status: str = SchemaField(description="Status of the monitor")
|
||||
behavior_type: str = SchemaField(description="Type of monitor behavior")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
cron_expression: str = SchemaField(description="The schedule cron expression")
|
||||
timezone: str = SchemaField(description="The timezone for scheduling")
|
||||
created_at: str = SchemaField(description="When the monitor was created")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8a9b0c1-d2e3-4567-890a-bcdef1234567",
|
||||
description="Create automated monitors to keep websets updated with fresh data on a schedule",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateMonitorBlock.Input,
|
||||
output_schema=ExaCreateMonitorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"webset_id": "test-webset",
|
||||
"cron_expression": "0 9 * * 1",
|
||||
"behavior_type": MonitorBehaviorType.SEARCH,
|
||||
"search_query": "AI startups",
|
||||
"search_count": 10,
|
||||
},
|
||||
test_output=[
|
||||
("monitor_id", "monitor-123"),
|
||||
("webset_id", "test-webset"),
|
||||
("status", "enabled"),
|
||||
("behavior_type", "search"),
|
||||
("next_run_at", "2024-01-01T00:00:00"),
|
||||
("cron_expression", "0 9 * * 1"),
|
||||
("timezone", "Etc/UTC"),
|
||||
("created_at", "2024-01-01T00:00:00"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK monitor object
|
||||
mock_monitor = MagicMock()
|
||||
mock_monitor.id = "monitor-123"
|
||||
mock_monitor.status = MagicMock(value="enabled")
|
||||
mock_monitor.webset_id = "test-webset"
|
||||
mock_monitor.next_run_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.created_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.updated_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.metadata = {}
|
||||
mock_monitor.last_run = None
|
||||
|
||||
# Mock behavior
|
||||
mock_behavior = MagicMock()
|
||||
mock_behavior.model_dump = MagicMock(
|
||||
return_value={"type": "search", "config": {}}
|
||||
)
|
||||
mock_monitor.behavior = mock_behavior
|
||||
|
||||
# Mock cadence
|
||||
mock_cadence = MagicMock()
|
||||
mock_cadence.model_dump = MagicMock(
|
||||
return_value={"cron": "0 9 * * 1", "timezone": "Etc/UTC"}
|
||||
)
|
||||
mock_monitor.cadence = mock_cadence
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"websetId": input_data.webset_id,
|
||||
"cadence": {
|
||||
"cron": input_data.cron_expression,
|
||||
"timezone": input_data.timezone,
|
||||
},
|
||||
}
|
||||
|
||||
# Build behavior configuration based on type
|
||||
if input_data.behavior_type == MonitorBehaviorType.SEARCH:
|
||||
behavior_config = {
|
||||
"query": input_data.search_query or "",
|
||||
"count": input_data.search_count,
|
||||
"behavior": input_data.search_behavior.value,
|
||||
}
|
||||
|
||||
if input_data.search_criteria:
|
||||
behavior_config["criteria"] = [
|
||||
{"description": c} for c in input_data.search_criteria
|
||||
]
|
||||
|
||||
if input_data.entity_type:
|
||||
behavior_config["entity"] = {"type": input_data.entity_type}
|
||||
|
||||
payload["behavior"] = {
|
||||
"type": "search",
|
||||
"config": behavior_config,
|
||||
}
|
||||
else:
|
||||
# REFRESH behavior
|
||||
payload["behavior"] = {
|
||||
"type": "refresh",
|
||||
"config": {
|
||||
"content": input_data.refresh_content,
|
||||
"enrichments": input_data.refresh_enrichments,
|
||||
},
|
||||
}
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield all fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "webset_id", monitor.webset_id
|
||||
yield "status", monitor.status
|
||||
yield "behavior_type", monitor.behavior_type
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "cron_expression", monitor.cron_expression
|
||||
yield "timezone", monitor.timezone
|
||||
yield "created_at", monitor.created_at
|
||||
|
||||
|
||||
class ExaGetMonitorBlock(Block):
|
||||
"""Get the details and status of a monitor."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to retrieve",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the monitor"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this monitor belongs to")
|
||||
status: str = SchemaField(description="Current status of the monitor")
|
||||
behavior_type: str = SchemaField(description="Type of monitor behavior")
|
||||
behavior_config: dict = SchemaField(
|
||||
description="Configuration for the monitor behavior"
|
||||
)
|
||||
cron_expression: str = SchemaField(description="The schedule cron expression")
|
||||
timezone: str = SchemaField(description="The timezone for scheduling")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
last_run: Optional[dict] = SchemaField(
|
||||
description="Information about the last run"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the monitor was created")
|
||||
updated_at: str = SchemaField(description="When the monitor was last updated")
|
||||
metadata: dict = SchemaField(description="Metadata attached to the monitor")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5c852a2d-d505-4a56-b711-7def8dd14e72",
|
||||
description="Get the details and status of a webset monitor",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetMonitorBlock.Input,
|
||||
output_schema=ExaGetMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield all fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "webset_id", monitor.webset_id
|
||||
yield "status", monitor.status
|
||||
yield "behavior_type", monitor.behavior_type
|
||||
yield "behavior_config", monitor.behavior_config
|
||||
yield "cron_expression", monitor.cron_expression
|
||||
yield "timezone", monitor.timezone
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "last_run", monitor.last_run
|
||||
yield "created_at", monitor.created_at
|
||||
yield "updated_at", monitor.updated_at
|
||||
yield "metadata", monitor.metadata
|
||||
|
||||
|
||||
class ExaUpdateMonitorBlock(Block):
|
||||
"""Update a monitor's configuration."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to update",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
status: Optional[MonitorStatus] = SchemaField(
|
||||
default=None,
|
||||
description="New status for the monitor",
|
||||
)
|
||||
cron_expression: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New cron expression for scheduling",
|
||||
)
|
||||
timezone: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New timezone for the schedule",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="New metadata for the monitor",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the monitor"
|
||||
)
|
||||
status: str = SchemaField(description="Updated status of the monitor")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
updated_at: str = SchemaField(description="When the monitor was updated")
|
||||
success: str = SchemaField(description="Whether the update was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="245102c3-6af3-4515-a308-c2210b7939d2",
|
||||
description="Update a monitor's status, schedule, or metadata",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateMonitorBlock.Input,
|
||||
output_schema=ExaUpdateMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Build update payload
|
||||
payload = {}
|
||||
|
||||
if input_data.status is not None:
|
||||
payload["status"] = input_data.status.value
|
||||
|
||||
if input_data.cron_expression is not None or input_data.timezone is not None:
|
||||
cadence = {}
|
||||
if input_data.cron_expression:
|
||||
cadence["cron"] = input_data.cron_expression
|
||||
if input_data.timezone:
|
||||
cadence["timezone"] = input_data.timezone
|
||||
payload["cadence"] = cadence
|
||||
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
|
||||
# Convert to our stable model
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "status", monitor.status
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "updated_at", monitor.updated_at
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaDeleteMonitorBlock(Block):
|
||||
"""Delete a monitor from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to delete",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(description="The ID of the deleted monitor")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f16f9b10-0c4d-4db8-997d-7b96b6026094",
|
||||
description="Delete a monitor from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteMonitorBlock.Input,
|
||||
output_schema=ExaDeleteMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaListMonitorsBlock(Block):
|
||||
"""List all monitors with pagination."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Filter monitors by webset ID",
|
||||
placeholder="webset-id",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of monitors to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitors: list[dict] = SchemaField(description="List of monitors")
|
||||
monitor: dict = SchemaField(
|
||||
description="Individual monitor (yielded for each monitor)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more monitors to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f06e2b38-5397-4e8f-aa85-491149dd98df",
|
||||
description="List all monitors with optional webset filtering",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListMonitorsBlock.Input,
|
||||
output_schema=ExaListMonitorsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.monitors.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
)
|
||||
|
||||
# Convert SDK monitors to our stable models
|
||||
monitors = [MonitorModel.from_sdk(m) for m in response.data]
|
||||
|
||||
# Yield the full list
|
||||
yield "monitors", [m.model_dump() for m in monitors]
|
||||
|
||||
# Yield individual monitors for graph chaining
|
||||
for monitor in monitors:
|
||||
yield "monitor", monitor.model_dump()
|
||||
|
||||
# Yield pagination metadata
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
600
autogpt_platform/backend/backend/blocks/exa/websets_polling.py
Normal file
600
autogpt_platform/backend/backend/blocks/exa/websets_polling.py
Normal file
@@ -0,0 +1,600 @@
|
||||
"""
|
||||
Exa Websets Polling Blocks
|
||||
|
||||
This module provides dedicated polling blocks for waiting on webset operations
|
||||
to complete, with progress tracking and timeout management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
# Import WebsetItemModel for use in enrichment samples
|
||||
# This is safe as websets_items doesn't import from websets_polling
|
||||
from .websets_items import WebsetItemModel
|
||||
|
||||
|
||||
# Model for sample enrichment data
|
||||
class SampleEnrichmentModel(BaseModel):
|
||||
"""Sample enrichment result for display."""
|
||||
|
||||
item_id: str
|
||||
item_title: str
|
||||
enrichment_data: Dict[str, Any]
|
||||
|
||||
|
||||
class WebsetTargetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
COMPLETED = "completed"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
ANY_COMPLETE = "any_complete" # Either idle or completed
|
||||
|
||||
|
||||
class ExaWaitForWebsetBlock(Block):
|
||||
"""Wait for a webset to reach a specific status with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to monitor",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
target_status: WebsetTargetStatus = SchemaField(
|
||||
default=WebsetTargetStatus.IDLE,
|
||||
description="Status to wait for (idle=all operations complete, completed=search done, running=actively processing)",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800, # 30 minutes max
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
max_interval: int = SchemaField(
|
||||
default=30,
|
||||
description="Maximum interval between checks (for exponential backoff)",
|
||||
advanced=True,
|
||||
ge=5,
|
||||
le=120,
|
||||
)
|
||||
include_progress: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include detailed progress information in output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
webset_id: str = SchemaField(description="The webset ID that was monitored")
|
||||
final_status: str = SchemaField(description="The final status of the webset")
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
item_count: int = SchemaField(description="Number of items found")
|
||||
search_progress: dict = SchemaField(
|
||||
description="Detailed search progress information"
|
||||
)
|
||||
enrichment_progress: dict = SchemaField(
|
||||
description="Detailed enrichment progress information"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="619d71e8-b72a-434d-8bd4-23376dd0342c",
|
||||
description="Wait for a webset to reach a specific status with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForWebsetBlock.Input,
|
||||
output_schema=ExaWaitForWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
if input_data.target_status in [
|
||||
WebsetTargetStatus.IDLE,
|
||||
WebsetTargetStatus.ANY_COMPLETE,
|
||||
]:
|
||||
final_webset = aexa.websets.wait_until_idle(
|
||||
id=input_data.webset_id,
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
status_str = (
|
||||
final_webset.status.value
|
||||
if hasattr(final_webset.status, "value")
|
||||
else str(final_webset.status)
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
if final_webset.searches:
|
||||
for search in final_webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
# Extract progress if requested
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = final_webset.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(webset_dict)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", status_str
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", False
|
||||
else:
|
||||
# For other status targets, manually poll
|
||||
interval = input_data.check_interval
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
else str(webset.status)
|
||||
)
|
||||
|
||||
# Check if target status reached
|
||||
if current_status == input_data.target_status.value:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Estimate item count from search progress
|
||||
item_count = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = webset.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(
|
||||
webset_dict
|
||||
)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", current_status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", False
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, input_data.max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
else str(webset.status)
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = webset.model_dump(by_alias=True, exclude_none=True)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(webset_dict)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", final_status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
def _extract_search_progress(self, webset_data: dict) -> dict:
|
||||
"""Extract search progress information from webset data."""
|
||||
progress = {}
|
||||
searches = webset_data.get("searches", [])
|
||||
|
||||
for idx, search in enumerate(searches):
|
||||
search_id = search.get("id", f"search_{idx}")
|
||||
search_progress = search.get("progress", {})
|
||||
|
||||
progress[search_id] = {
|
||||
"status": search.get("status", "unknown"),
|
||||
"found": search_progress.get("found", 0),
|
||||
"analyzed": search_progress.get("analyzed", 0),
|
||||
"completion": search_progress.get("completion", 0),
|
||||
"time_left": search_progress.get("timeLeft", 0),
|
||||
}
|
||||
|
||||
return progress
|
||||
|
||||
def _extract_enrichment_progress(self, webset_data: dict) -> dict:
|
||||
"""Extract enrichment progress information from webset data."""
|
||||
progress = {}
|
||||
enrichments = webset_data.get("enrichments", [])
|
||||
|
||||
for idx, enrichment in enumerate(enrichments):
|
||||
enrich_id = enrichment.get("id", f"enrichment_{idx}")
|
||||
|
||||
progress[enrich_id] = {
|
||||
"status": enrichment.get("status", "unknown"),
|
||||
"title": enrichment.get("title", ""),
|
||||
"description": enrichment.get("description", ""),
|
||||
}
|
||||
|
||||
return progress
|
||||
|
||||
|
||||
class ExaWaitForSearchBlock(Block):
|
||||
"""Wait for a specific webset search to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to monitor",
|
||||
placeholder="search-id",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The search ID that was monitored")
|
||||
final_status: str = SchemaField(description="The final status of the search")
|
||||
items_found: int = SchemaField(
|
||||
description="Number of items found by the search"
|
||||
)
|
||||
items_analyzed: int = SchemaField(description="Number of items analyzed")
|
||||
completion_percentage: int = SchemaField(
|
||||
description="Completion percentage (0-100)"
|
||||
)
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
recall_info: dict = SchemaField(
|
||||
description="Information about expected results and confidence"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="14da21ae-40a1-41bc-a111-c8e5c9ef012b",
|
||||
description="Wait for a specific webset search to complete with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForSearchBlock.Input,
|
||||
output_schema=ExaWaitForSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
interval = input_data.check_interval
|
||||
max_interval = 30
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current search status using SDK
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
)
|
||||
|
||||
# Check if search is complete
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Extract progress information
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Extract recall information
|
||||
recall_info = {}
|
||||
if search.recall:
|
||||
recall_dict = search.recall.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
expected = recall_dict.get("expected", {})
|
||||
recall_info = {
|
||||
"expected_total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min_expected": expected.get("bounds", {}).get("min", 0),
|
||||
"max_expected": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": recall_dict.get("reasoning", ""),
|
||||
}
|
||||
|
||||
yield "search_id", input_data.search_id
|
||||
yield "final_status", status
|
||||
yield "items_found", progress_dict.get("found", 0)
|
||||
yield "items_analyzed", progress_dict.get("analyzed", 0)
|
||||
yield "completion_percentage", progress_dict.get("completion", 0)
|
||||
yield "elapsed_time", elapsed
|
||||
yield "recall_info", recall_info
|
||||
yield "timed_out", False
|
||||
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
final_status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
)
|
||||
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
yield "search_id", input_data.search_id
|
||||
yield "final_status", final_status
|
||||
yield "items_found", progress_dict.get("found", 0)
|
||||
yield "items_analyzed", progress_dict.get("analyzed", 0)
|
||||
yield "completion_percentage", progress_dict.get("completion", 0)
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Search polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
|
||||
class ExaWaitForEnrichmentBlock(Block):
|
||||
"""Wait for a webset enrichment to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to monitor",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
sample_results: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include sample enrichment results in output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The enrichment ID that was monitored"
|
||||
)
|
||||
final_status: str = SchemaField(
|
||||
description="The final status of the enrichment"
|
||||
)
|
||||
items_enriched: int = SchemaField(
|
||||
description="Number of items successfully enriched"
|
||||
)
|
||||
enrichment_title: str = SchemaField(
|
||||
description="Title/description of the enrichment"
|
||||
)
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
sample_data: list[SampleEnrichmentModel] = SchemaField(
|
||||
description="Sample of enriched data (if requested)"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a11865c3-ac80-4721-8a40-ac4e3b71a558",
|
||||
description="Wait for a webset enrichment to complete with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForEnrichmentBlock.Input,
|
||||
output_schema=ExaWaitForEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
interval = input_data.check_interval
|
||||
max_interval = 30
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current enrichment status using SDK
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
)
|
||||
|
||||
# Check if enrichment is complete
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get sample enriched items if requested
|
||||
sample_data = []
|
||||
items_enriched = 0
|
||||
|
||||
if input_data.sample_results and status == "completed":
|
||||
sample_data, items_enriched = (
|
||||
await self._get_sample_enrichments(
|
||||
input_data.webset_id, input_data.enrichment_id, aexa
|
||||
)
|
||||
)
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", status
|
||||
yield "items_enriched", items_enriched
|
||||
yield "enrichment_title", enrichment.title or enrichment.description or ""
|
||||
yield "elapsed_time", elapsed
|
||||
if input_data.sample_results:
|
||||
yield "sample_data", sample_data
|
||||
yield "timed_out", False
|
||||
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
final_status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
)
|
||||
title = enrichment.title or enrichment.description or ""
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", final_status
|
||||
yield "items_enriched", 0
|
||||
yield "enrichment_title", title
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Enrichment polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
async def _get_sample_enrichments(
|
||||
self, webset_id: str, enrichment_id: str, aexa: AsyncExa
|
||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
for sdk_item in response.data:
|
||||
# Convert to our WebsetItemModel first
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
# Check if this item has the enrichment we're looking for
|
||||
if enrichment_id in item.enrichments:
|
||||
enriched_count += 1
|
||||
enrich_model = item.enrichments[enrichment_id]
|
||||
|
||||
# Create sample using our typed model
|
||||
sample = SampleEnrichmentModel(
|
||||
item_id=item.id,
|
||||
item_title=item.title,
|
||||
enrichment_data=enrich_model.model_dump(exclude_none=True),
|
||||
)
|
||||
sample_data.append(sample)
|
||||
|
||||
return sample_data, enriched_count
|
||||
650
autogpt_platform/backend/backend/blocks/exa/websets_search.py
Normal file
650
autogpt_platform/backend/backend/blocks/exa/websets_search.py
Normal file
@@ -0,0 +1,650 @@
|
||||
"""
|
||||
Exa Websets Search Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing searches within websets,
|
||||
including adding new searches, checking status, and canceling operations.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetSearch as SdkWebsetSearch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
class WebsetSearchModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetSearch."""
|
||||
|
||||
id: str
|
||||
webset_id: str
|
||||
status: str
|
||||
query: str
|
||||
entity_type: str
|
||||
criteria: List[Dict[str, Any]]
|
||||
count: int
|
||||
behavior: str
|
||||
progress: Dict[str, Any]
|
||||
recall: Optional[Dict[str, Any]]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
canceled_at: Optional[str]
|
||||
canceled_reason: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, search: SdkWebsetSearch) -> "WebsetSearchModel":
|
||||
"""Convert SDK WebsetSearch to our stable model."""
|
||||
# Extract entity type
|
||||
entity_type = "auto"
|
||||
if search.entity:
|
||||
entity_dict = search.entity.model_dump(by_alias=True)
|
||||
entity_type = entity_dict.get("type", "auto")
|
||||
|
||||
# Convert criteria
|
||||
criteria = [c.model_dump(by_alias=True) for c in search.criteria]
|
||||
|
||||
# Convert progress
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(by_alias=True)
|
||||
|
||||
# Convert recall
|
||||
recall_dict = None
|
||||
if search.recall:
|
||||
recall_dict = search.recall.model_dump(by_alias=True)
|
||||
|
||||
return cls(
|
||||
id=search.id,
|
||||
webset_id=search.webset_id,
|
||||
status=(
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
),
|
||||
query=search.query,
|
||||
entity_type=entity_type,
|
||||
criteria=criteria,
|
||||
count=search.count,
|
||||
behavior=search.behavior.value if search.behavior else "override",
|
||||
progress=progress_dict,
|
||||
recall=recall_dict,
|
||||
created_at=search.created_at.isoformat() if search.created_at else "",
|
||||
updated_at=search.updated_at.isoformat() if search.updated_at else "",
|
||||
canceled_at=search.canceled_at.isoformat() if search.canceled_at else None,
|
||||
canceled_reason=(
|
||||
search.canceled_reason.value if search.canceled_reason else None
|
||||
),
|
||||
metadata=search.metadata if search.metadata else {},
|
||||
)
|
||||
|
||||
|
||||
class SearchBehavior(str, Enum):
|
||||
"""Behavior for how new search results interact with existing items."""
|
||||
|
||||
OVERRIDE = "override" # Replace existing items
|
||||
APPEND = "append" # Add to existing items
|
||||
MERGE = "merge" # Merge with existing items
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ExaCreateWebsetSearchBlock(Block):
|
||||
"""Add a new search to an existing webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query describing what to find",
|
||||
placeholder="Engineering managers at Fortune 500 companies",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
|
||||
# Entity configuration
|
||||
entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Type of entity to search for",
|
||||
)
|
||||
entity_description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Criteria for verification
|
||||
criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of criteria that items must meet. If not provided, auto-detected from query.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Advanced search options
|
||||
behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.APPEND,
|
||||
description="How new results interact with existing items",
|
||||
advanced=True,
|
||||
)
|
||||
recall: bool = SchemaField(
|
||||
default=True,
|
||||
description="Enable recall estimation for expected results",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Exclude sources
|
||||
exclude_source_ids: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="IDs of imports/websets to exclude from results",
|
||||
advanced=True,
|
||||
)
|
||||
exclude_source_types: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Types of sources to exclude ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Scope sources
|
||||
scope_source_ids: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="IDs of imports/websets to limit search scope to",
|
||||
advanced=True,
|
||||
)
|
||||
scope_source_types: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Types of scope sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
scope_relationships: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Relationship definitions for hop searches",
|
||||
advanced=True,
|
||||
)
|
||||
scope_relationship_limits: list[int] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Limits on related entities to find",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the search",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Polling options
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for the search to complete before returning",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait for completion in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(
|
||||
description="The unique identifier for the created search"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this search belongs to")
|
||||
status: str = SchemaField(description="Current status of the search")
|
||||
query: str = SchemaField(description="The search query")
|
||||
expected_results: dict = SchemaField(
|
||||
description="Recall estimation of expected results"
|
||||
)
|
||||
items_found: Optional[int] = SchemaField(
|
||||
description="Number of items found (if wait_for_completion was True)"
|
||||
)
|
||||
completion_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="342ff776-2e2c-4cdb-b392-4eeb34b21d5f",
|
||||
description="Add a new search to an existing webset to find more items",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateWebsetSearchBlock.Input,
|
||||
output_schema=ExaCreateWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import time
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"count": input_data.count,
|
||||
"behavior": input_data.behavior.value,
|
||||
"recall": input_data.recall,
|
||||
}
|
||||
|
||||
# Add entity configuration
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
entity = {"type": input_data.entity_type.value}
|
||||
if (
|
||||
input_data.entity_type == SearchEntityType.CUSTOM
|
||||
and input_data.entity_description
|
||||
):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
# Add criteria if provided
|
||||
if input_data.criteria:
|
||||
payload["criteria"] = [{"description": c} for c in input_data.criteria]
|
||||
|
||||
# Add exclude sources
|
||||
if input_data.exclude_source_ids:
|
||||
exclude_list = []
|
||||
for idx, src_id in enumerate(input_data.exclude_source_ids):
|
||||
src_type = "import"
|
||||
if input_data.exclude_source_types and idx < len(
|
||||
input_data.exclude_source_types
|
||||
):
|
||||
src_type = input_data.exclude_source_types[idx]
|
||||
exclude_list.append({"source": src_type, "id": src_id})
|
||||
payload["exclude"] = exclude_list
|
||||
|
||||
# Add scope sources
|
||||
if input_data.scope_source_ids:
|
||||
scope_list: list[dict[str, Any]] = []
|
||||
for idx, src_id in enumerate(input_data.scope_source_ids):
|
||||
scope_item: dict[str, Any] = {"source": "import", "id": src_id}
|
||||
|
||||
if input_data.scope_source_types and idx < len(
|
||||
input_data.scope_source_types
|
||||
):
|
||||
scope_item["source"] = input_data.scope_source_types[idx]
|
||||
|
||||
# Add relationship if provided
|
||||
if input_data.scope_relationships and idx < len(
|
||||
input_data.scope_relationships
|
||||
):
|
||||
relationship: dict[str, Any] = {
|
||||
"definition": input_data.scope_relationships[idx]
|
||||
}
|
||||
if input_data.scope_relationship_limits and idx < len(
|
||||
input_data.scope_relationship_limits
|
||||
):
|
||||
relationship["limit"] = input_data.scope_relationship_limits[
|
||||
idx
|
||||
]
|
||||
scope_item["relationship"] = relationship
|
||||
|
||||
scope_list.append(scope_item)
|
||||
payload["scope"] = scope_list
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
search_id = sdk_search.id
|
||||
status = (
|
||||
sdk_search.status.value
|
||||
if hasattr(sdk_search.status, "value")
|
||||
else str(sdk_search.status)
|
||||
)
|
||||
|
||||
# Extract expected results from recall
|
||||
expected_results = {}
|
||||
if sdk_search.recall:
|
||||
recall_dict = sdk_search.recall.model_dump(by_alias=True)
|
||||
expected = recall_dict.get("expected", {})
|
||||
expected_results = {
|
||||
"total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min": expected.get("bounds", {}).get("min", 0),
|
||||
"max": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": recall_dict.get("reasoning", ""),
|
||||
}
|
||||
|
||||
# If wait_for_completion is True, poll for completion
|
||||
if input_data.wait_for_completion:
|
||||
import asyncio
|
||||
|
||||
poll_interval = 5
|
||||
max_interval = 30
|
||||
poll_start = time.time()
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
current_status = (
|
||||
current_search.status.value
|
||||
if hasattr(current_search.status, "value")
|
||||
else str(current_search.status)
|
||||
)
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
items_found = 0
|
||||
if current_search.progress:
|
||||
items_found = current_search.progress.found
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", current_status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
yield "items_found", items_found
|
||||
yield "completion_time", completion_time
|
||||
return
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
poll_interval = min(poll_interval * 1.5, max_interval)
|
||||
|
||||
# Timeout - yield what we have
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
yield "items_found", 0
|
||||
yield "completion_time", time.time() - start_time
|
||||
else:
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
|
||||
|
||||
class ExaGetWebsetSearchBlock(Block):
|
||||
"""Get the status and details of a webset search."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to retrieve",
|
||||
placeholder="search-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The unique identifier for the search")
|
||||
status: str = SchemaField(description="Current status of the search")
|
||||
query: str = SchemaField(description="The search query")
|
||||
entity_type: str = SchemaField(description="Type of entity being searched")
|
||||
criteria: list[dict] = SchemaField(description="Criteria used for verification")
|
||||
progress: dict = SchemaField(description="Search progress information")
|
||||
recall: dict = SchemaField(description="Recall estimation information")
|
||||
created_at: str = SchemaField(description="When the search was created")
|
||||
updated_at: str = SchemaField(description="When the search was last updated")
|
||||
canceled_at: Optional[str] = SchemaField(
|
||||
description="When the search was canceled (if applicable)"
|
||||
)
|
||||
canceled_reason: Optional[str] = SchemaField(
|
||||
description="Reason for cancellation (if applicable)"
|
||||
)
|
||||
metadata: dict = SchemaField(description="Metadata attached to the search")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4fa3e627-a0ff-485f-8732-52148051646c",
|
||||
description="Get the status and details of a webset search",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetSearchBlock.Input,
|
||||
output_schema=ExaGetWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
# Extract progress information
|
||||
progress_info = {
|
||||
"found": search.progress.get("found", 0),
|
||||
"analyzed": search.progress.get("analyzed", 0),
|
||||
"completion": search.progress.get("completion", 0),
|
||||
"time_left": search.progress.get("timeLeft", 0),
|
||||
}
|
||||
|
||||
# Extract recall information
|
||||
recall_data = {}
|
||||
if search.recall:
|
||||
expected = search.recall.get("expected", {})
|
||||
recall_data = {
|
||||
"expected_total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min_expected": expected.get("bounds", {}).get("min", 0),
|
||||
"max_expected": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": search.recall.get("reasoning", ""),
|
||||
}
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "entity_type", search.entity_type
|
||||
yield "criteria", search.criteria
|
||||
yield "progress", progress_info
|
||||
yield "recall", recall_data
|
||||
yield "created_at", search.created_at
|
||||
yield "updated_at", search.updated_at
|
||||
yield "canceled_at", search.canceled_at
|
||||
yield "canceled_reason", search.canceled_reason
|
||||
yield "metadata", search.metadata
|
||||
|
||||
|
||||
class ExaCancelWebsetSearchBlock(Block):
|
||||
"""Cancel a running webset search."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to cancel",
|
||||
placeholder="search-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The ID of the canceled search")
|
||||
status: str = SchemaField(description="Status after cancellation")
|
||||
items_found_before_cancel: int = SchemaField(
|
||||
description="Number of items found before cancellation"
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74ef9f1e-ae89-4c7f-9d7d-d217214815b4",
|
||||
description="Cancel a running webset search",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelWebsetSearchBlock.Input,
|
||||
output_schema=ExaCancelWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_search = aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
# Extract items found before cancellation
|
||||
items_found = 0
|
||||
if canceled_search.progress:
|
||||
items_found = canceled_search.progress.found
|
||||
|
||||
status = (
|
||||
canceled_search.status.value
|
||||
if hasattr(canceled_search.status, "value")
|
||||
else str(canceled_search.status)
|
||||
)
|
||||
|
||||
yield "search_id", canceled_search.id
|
||||
yield "status", status
|
||||
yield "items_found_before_cancel", items_found
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaFindOrCreateSearchBlock(Block):
|
||||
"""Find existing search by query or create new one (prevents duplicate searches)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query to find or create",
|
||||
placeholder="AI companies in San Francisco",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find (only used if creating new search)",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Entity type (only used if creating)",
|
||||
advanced=True,
|
||||
)
|
||||
behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.OVERRIDE,
|
||||
description="Search behavior (only used if creating)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The search ID (existing or new)")
|
||||
webset_id: str = SchemaField(description="The webset ID")
|
||||
status: str = SchemaField(description="Current search status")
|
||||
query: str = SchemaField(description="The search query")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if search was newly created, False if already existed"
|
||||
)
|
||||
items_found: int = SchemaField(
|
||||
description="Number of items found (0 if still running)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbdb05ac-cb73-4b03-a493-6d34e9a011da",
|
||||
description="Find existing search by query or create new - prevents duplicate searches in workflows",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaFindOrCreateSearchBlock.Input,
|
||||
output_schema=ExaFindOrCreateSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.query.strip().lower() == input_data.query.strip().lower():
|
||||
existing_search = search
|
||||
break
|
||||
|
||||
if existing_search:
|
||||
# Found existing search
|
||||
search = WebsetSearchModel.from_sdk(existing_search)
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "was_created", False
|
||||
yield "items_found", search.progress.get("found", 0)
|
||||
else:
|
||||
# Create new search
|
||||
payload: Dict[str, Any] = {
|
||||
"query": input_data.query,
|
||||
"count": input_data.count,
|
||||
"behavior": input_data.behavior.value,
|
||||
}
|
||||
|
||||
# Add entity if not auto
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
payload["entity"] = {"type": input_data.entity_type.value}
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "was_created", True
|
||||
yield "items_found", 0 # Newly created, no items yet
|
||||
@@ -10,7 +10,13 @@ from backend.blocks.fal._auth import (
|
||||
FalCredentialsField,
|
||||
FalCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import ClientResponseError, Requests
|
||||
|
||||
@@ -24,7 +30,7 @@ class FalModel(str, Enum):
|
||||
|
||||
|
||||
class AIVideoGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
prompt: str = SchemaField(
|
||||
description="Description of the video to generate.",
|
||||
placeholder="A dog running in a field.",
|
||||
@@ -36,7 +42,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
)
|
||||
credentials: FalCredentialsInput = FalCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
video_url: str = SchemaField(description="The URL of the generated video.")
|
||||
error: str = SchemaField(
|
||||
description="Error message if video generation failed."
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -19,7 +20,7 @@ from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlCrawlBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
@@ -39,7 +40,7 @@ class FirecrawlCrawlBlock(Block):
|
||||
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
data: list[dict[str, Any]] = SchemaField(description="The result of the crawl")
|
||||
markdown: str = SchemaField(description="The markdown of the crawl")
|
||||
html: str = SchemaField(description="The html of the crawl")
|
||||
@@ -55,6 +56,10 @@ class FirecrawlCrawlBlock(Block):
|
||||
change_tracking: dict[str, Any] = SchemaField(
|
||||
description="The change tracking of the crawl"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the crawl failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.sdk import (
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
cost,
|
||||
@@ -20,7 +21,7 @@ from ._config import firecrawl
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
urls: list[str] = SchemaField(
|
||||
description="The URLs to crawl - at least one is required. Wildcards are supported. (/*)"
|
||||
@@ -37,8 +38,12 @@ class FirecrawlExtractBlock(Block):
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the crawl")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the extraction failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -7,7 +7,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -16,16 +17,20 @@ from ._config import firecrawl
|
||||
|
||||
|
||||
class FirecrawlMapWebsiteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
|
||||
url: str = SchemaField(description="The website url to map")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
links: list[str] = SchemaField(description="List of URLs found on the website")
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="List of search results with url, title, and description"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the map failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -8,7 +8,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -18,7 +19,7 @@ from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlScrapeBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
@@ -38,7 +39,7 @@ class FirecrawlScrapeBlock(Block):
|
||||
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the crawl")
|
||||
markdown: str = SchemaField(description="The markdown of the crawl")
|
||||
html: str = SchemaField(description="The html of the crawl")
|
||||
@@ -54,6 +55,10 @@ class FirecrawlScrapeBlock(Block):
|
||||
change_tracking: dict[str, Any] = SchemaField(
|
||||
description="The change tracking of the crawl"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the scrape failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -19,7 +20,7 @@ from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
query: str = SchemaField(description="The query to search for")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
@@ -35,9 +36,13 @@ class FirecrawlSearchBlock(Block):
|
||||
description="Returns the content of the search if specified", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the search")
|
||||
site: dict[str, Any] = SchemaField(description="The site of the search")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -5,7 +5,13 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -57,7 +63,7 @@ class AspectRatio(str, Enum):
|
||||
|
||||
|
||||
class AIImageEditorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -90,11 +96,10 @@ class AIImageEditorBlock(Block):
|
||||
title="Model",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output_image: MediaFileType = SchemaField(
|
||||
description="URL of the transformed image"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -3,7 +3,8 @@ from backend.sdk import (
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
ProviderBuilder,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
@@ -19,14 +20,14 @@ generic_webhook = (
|
||||
|
||||
|
||||
class GenericWebhookTriggerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
constants: dict = SchemaField(
|
||||
description="The constants to be set when the block is put on the graph",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload that was received from the generic webhook."
|
||||
)
|
||||
|
||||
@@ -3,7 +3,13 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -39,7 +45,7 @@ class ChecksConclusion(Enum):
|
||||
class GithubCreateCheckRunBlock(Block):
|
||||
"""Block for creating a new check run on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo:status")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -76,7 +82,7 @@ class GithubCreateCheckRunBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class CheckRunResult(BaseModel):
|
||||
id: int
|
||||
html_url: str
|
||||
@@ -211,7 +217,7 @@ class GithubCreateCheckRunBlock(Block):
|
||||
class GithubUpdateCheckRunBlock(Block):
|
||||
"""Block for updating an existing check run on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo:status")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -239,7 +245,7 @@ class GithubUpdateCheckRunBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class CheckRunResult(BaseModel):
|
||||
id: int
|
||||
html_url: str
|
||||
@@ -249,7 +255,6 @@ class GithubUpdateCheckRunBlock(Block):
|
||||
check_run: CheckRunResult = SchemaField(
|
||||
description="Details of the updated check run"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if check run update failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -5,7 +5,13 @@ from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -37,7 +43,7 @@ class CheckRunConclusion(Enum):
|
||||
|
||||
|
||||
class GithubGetCIResultsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -60,7 +66,7 @@ class GithubGetCIResultsBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class CheckRunItem(TypedDict, total=False):
|
||||
id: int
|
||||
name: str
|
||||
@@ -104,7 +110,6 @@ class GithubGetCIResultsBlock(Block):
|
||||
total_checks: int = SchemaField(description="Total number of CI checks")
|
||||
passed_checks: int = SchemaField(description="Number of passed checks")
|
||||
failed_checks: int = SchemaField(description="Number of failed checks")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -3,7 +3,13 @@ from urllib.parse import urlparse
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import convert_comment_url_to_api_endpoint, get_api
|
||||
@@ -24,7 +30,7 @@ def is_github_url(url: str) -> bool:
|
||||
|
||||
# --8<-- [start:GithubCommentBlockExample]
|
||||
class GithubCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
@@ -35,7 +41,7 @@ class GithubCommentBlock(Block):
|
||||
placeholder="Enter your comment",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
id: int = SchemaField(description="ID of the created comment")
|
||||
url: str = SchemaField(description="URL to the comment on GitHub")
|
||||
error: str = SchemaField(
|
||||
@@ -112,7 +118,7 @@ class GithubCommentBlock(Block):
|
||||
|
||||
|
||||
class GithubUpdateCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
comment_url: str = SchemaField(
|
||||
description="URL of the GitHub comment",
|
||||
@@ -135,7 +141,7 @@ class GithubUpdateCommentBlock(Block):
|
||||
placeholder="Enter your comment",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
id: int = SchemaField(description="ID of the updated comment")
|
||||
url: str = SchemaField(description="URL to the comment on GitHub")
|
||||
error: str = SchemaField(
|
||||
@@ -219,14 +225,14 @@ class GithubUpdateCommentBlock(Block):
|
||||
|
||||
|
||||
class GithubListCommentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
body: str
|
||||
@@ -239,7 +245,6 @@ class GithubListCommentsBlock(Block):
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of comments with their ID, body, user, and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -335,7 +340,7 @@ class GithubListCommentsBlock(Block):
|
||||
|
||||
|
||||
class GithubMakeIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -348,7 +353,7 @@ class GithubMakeIssueBlock(Block):
|
||||
description="Body of the issue", placeholder="Enter the issue body"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
number: int = SchemaField(description="Number of the created issue")
|
||||
url: str = SchemaField(description="URL of the created issue")
|
||||
error: str = SchemaField(
|
||||
@@ -410,14 +415,14 @@ class GithubMakeIssueBlock(Block):
|
||||
|
||||
|
||||
class GithubReadIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
title: str = SchemaField(description="Title of the issue")
|
||||
body: str = SchemaField(description="Body of the issue")
|
||||
user: str = SchemaField(description="User who created the issue")
|
||||
@@ -483,14 +488,14 @@ class GithubReadIssueBlock(Block):
|
||||
|
||||
|
||||
class GithubListIssuesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class IssueItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
@@ -501,7 +506,6 @@ class GithubListIssuesBlock(Block):
|
||||
issues: list[IssueItem] = SchemaField(
|
||||
description="List of issues with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -573,7 +577,7 @@ class GithubListIssuesBlock(Block):
|
||||
|
||||
|
||||
class GithubAddLabelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
@@ -584,7 +588,7 @@ class GithubAddLabelBlock(Block):
|
||||
placeholder="Enter the label",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the label addition operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the label addition failed"
|
||||
@@ -633,7 +637,7 @@ class GithubAddLabelBlock(Block):
|
||||
|
||||
|
||||
class GithubRemoveLabelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
@@ -644,7 +648,7 @@ class GithubRemoveLabelBlock(Block):
|
||||
placeholder="Enter the label",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the label removal operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the label removal failed"
|
||||
@@ -694,7 +698,7 @@ class GithubRemoveLabelBlock(Block):
|
||||
|
||||
|
||||
class GithubAssignIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
@@ -705,7 +709,7 @@ class GithubAssignIssueBlock(Block):
|
||||
placeholder="Enter the username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(
|
||||
description="Status of the issue assignment operation"
|
||||
)
|
||||
@@ -760,7 +764,7 @@ class GithubAssignIssueBlock(Block):
|
||||
|
||||
|
||||
class GithubUnassignIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
@@ -771,7 +775,7 @@ class GithubUnassignIssueBlock(Block):
|
||||
placeholder="Enter the username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(
|
||||
description="Status of the issue unassignment operation"
|
||||
)
|
||||
|
||||
@@ -2,7 +2,13 @@ import re
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -16,14 +22,14 @@ from ._auth import (
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class PRItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
@@ -108,7 +114,7 @@ class GithubListPullRequestsBlock(Block):
|
||||
|
||||
|
||||
class GithubMakePullRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -135,7 +141,7 @@ class GithubMakePullRequestBlock(Block):
|
||||
placeholder="Enter the base branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
number: int = SchemaField(description="Number of the created pull request")
|
||||
url: str = SchemaField(description="URL of the created pull request")
|
||||
error: str = SchemaField(
|
||||
@@ -209,7 +215,7 @@ class GithubMakePullRequestBlock(Block):
|
||||
|
||||
|
||||
class GithubReadPullRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
@@ -221,7 +227,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
title: str = SchemaField(description="Title of the pull request")
|
||||
body: str = SchemaField(description="Body of the pull request")
|
||||
author: str = SchemaField(description="User who created the pull request")
|
||||
@@ -325,7 +331,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
|
||||
|
||||
class GithubAssignPRReviewerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
@@ -336,7 +342,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
placeholder="Enter the reviewer's username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(
|
||||
description="Status of the reviewer assignment operation"
|
||||
)
|
||||
@@ -392,7 +398,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
|
||||
|
||||
class GithubUnassignPRReviewerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
@@ -403,7 +409,7 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
placeholder="Enter the reviewer's username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(
|
||||
description="Status of the reviewer unassignment operation"
|
||||
)
|
||||
@@ -459,14 +465,14 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
|
||||
|
||||
class GithubListPRReviewersBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class ReviewerItem(TypedDict):
|
||||
username: str
|
||||
url: str
|
||||
|
||||
@@ -2,7 +2,13 @@ import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -16,14 +22,14 @@ from ._auth import (
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class TagItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
@@ -34,7 +40,6 @@ class GithubListTagsBlock(Block):
|
||||
tags: list[TagItem] = SchemaField(
|
||||
description="List of tags with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing tags failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -111,14 +116,14 @@ class GithubListTagsBlock(Block):
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
@@ -130,7 +135,6 @@ class GithubListBranchesBlock(Block):
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -207,7 +211,7 @@ class GithubListBranchesBlock(Block):
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -217,7 +221,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
description="Number of discussions to fetch", default=5
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class DiscussionItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
@@ -323,14 +327,14 @@ class GithubListDiscussionsBlock(Block):
|
||||
|
||||
|
||||
class GithubListReleasesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class ReleaseItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
@@ -342,7 +346,6 @@ class GithubListReleasesBlock(Block):
|
||||
releases: list[ReleaseItem] = SchemaField(
|
||||
description="List of releases with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing releases failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -414,7 +417,7 @@ class GithubListReleasesBlock(Block):
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -430,7 +433,7 @@ class GithubReadFileBlock(Block):
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
@@ -438,7 +441,6 @@ class GithubReadFileBlock(Block):
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if the file reading failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -501,7 +503,7 @@ class GithubReadFileBlock(Block):
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -517,7 +519,7 @@ class GithubReadFolderBlock(Block):
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
@@ -625,7 +627,7 @@ class GithubReadFolderBlock(Block):
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -640,7 +642,7 @@ class GithubMakeBranchBlock(Block):
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
@@ -705,7 +707,7 @@ class GithubMakeBranchBlock(Block):
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -716,7 +718,7 @@ class GithubDeleteBranchBlock(Block):
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
@@ -766,7 +768,7 @@ class GithubDeleteBranchBlock(Block):
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -789,7 +791,7 @@ class GithubCreateFileBlock(Block):
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
@@ -868,7 +870,7 @@ class GithubCreateFileBlock(Block):
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -891,10 +893,9 @@ class GithubUpdateFileBlock(Block):
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(description="Error message if the file update failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -974,7 +975,7 @@ class GithubUpdateFileBlock(Block):
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
name: str = SchemaField(
|
||||
description="Name of the repository to create",
|
||||
@@ -998,7 +999,7 @@ class GithubCreateRepositoryBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the repository")
|
||||
error: str = SchemaField(
|
||||
@@ -1077,14 +1078,14 @@ class GithubCreateRepositoryBlock(Block):
|
||||
|
||||
|
||||
class GithubListStargazersBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class StargazerItem(TypedDict):
|
||||
username: str
|
||||
url: str
|
||||
|
||||
@@ -4,7 +4,13 @@ from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -26,7 +32,7 @@ class ReviewEvent(Enum):
|
||||
|
||||
|
||||
class GithubCreatePRReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
class ReviewComment(TypedDict, total=False):
|
||||
path: str
|
||||
position: Optional[int]
|
||||
@@ -61,7 +67,7 @@ class GithubCreatePRReviewBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
review_id: int = SchemaField(description="ID of the created review")
|
||||
state: str = SchemaField(
|
||||
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
|
||||
@@ -197,7 +203,7 @@ class GithubCreatePRReviewBlock(Block):
|
||||
|
||||
|
||||
class GithubListPRReviewsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -208,7 +214,7 @@ class GithubListPRReviewsBlock(Block):
|
||||
placeholder="123",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class ReviewItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
@@ -223,7 +229,6 @@ class GithubListPRReviewsBlock(Block):
|
||||
reviews: list[ReviewItem] = SchemaField(
|
||||
description="List of all reviews on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing reviews failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -317,7 +322,7 @@ class GithubListPRReviewsBlock(Block):
|
||||
|
||||
|
||||
class GithubSubmitPendingReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -336,7 +341,7 @@ class GithubSubmitPendingReviewBlock(Block):
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
state: str = SchemaField(description="State of the submitted review")
|
||||
html_url: str = SchemaField(description="URL of the submitted review")
|
||||
error: str = SchemaField(
|
||||
@@ -415,7 +420,7 @@ class GithubSubmitPendingReviewBlock(Block):
|
||||
|
||||
|
||||
class GithubResolveReviewDiscussionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -434,9 +439,8 @@ class GithubResolveReviewDiscussionBlock(Block):
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -579,7 +583,7 @@ class GithubResolveReviewDiscussionBlock(Block):
|
||||
|
||||
|
||||
class GithubGetPRReviewCommentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -596,7 +600,7 @@ class GithubGetPRReviewCommentsBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
@@ -616,7 +620,6 @@ class GithubGetPRReviewCommentsBlock(Block):
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of all review comments on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -744,7 +747,7 @@ class GithubGetPRReviewCommentsBlock(Block):
|
||||
|
||||
|
||||
class GithubCreateCommentObjectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
path: str = SchemaField(
|
||||
description="The file path to comment on",
|
||||
placeholder="src/main.py",
|
||||
@@ -781,7 +784,7 @@ class GithubCreateCommentObjectBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
comment_object: dict = SchemaField(
|
||||
description="The comment object formatted for GitHub API"
|
||||
)
|
||||
|
||||
@@ -3,7 +3,13 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -26,7 +32,7 @@ class StatusState(Enum):
|
||||
class GithubCreateStatusBlock(Block):
|
||||
"""Block for creating a commit status on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubFineGrainedAPICredentialsInput = (
|
||||
GithubFineGrainedAPICredentialsField("repo:status")
|
||||
)
|
||||
@@ -54,7 +60,7 @@ class GithubCreateStatusBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
class StatusResult(BaseModel):
|
||||
id: int
|
||||
url: str
|
||||
@@ -66,7 +72,6 @@ class GithubCreateStatusBlock(Block):
|
||||
updated_at: str
|
||||
|
||||
status: StatusResult = SchemaField(description="Details of the created status")
|
||||
error: str = SchemaField(description="Error message if status creation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -8,7 +8,8 @@ from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
@@ -26,7 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# --8<-- [start:GithubTriggerExample]
|
||||
class GitHubTriggerBase:
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description=(
|
||||
@@ -40,7 +41,7 @@ class GitHubTriggerBase:
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
# --8<-- [end:example-payload-field]
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload that was received from GitHub. "
|
||||
"Includes information about the affected resource (e.g. pull request), "
|
||||
|
||||
@@ -8,7 +8,13 @@ from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -43,7 +49,7 @@ class CalendarEvent(BaseModel):
|
||||
|
||||
|
||||
class GoogleCalendarReadEventsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/calendar.readonly"]
|
||||
)
|
||||
@@ -73,7 +79,7 @@ class GoogleCalendarReadEventsBlock(Block):
|
||||
description="Include events you've declined", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
events: list[CalendarEvent] = SchemaField(
|
||||
description="List of calendar events in the requested time range",
|
||||
default_factory=list,
|
||||
@@ -379,7 +385,7 @@ class RecurringEvent(BaseModel):
|
||||
|
||||
|
||||
class GoogleCalendarCreateEventBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/calendar"]
|
||||
)
|
||||
@@ -433,12 +439,11 @@ class GoogleCalendarCreateEventBlock(Block):
|
||||
default_factory=lambda: [ReminderPreset.TEN_MINUTES],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
event_id: str = SchemaField(description="ID of the created event")
|
||||
event_link: str = SchemaField(
|
||||
description="Link to view the event in Google Calendar"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if event creation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -14,7 +14,13 @@ from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
from backend.util.settings import Settings
|
||||
@@ -320,7 +326,7 @@ class GmailBase(Block, ABC):
|
||||
|
||||
|
||||
class GmailReadBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
)
|
||||
@@ -333,7 +339,7 @@ class GmailReadBlock(GmailBase):
|
||||
default=10,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
email: Email = SchemaField(
|
||||
description="Email data",
|
||||
)
|
||||
@@ -516,7 +522,7 @@ class GmailSendBlock(GmailBase):
|
||||
- Attachment support for multiple files
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.send"]
|
||||
)
|
||||
@@ -540,7 +546,7 @@ class GmailSendBlock(GmailBase):
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: GmailSendResult = SchemaField(
|
||||
description="Send confirmation",
|
||||
)
|
||||
@@ -618,7 +624,7 @@ class GmailCreateDraftBlock(GmailBase):
|
||||
- Attachment support for multiple files
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.modify"]
|
||||
)
|
||||
@@ -642,7 +648,7 @@ class GmailCreateDraftBlock(GmailBase):
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: GmailDraftResult = SchemaField(
|
||||
description="Draft creation result",
|
||||
)
|
||||
@@ -721,12 +727,12 @@ class GmailCreateDraftBlock(GmailBase):
|
||||
|
||||
|
||||
class GmailListLabelsBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.labels"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: list[dict] = SchemaField(
|
||||
description="List of labels",
|
||||
)
|
||||
@@ -779,7 +785,7 @@ class GmailListLabelsBlock(GmailBase):
|
||||
|
||||
|
||||
class GmailAddLabelBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.modify"]
|
||||
)
|
||||
@@ -790,7 +796,7 @@ class GmailAddLabelBlock(GmailBase):
|
||||
description="Label name to add",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: GmailLabelResult = SchemaField(
|
||||
description="Label addition result",
|
||||
)
|
||||
@@ -865,7 +871,7 @@ class GmailAddLabelBlock(GmailBase):
|
||||
|
||||
|
||||
class GmailRemoveLabelBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.modify"]
|
||||
)
|
||||
@@ -876,7 +882,7 @@ class GmailRemoveLabelBlock(GmailBase):
|
||||
description="Label name to remove",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: GmailLabelResult = SchemaField(
|
||||
description="Label removal result",
|
||||
)
|
||||
@@ -941,17 +947,16 @@ class GmailRemoveLabelBlock(GmailBase):
|
||||
|
||||
|
||||
class GmailGetThreadBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
)
|
||||
threadId: str = SchemaField(description="Gmail thread ID")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
thread: Thread = SchemaField(
|
||||
description="Gmail thread with decoded message bodies"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1218,7 +1223,7 @@ class GmailReplyBlock(GmailBase):
|
||||
- Full Unicode/emoji support with UTF-8 encoding
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
@@ -1246,14 +1251,13 @@ class GmailReplyBlock(GmailBase):
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
messageId: str = SchemaField(description="Sent message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
message: dict = SchemaField(description="Raw Gmail message object")
|
||||
email: Email = SchemaField(
|
||||
description="Parsed email object with decoded body and attachments"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1368,7 +1372,7 @@ class GmailDraftReplyBlock(GmailBase):
|
||||
- Full Unicode/emoji support with UTF-8 encoding
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
@@ -1396,12 +1400,11 @@ class GmailDraftReplyBlock(GmailBase):
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
draftId: str = SchemaField(description="Created draft ID")
|
||||
messageId: str = SchemaField(description="Draft message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
status: str = SchemaField(description="Draft creation status")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1482,14 +1485,13 @@ class GmailDraftReplyBlock(GmailBase):
|
||||
|
||||
|
||||
class GmailGetProfileBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
profile: Profile = SchemaField(description="Gmail user profile information")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1555,7 +1557,7 @@ class GmailForwardBlock(GmailBase):
|
||||
- Manual content type override option
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
@@ -1589,11 +1591,10 @@ class GmailForwardBlock(GmailBase):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
messageId: str = SchemaField(description="Forwarded message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
status: str = SchemaField(description="Forward status")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -5,7 +5,13 @@ from typing import Any
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -195,7 +201,7 @@ class BatchOperationType(str, Enum):
|
||||
CLEAR = "clear"
|
||||
|
||||
|
||||
class BatchOperation(BlockSchema):
|
||||
class BatchOperation(BlockSchemaInput):
|
||||
type: BatchOperationType = SchemaField(
|
||||
description="The type of operation to perform"
|
||||
)
|
||||
@@ -206,7 +212,7 @@ class BatchOperation(BlockSchema):
|
||||
|
||||
|
||||
class GoogleSheetsReadBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets.readonly"]
|
||||
)
|
||||
@@ -218,7 +224,7 @@ class GoogleSheetsReadBlock(Block):
|
||||
description="The A1 notation of the range to read",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: list[list[str]] = SchemaField(
|
||||
description="The data read from the spreadsheet",
|
||||
)
|
||||
@@ -274,7 +280,7 @@ class GoogleSheetsReadBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsWriteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -289,7 +295,7 @@ class GoogleSheetsWriteBlock(Block):
|
||||
description="The data to write to the spreadsheet",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(
|
||||
description="The result of the write operation",
|
||||
)
|
||||
@@ -363,7 +369,7 @@ class GoogleSheetsWriteBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsAppendBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -403,9 +409,8 @@ class GoogleSheetsAppendBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(description="Append API response")
|
||||
error: str = SchemaField(description="Error message, if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -503,7 +508,7 @@ class GoogleSheetsAppendBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsClearBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -515,7 +520,7 @@ class GoogleSheetsClearBlock(Block):
|
||||
description="The A1 notation of the range to clear",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(
|
||||
description="The result of the clear operation",
|
||||
)
|
||||
@@ -571,7 +576,7 @@ class GoogleSheetsClearBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsMetadataBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets.readonly"]
|
||||
)
|
||||
@@ -580,7 +585,7 @@ class GoogleSheetsMetadataBlock(Block):
|
||||
title="Spreadsheet ID or URL",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(
|
||||
description="The metadata of the spreadsheet including sheets info",
|
||||
)
|
||||
@@ -652,7 +657,7 @@ class GoogleSheetsMetadataBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsManageSheetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -672,9 +677,8 @@ class GoogleSheetsManageSheetBlock(Block):
|
||||
description="New sheet name for copy", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(description="Operation result")
|
||||
error: str = SchemaField(description="Error message, if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -760,7 +764,7 @@ class GoogleSheetsManageSheetBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsBatchOperationsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -772,7 +776,7 @@ class GoogleSheetsBatchOperationsBlock(Block):
|
||||
description="List of operations to perform",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(
|
||||
description="The result of the batch operations",
|
||||
)
|
||||
@@ -877,7 +881,7 @@ class GoogleSheetsBatchOperationsBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsFindReplaceBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -904,7 +908,7 @@ class GoogleSheetsFindReplaceBlock(Block):
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(
|
||||
description="The result of the find/replace operation including number of replacements",
|
||||
)
|
||||
@@ -987,7 +991,7 @@ class GoogleSheetsFindReplaceBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsFindBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets.readonly"]
|
||||
)
|
||||
@@ -1020,7 +1024,7 @@ class GoogleSheetsFindBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(
|
||||
description="The result of the find operation including locations and count",
|
||||
)
|
||||
@@ -1255,7 +1259,7 @@ class GoogleSheetsFindBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsFormatBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -1270,9 +1274,8 @@ class GoogleSheetsFormatBlock(Block):
|
||||
italic: bool = SchemaField(default=False)
|
||||
font_size: int = SchemaField(default=10)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(description="API response or success flag")
|
||||
error: str = SchemaField(description="Error message, if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1383,7 +1386,7 @@ class GoogleSheetsFormatBlock(Block):
|
||||
|
||||
|
||||
class GoogleSheetsCreateSpreadsheetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
@@ -1395,7 +1398,7 @@ class GoogleSheetsCreateSpreadsheetBlock(Block):
|
||||
default=["Sheet1"],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(
|
||||
description="The result containing spreadsheet ID and URL",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,13 @@ from typing import Literal
|
||||
import googlemaps
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -37,7 +43,7 @@ class Place(BaseModel):
|
||||
|
||||
|
||||
class GoogleMapsSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.GOOGLE_MAPS], Literal["api_key"]
|
||||
] = CredentialsField(description="Google Maps API Key")
|
||||
@@ -58,9 +64,8 @@ class GoogleMapsSearchBlock(Block):
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
place: Place = SchemaField(description="Place found")
|
||||
error: str = SchemaField(description="Error message if the search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -8,7 +8,13 @@ from typing import Literal
|
||||
import aiofiles
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
@@ -62,7 +68,7 @@ class HttpMethod(Enum):
|
||||
|
||||
|
||||
class SendWebRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
url: str = SchemaField(
|
||||
description="The URL to send the request to",
|
||||
placeholder="https://api.example.com",
|
||||
@@ -93,7 +99,7 @@ class SendWebRequestBlock(Block):
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
response: object = SchemaField(description="The response from the server")
|
||||
client_error: object = SchemaField(description="Errors on 4xx status codes")
|
||||
server_error: object = SchemaField(description="Errors on 5xx status codes")
|
||||
|
||||
@@ -3,13 +3,19 @@ from backend.blocks.hubspot._auth import (
|
||||
HubSpotCredentialsField,
|
||||
HubSpotCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class HubSpotCompanyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: HubSpotCredentialsInput = HubSpotCredentialsField()
|
||||
operation: str = SchemaField(
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
@@ -22,7 +28,7 @@ class HubSpotCompanyBlock(Block):
|
||||
description="Company domain for get/update operations", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
company: dict = SchemaField(description="Company information")
|
||||
status: str = SchemaField(description="Operation status")
|
||||
|
||||
|
||||
@@ -3,13 +3,19 @@ from backend.blocks.hubspot._auth import (
|
||||
HubSpotCredentialsField,
|
||||
HubSpotCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class HubSpotContactBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: HubSpotCredentialsInput = HubSpotCredentialsField()
|
||||
operation: str = SchemaField(
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
@@ -22,7 +28,7 @@ class HubSpotContactBlock(Block):
|
||||
description="Email address for get/update operations", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
contact: dict = SchemaField(description="Contact information")
|
||||
status: str = SchemaField(description="Operation status")
|
||||
|
||||
|
||||
@@ -5,13 +5,19 @@ from backend.blocks.hubspot._auth import (
|
||||
HubSpotCredentialsField,
|
||||
HubSpotCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class HubSpotEngagementBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: HubSpotCredentialsInput = HubSpotCredentialsField()
|
||||
operation: str = SchemaField(
|
||||
description="Operation to perform (send_email, track_engagement)",
|
||||
@@ -29,7 +35,7 @@ class HubSpotEngagementBlock(Block):
|
||||
default=30,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: dict = SchemaField(description="Operation result")
|
||||
status: str = SchemaField(description="Operation status")
|
||||
|
||||
|
||||
@@ -4,7 +4,13 @@ from typing import Any, Dict, Literal, Optional
|
||||
from pydantic import SecretStr
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -84,7 +90,7 @@ class UpscaleOption(str, Enum):
|
||||
|
||||
|
||||
class IdeogramModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.IDEOGRAM], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -154,9 +160,8 @@ class IdeogramModelBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: str = SchemaField(description="Generated image URL")
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -2,7 +2,14 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
@@ -22,7 +29,7 @@ class AgentInputBlock(Block):
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
@@ -60,6 +67,7 @@ class AgentInputBlock(Block):
|
||||
return schema
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field for interface definition
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -109,7 +117,7 @@ class AgentOutputBlock(Block):
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
@@ -151,6 +159,7 @@ class AgentOutputBlock(Block):
|
||||
return self.get_field_schema("value")
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field for interface definition
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
name: Any = SchemaField(description="The name of the value recorded as output.")
|
||||
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
|
||||
|
||||
class StepThroughItemsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
items: list = SchemaField(
|
||||
advanced=False,
|
||||
description="The list or dictionary of items to iterate over",
|
||||
@@ -26,7 +32,7 @@ class StepThroughItemsBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
item: Any = SchemaField(description="The current item in the iteration")
|
||||
key: Any = SchemaField(
|
||||
description="The key or index of the current item in the iteration",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user