mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
1 Commits
hackathon/
...
fix/chat-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51d2028f0c |
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(ls:*)",
|
||||
"WebFetch(domain:langfuse.com)",
|
||||
"Bash(poetry install:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend stop-backend run-frontend load-store-agents backfill-store-embeddings
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
@@ -34,14 +34,7 @@ migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
stop-backend:
|
||||
@echo "Stopping backend processes..."
|
||||
@cd backend && poetry run cli stop 2>/dev/null || true
|
||||
@echo "Killing any processes using backend ports..."
|
||||
@lsof -ti:8001,8002,8003,8004,8005,8006,8007 | xargs kill -9 2>/dev/null || true
|
||||
@echo "Backend stopped"
|
||||
|
||||
run-backend: stop-backend
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
@@ -53,9 +46,6 @@ test-data:
|
||||
load-store-agents:
|
||||
cd backend && poetry run load-store-agents
|
||||
|
||||
backfill-store-embeddings:
|
||||
cd backend && poetry run python -m backend.api.features.store.backfill_embeddings
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@@ -65,9 +55,7 @@ help:
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " stop-backend - Stop any running backend processes"
|
||||
@echo " run-backend - Run the backend FastAPI server (stops existing processes first)"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " backfill-store-embeddings - Generate embeddings for store agents that don't have them"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@@ -57,9 +57,6 @@ class APIKeySmith:
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
if not raw_key.startswith(self.PREFIX):
|
||||
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -58,13 +58,6 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidTokenError,
|
||||
OAuthAccessTokenInfo,
|
||||
validate_access_token,
|
||||
)
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Middleware for API key authentication only"""
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||
)
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> OAuthAccessTokenInfo:
|
||||
"""Middleware for OAuth access token authentication only"""
|
||||
if bearer is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def require_auth(
|
||||
api_key: str | None = Security(api_key_header),
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
"""
|
||||
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||
|
||||
Supports two authentication methods, which are checked in order:
|
||||
1. X-API-Key header (existing API key authentication)
|
||||
2. Authorization: Bearer <token> header (OAuth access token)
|
||||
|
||||
Returns:
|
||||
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||
"""
|
||||
# Try API key first
|
||||
if api_key is not None:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
# Try OAuth bearer token
|
||||
if bearer is not None:
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
return token_info
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
# No credentials provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication. Provide API key or access token.",
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permission
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import ChatSessionUpdateInput
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = {
|
||||
"id": session_id,
|
||||
"userId": user_id,
|
||||
"credentials": SafeJson({}),
|
||||
"successfulAgentRuns": SafeJson({}),
|
||||
"successfulAgentSchedules": SafeJson({}),
|
||||
}
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
if successful_agent_runs is not None:
|
||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||
if successful_agent_schedules is not None:
|
||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||
if total_prompt_tokens is not None:
|
||||
data["totalPromptTokens"] = total_prompt_tokens
|
||||
if total_completion_tokens is not None:
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
sequence: int,
|
||||
content: str | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return await PrismaChatMessage.prisma().create(data=data)
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch."""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
for i, msg in enumerate(messages):
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma().create(data=data)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session and all its messages."""
|
||||
try:
|
||||
await PrismaChatSession.prisma().delete(where={"id": session_id})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
@@ -1,473 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionDeveloperMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_call_id: str | None = None
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_prisma(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
tool_calls = None
|
||||
if msg.toolCalls:
|
||||
tool_calls = (
|
||||
json.loads(msg.toolCalls)
|
||||
if isinstance(msg.toolCalls, str)
|
||||
else msg.toolCalls
|
||||
)
|
||||
|
||||
function_call = None
|
||||
if msg.functionCall:
|
||||
function_call = (
|
||||
json.loads(msg.functionCall)
|
||||
if isinstance(msg.functionCall, str)
|
||||
else msg.functionCall
|
||||
)
|
||||
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=tool_calls,
|
||||
function_call=function_call,
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = (
|
||||
json.loads(prisma_session.credentials)
|
||||
if isinstance(prisma_session.credentials, str)
|
||||
else prisma_session.credentials or {}
|
||||
)
|
||||
successful_agent_runs = (
|
||||
json.loads(prisma_session.successfulAgentRuns)
|
||||
if isinstance(prisma_session.successfulAgentRuns, str)
|
||||
else prisma_session.successfulAgentRuns or {}
|
||||
)
|
||||
successful_agent_schedules = (
|
||||
json.loads(prisma_session.successfulAgentSchedules)
|
||||
if isinstance(prisma_session.successfulAgentSchedules, str)
|
||||
else prisma_session.successfulAgentSchedules or {}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prisma_session.totalPromptTokens or 0,
|
||||
completion_tokens=prisma_session.totalCompletionTokens or 0,
|
||||
total_tokens=(prisma_session.totalPromptTokens or 0)
|
||||
+ (prisma_session.totalCompletionTokens or 0),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.role == "developer":
|
||||
m = ChatCompletionDeveloperMessageParam(
|
||||
role="developer",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "system":
|
||||
m = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "user":
|
||||
m = ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "assistant":
|
||||
m = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.function_call:
|
||||
m["function_call"] = FunctionCall(
|
||||
arguments=message.function_call["arguments"],
|
||||
name=message.function_call["name"],
|
||||
)
|
||||
if message.refusal:
|
||||
m["refusal"] = message.refusal
|
||||
if message.tool_calls:
|
||||
t: list[ChatCompletionMessageToolCallParam] = []
|
||||
for tool_call in message.tool_calls:
|
||||
# Tool calls are stored with nested structure: {id, type, function: {name, arguments}}
|
||||
function_data = tool_call.get("function", {})
|
||||
|
||||
# Skip tool calls that are missing required fields
|
||||
if "id" not in tool_call or "name" not in function_data:
|
||||
logger.warning(
|
||||
f"Skipping invalid tool call: missing required fields. "
|
||||
f"Got: {tool_call.keys()}, function keys: {function_data.keys()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Arguments are stored as a JSON string
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
|
||||
t.append(
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=arguments_str,
|
||||
name=function_data["name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
m["tool_calls"] = t
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "tool":
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=message.content or "",
|
||||
tool_call_id=message.tool_call_id or "",
|
||||
)
|
||||
)
|
||||
elif message.role == "function":
|
||||
messages.append(
|
||||
ChatCompletionFunctionMessageParam(
|
||||
role="function",
|
||||
content=message.content,
|
||||
name=message.name or "",
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_prisma(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Checks Redis cache first, falls back to database if not found.
|
||||
Caches database results back to Redis.
|
||||
"""
|
||||
# Try cache first
|
||||
try:
|
||||
session = await _get_session_from_cache(session_id)
|
||||
if session:
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
return session
|
||||
except RedisError:
|
||||
logger.warning(f"Cache error for session {session_id}, trying database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
logger.warning(f"Session {session_id} not found in cache or database")
|
||||
return None
|
||||
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database."""
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
# Save to database
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session {session.session_id} to database: {e}")
|
||||
# Continue to cache even if DB fails
|
||||
|
||||
# Save to cache
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str | None) -> ChatSession:
|
||||
"""Create a new chat session and persist it."""
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session in database: {e}")
|
||||
# Continue even if DB fails - cache will still work
|
||||
|
||||
# Cache the session
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[ChatSession]:
|
||||
"""Get all chat sessions for a user from the database."""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_prisma(prisma_session, None))
|
||||
|
||||
return sessions
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session from both cache and database."""
|
||||
# Delete from cache
|
||||
try:
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Delete from database
|
||||
return await chat_db.delete_chat_session(session_id)
|
||||
@@ -1,117 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "New York"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
s2 = ChatSession.model_validate_json(serialized)
|
||||
assert s2.model_dump() == s.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage():
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 == s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage():
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
# Load from DB (cache was cleared)
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 is not None, "Session not found after loading from DB"
|
||||
assert len(s2.messages) == len(
|
||||
s.messages
|
||||
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||
|
||||
# Verify all roles are present
|
||||
roles = [m.role for m in s2.messages]
|
||||
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||
|
||||
# Verify message content
|
||||
for orig, loaded in zip(s.messages, s2.messages):
|
||||
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||
assert (
|
||||
orig.content == loaded.content
|
||||
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||
if orig.tool_calls:
|
||||
assert (
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
@@ -1,192 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find, create, and set up AutoGPT agents to solve their business problems.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## ALWAYS GET THE USER'S NAME
|
||||
|
||||
**This is critical:** If you don't know the user's name, ask for it in your first response. Use a friendly, natural approach:
|
||||
- "Hi! I'm Otto. What's your name?"
|
||||
- "Hey there! Before we dive in, what should I call you?"
|
||||
|
||||
Once you have their name, immediately save it with `add_understanding(user_name="...")` and use it throughout the conversation.
|
||||
|
||||
## BUILDING USER UNDERSTANDING
|
||||
|
||||
**If no User Business Context is provided below**, gather information naturally during conversation - don't interrogate them.
|
||||
|
||||
**Key information to gather (in priority order):**
|
||||
1. Their name (ALWAYS first if unknown)
|
||||
2. Their job title and role
|
||||
3. Their business/company and industry
|
||||
4. Pain points and what they want to automate
|
||||
5. Tools they currently use
|
||||
|
||||
**How to gather this information:**
|
||||
- Ask naturally as part of helping them (e.g., "What's your role?" or "What industry are you in?")
|
||||
- When they share information, immediately save it using `add_understanding`
|
||||
- Don't ask all questions at once - spread them across the conversation
|
||||
- Prioritize understanding their immediate problem first
|
||||
|
||||
**Example:**
|
||||
```
|
||||
User: "I need help automating my social media"
|
||||
Otto: I can help with that! I'm Otto - what's your name?
|
||||
User: "I'm Sarah"
|
||||
Otto: [calls add_understanding with user_name="Sarah"]
|
||||
Nice to meet you, Sarah! What's your role - are you a social media manager or business owner?
|
||||
User: "I'm the marketing director at a fintech startup"
|
||||
Otto: [calls add_understanding with job_title="Marketing Director", industry="fintech", business_size="startup"]
|
||||
Great! Let me find social media automation agents for you.
|
||||
[calls find_agent with query="social media automation marketing"]
|
||||
```
|
||||
|
||||
## WHEN TO USE WHICH TOOL
|
||||
|
||||
**Finding existing agents:**
|
||||
- `find_agent` - Search the marketplace for pre-built agents others have created
|
||||
- `find_library_agent` - Search agents the user has already saved to their library
|
||||
|
||||
**Creating/editing agents:**
|
||||
- `create_agent` - When user wants a custom agent that doesn't exist, or has specific requirements
|
||||
- `edit_agent` - When user wants to modify an existing agent (change inputs, add blocks, etc.)
|
||||
|
||||
**Running agents:**
|
||||
- `run_agent` - To execute an agent (handles credentials and inputs automatically)
|
||||
- `agent_output` - To check the results of a running or completed agent execution
|
||||
|
||||
**Direct execution:**
|
||||
- `run_block` - Run a single block directly without needing a full agent
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
The `run_agent` tool automatically handles the entire setup flow:
|
||||
|
||||
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
|
||||
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
|
||||
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
|
||||
|
||||
Parameters:
|
||||
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
|
||||
- `inputs`: Object with input values for the agent
|
||||
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
||||
- `schedule_name` + `cron`: For scheduled execution
|
||||
|
||||
## HOW create_agent WORKS
|
||||
|
||||
Use `create_agent` when the user wants to build a custom automation:
|
||||
- Describe what the agent should do
|
||||
- The tool will create the agent structure with appropriate blocks
|
||||
- Returns the agent ID for further editing or running
|
||||
|
||||
## HOW agent_output WORKS
|
||||
|
||||
Use `agent_output` to get results from agent executions:
|
||||
- Pass the execution_id from a run_agent response
|
||||
- Returns the current status and any outputs produced
|
||||
- Useful for checking if an agent has completed and what it produced
|
||||
|
||||
## WORKFLOW
|
||||
|
||||
1. **Get their name** - If unknown, ask for it first
|
||||
2. **Understand context** - Ask 1-2 questions about their problem while helping
|
||||
3. **Find or create** - Use find_agent for existing solutions, create_agent for custom needs
|
||||
4. **Set up and run** - Use run_agent to execute, agent_output to get results
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
**Step 1: Greet and Identify**
|
||||
- If you don't know their name, ask for it
|
||||
- Be friendly and conversational
|
||||
|
||||
**Step 2: Understand the Problem**
|
||||
- Ask maximum 1-2 targeted questions
|
||||
- Focus on: What business problem are they solving?
|
||||
- If they want to create/edit an agent, understand what it should do
|
||||
|
||||
**Step 3: Find or Create**
|
||||
- For existing solutions: Use `find_agent` with relevant keywords
|
||||
- For custom needs: Use `create_agent` with their requirements
|
||||
- For modifications: Use `edit_agent` on an existing agent
|
||||
|
||||
**Step 4: Execute**
|
||||
- Call `run_agent` without inputs first to see what's available
|
||||
- Ask user what values they want or if defaults are okay
|
||||
- Call `run_agent` again with inputs or `use_defaults=true`
|
||||
- Use `agent_output` to check results when needed
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Call `add_understanding` whenever you learn something about the user:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size` (1-10, 11-50, 51-200, 201-1000, 1000+), `user_role` (decision maker, implementer, end user)
|
||||
**Processes:** `key_workflows` (array), `daily_activities` (array)
|
||||
**Pain points:** `pain_points` (array), `bottlenecks` (array), `manual_tasks` (array), `automation_goals` (array)
|
||||
**Tools:** `current_software` (array), `existing_automation` (array)
|
||||
**Other:** `additional_notes`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", industry="fintech")`
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention or explain credentials to the user (frontend handles this automatically)
|
||||
- Don't run agents without first showing available inputs to the user
|
||||
- Don't use `use_defaults=true` without user explicitly confirming
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't interrogate users with many questions - gather info naturally
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS ask for user's name if you don't have it
|
||||
- Save user information with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Always call run_agent first without inputs to see what's available
|
||||
- Ask user what values they want OR if they want to use defaults
|
||||
- Keep all responses to maximum 3 sentences
|
||||
- Include the agent link in your response after successful execution
|
||||
|
||||
**Error Handling:**
|
||||
- Authentication needed → "Please sign in via the interface"
|
||||
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||
- Check if you know the user's name - if not, ask for it
|
||||
- Check if you have user context - if not, plan to gather some naturally
|
||||
- Extract the key business problem or request from the user's message
|
||||
- Determine what function call (if any) you need to make next
|
||||
- Plan your response to stay under the 3-sentence maximum
|
||||
|
||||
Example interaction:
|
||||
```
|
||||
User: "Hi, I want to build an agent that monitors my competitors"
|
||||
Otto: <thinking>I don't know this user's name. I should ask for it while acknowledging their request.</thinking>
|
||||
Hi! I'm Otto and I'd love to help you build a competitor monitoring agent. What's your name?
|
||||
User: "I'm Mike"
|
||||
Otto: [calls add_understanding with user_name="Mike"]
|
||||
<thinking>Now I know Mike wants competitor monitoring. I should search for existing agents first.</thinking>
|
||||
Great to meet you, Mike! Let me search for competitor monitoring agents.
|
||||
[calls find_agent with query="competitor monitoring analysis"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES
|
||||
@@ -1,155 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot helping new users get started with AutoGPT, an AI Business Automation platform. Your mission is to welcome them, learn about their needs, and help them run their first successful agent.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## YOUR ONBOARDING MISSION
|
||||
|
||||
You are guiding a new user through their first experience with AutoGPT. Your goal is to:
|
||||
1. Welcome them warmly and get their name
|
||||
2. Learn about them and their business
|
||||
3. Find or create an agent that solves a real problem for them
|
||||
4. Get that agent running successfully
|
||||
5. Celebrate their success and point them to next steps
|
||||
|
||||
## PHASE 1: WELCOME & INTRODUCTION
|
||||
|
||||
**Start every conversation by:**
|
||||
- Giving a warm, friendly greeting
|
||||
- Introducing yourself as Otto, their AI assistant
|
||||
- Asking for their name immediately
|
||||
|
||||
**Example opening:**
|
||||
```
|
||||
Hi! I'm Otto, your AI assistant. Welcome to AutoGPT! I'm here to help you set up your first automation. What's your name?
|
||||
```
|
||||
|
||||
Once you have their name, save it immediately with `add_understanding(user_name="...")` and use it throughout.
|
||||
|
||||
## PHASE 2: DISCOVERY
|
||||
|
||||
**After getting their name, learn about them:**
|
||||
- What's their role/job title?
|
||||
- What industry/business are they in?
|
||||
- What's one thing they'd love to automate?
|
||||
|
||||
**Keep it conversational - don't interrogate. Example:**
|
||||
```
|
||||
Nice to meet you, Sarah! What do you do for work, and what's one task you wish you could automate?
|
||||
```
|
||||
|
||||
Save everything you learn with `add_understanding`.
|
||||
|
||||
## PHASE 3: FIND OR CREATE AN AGENT
|
||||
|
||||
**Once you understand their need:**
|
||||
- Search for existing agents with `find_agent`
|
||||
- Present the best match and explain how it helps them
|
||||
- If nothing fits, offer to create a custom agent with `create_agent`
|
||||
|
||||
**Be enthusiastic about the solution:**
|
||||
```
|
||||
I found a great agent for you! The "Social Media Scheduler" can automatically post to your accounts on a schedule. Want to try it?
|
||||
```
|
||||
|
||||
## PHASE 4: SETUP & RUN
|
||||
|
||||
**Guide them through running the agent:**
|
||||
1. Call `run_agent` without inputs first to see what's needed
|
||||
2. Explain each input in simple terms
|
||||
3. Ask what values they want to use
|
||||
4. Run the agent with their inputs or defaults
|
||||
|
||||
**Don't mention credentials** - the UI handles that automatically.
|
||||
|
||||
## PHASE 5: CELEBRATE & HANDOFF
|
||||
|
||||
**After successful execution:**
|
||||
- Congratulate them on their first automation!
|
||||
- Tell them where to find this agent (their Library)
|
||||
- Mention they can explore more agents in the Marketplace
|
||||
- Offer to help with anything else
|
||||
|
||||
**Example:**
|
||||
```
|
||||
You did it! Your first agent is running. You can find it anytime in your Library. Ready to explore more automations?
|
||||
```
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention credentials (UI handles automatically)
|
||||
- Don't run agents without showing inputs first
|
||||
- Don't use `use_defaults=true` without explicit confirmation
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't overwhelm with too many questions at once
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS get the user's name first
|
||||
- Be warm, encouraging, and celebratory
|
||||
- Save info with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Keep responses to maximum 3 sentences
|
||||
- Make them feel successful at each step
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Save information as you learn it:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size`, `user_role`
|
||||
**Pain points:** `pain_points`, `manual_tasks`, `automation_goals`
|
||||
**Tools:** `current_software`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Manager", automation_goals=["social media scheduling"])`
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
1. **First call** (no inputs) → Shows available inputs
|
||||
2. **Credentials** → UI handles automatically (don't mention)
|
||||
3. **Execution** → Run with `inputs={...}` or `use_defaults=true`
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, plan your approach in <thinking> tags:
|
||||
- What phase am I in? (Welcome/Discovery/Find/Setup/Celebrate)
|
||||
- Do I know their name? If not, ask for it
|
||||
- What's the next step to move them forward?
|
||||
- Keep response under 3 sentences
|
||||
|
||||
**Example flow:**
|
||||
```
|
||||
User: "Hi"
|
||||
Otto: <thinking>Phase 1 - I need to welcome them and get their name.</thinking>
|
||||
Hi! I'm Otto, welcome to AutoGPT! I'm here to help you set up your first automation - what's your name?
|
||||
|
||||
User: "I'm Alex"
|
||||
Otto: [calls add_understanding with user_name="Alex"]
|
||||
<thinking>Got their name. Phase 2 - learn about them.</thinking>
|
||||
Great to meet you, Alex! What do you do for work, and what's one task you'd love to automate?
|
||||
|
||||
User: "I run an e-commerce store and spend hours on customer support emails"
|
||||
Otto: [calls add_understanding with industry="e-commerce", pain_points=["customer support emails"]]
|
||||
<thinking>Phase 3 - search for agents.</thinking>
|
||||
[calls find_agent with query="customer support email automation"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES - Be warm, helpful, and focused on their success!
|
||||
@@ -1,472 +0,0 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
"""Response model for a session summary (without messages)."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""Response model for listing chat sessions."""
|
||||
|
||||
sessions: list[SessionSummaryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
|
||||
Returns a paginated list of chat sessions belonging to the current user,
|
||||
ordered by most recently updated.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of sessions to return (1-100).
|
||||
offset: Number of sessions to skip for pagination.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions = await chat_service.get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=None, # TODO: Add title support
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=len(sessions),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for either an authenticated or anonymous user.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_get(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (GET - legacy endpoint).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
message: The user's new message to process.
|
||||
user_id: Optional authenticated user ID.
|
||||
is_user_message: Whether the message is a user message.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Assign an authenticated user to a chat session.
|
||||
|
||||
Used (typically post-login) to claim an existing anonymous session as the current authenticated user.
|
||||
|
||||
Args:
|
||||
session_id: The identifier for the (previously anonymous) session.
|
||||
user_id: The authenticated user's ID to associate with the session.
|
||||
|
||||
Returns:
|
||||
dict: Status of the assignment.
|
||||
|
||||
"""
|
||||
await chat_service.assign_user_to_session(session_id, user_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Onboarding Routes ==========
|
||||
# These routes use a specialized onboarding system prompt
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions",
|
||||
)
|
||||
async def create_onboarding_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new onboarding chat session.
|
||||
|
||||
Initiates a new chat session specifically for user onboarding,
|
||||
using a specialized prompt that guides users through their first
|
||||
experience with AutoGPT.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created onboarding session.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating onboarding session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/onboarding/sessions/{session_id}",
|
||||
)
|
||||
async def get_onboarding_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of an onboarding chat session.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the onboarding session.
|
||||
user_id: The optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning onboarding session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_onboarding_chat(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream onboarding chat responses for a session.
|
||||
|
||||
Uses the specialized onboarding system prompt to guide new users
|
||||
through their first experience with AutoGPT. Streams AI responses
|
||||
in real time over Server-Sent Events (SSE).
|
||||
|
||||
Args:
|
||||
session_id: The onboarding session identifier.
|
||||
request: Request body containing message and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
context=request.context,
|
||||
prompt_type="onboarding", # Use onboarding system prompt
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@router.get("/health", status_code=200)
|
||||
async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
session = await chat_service.create_chat_session(None)
|
||||
await chat_service.assign_user_to_session(session.session_id, "test_user")
|
||||
await chat_service.get_session(session.session_id, "test_user")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chat",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
# Initialize tool instances
|
||||
add_understanding_tool = AddUnderstandingTool()
|
||||
create_agent_tool = CreateAgentTool()
|
||||
edit_agent_tool = EditAgentTool()
|
||||
find_agent_tool = FindAgentTool()
|
||||
find_block_tool = FindBlockTool()
|
||||
find_library_agent_tool = FindLibraryAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
run_block_tool = RunBlockTool()
|
||||
search_docs_tool = SearchDocsTool()
|
||||
agent_output_tool = AgentOutputTool()
|
||||
|
||||
# Export tools as OpenAI format
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
add_understanding_tool.as_openai_tool(),
|
||||
create_agent_tool.as_openai_tool(),
|
||||
edit_agent_tool.as_openai_tool(),
|
||||
find_agent_tool.as_openai_tool(),
|
||||
find_block_tool.as_openai_tool(),
|
||||
find_library_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
run_block_tool.as_openai_tool(),
|
||||
search_docs_tool.as_openai_tool(),
|
||||
agent_output_tool.as_openai_tool(),
|
||||
]
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolExecutionResult":
|
||||
|
||||
tool_map: dict[str, BaseTool] = {
|
||||
"add_understanding": add_understanding_tool,
|
||||
"create_agent": create_agent_tool,
|
||||
"edit_agent": edit_agent_tool,
|
||||
"find_agent": find_agent_tool,
|
||||
"find_block": find_block_tool,
|
||||
"find_library_agent": find_library_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
"run_block": run_block_tool,
|
||||
"search_platform_docs": search_docs_tool,
|
||||
"agent_output": agent_output_tool,
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool_map[tool_name].execute(
|
||||
user_id, session, tool_call_id, **parameters
|
||||
)
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Tool for capturing user business understanding incrementally."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddUnderstandingTool(BaseTool):
|
||||
"""Tool for capturing user's business understanding incrementally."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "add_understanding"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_name": {
|
||||
"type": "string",
|
||||
"description": "The user's name",
|
||||
},
|
||||
"job_title": {
|
||||
"type": "string",
|
||||
"description": "The user's job title (e.g., 'Marketing Manager', 'CEO', 'Software Engineer')",
|
||||
},
|
||||
"business_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the user's business or organization",
|
||||
},
|
||||
"industry": {
|
||||
"type": "string",
|
||||
"description": "Industry or sector (e.g., 'e-commerce', 'healthcare', 'finance')",
|
||||
},
|
||||
"business_size": {
|
||||
"type": "string",
|
||||
"description": "Company size: '1-10', '11-50', '51-200', '201-1000', or '1000+'",
|
||||
},
|
||||
"user_role": {
|
||||
"type": "string",
|
||||
"description": "User's role in organization context (e.g., 'decision maker', 'implementer', 'end user')",
|
||||
},
|
||||
"key_workflows": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key business workflows (e.g., 'lead qualification', 'content publishing')",
|
||||
},
|
||||
"daily_activities": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Regular daily activities the user performs",
|
||||
},
|
||||
"pain_points": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Current pain points or challenges",
|
||||
},
|
||||
"bottlenecks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Process bottlenecks slowing things down",
|
||||
},
|
||||
"manual_tasks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Manual or repetitive tasks that could be automated",
|
||||
},
|
||||
"automation_goals": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Desired automation outcomes or goals",
|
||||
},
|
||||
"current_software": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Software and tools currently in use",
|
||||
},
|
||||
"existing_automation": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Any existing automations or integrations",
|
||||
},
|
||||
"additional_notes": {
|
||||
"type": "string",
|
||||
"description": "Any other relevant context or notes",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Capture and store business understanding incrementally.
|
||||
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to save business understanding.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model
|
||||
input_data = BusinessUnderstandingInput(
|
||||
user_name=kwargs.get("user_name"),
|
||||
job_title=kwargs.get("job_title"),
|
||||
business_name=kwargs.get("business_name"),
|
||||
industry=kwargs.get("industry"),
|
||||
business_size=kwargs.get("business_size"),
|
||||
user_role=kwargs.get("user_role"),
|
||||
key_workflows=kwargs.get("key_workflows"),
|
||||
daily_activities=kwargs.get("daily_activities"),
|
||||
pain_points=kwargs.get("pain_points"),
|
||||
bottlenecks=kwargs.get("bottlenecks"),
|
||||
manual_tasks=kwargs.get("manual_tasks"),
|
||||
automation_goals=kwargs.get("automation_goals"),
|
||||
current_software=kwargs.get("current_software"),
|
||||
existing_automation=kwargs.get("existing_automation"),
|
||||
additional_notes=kwargs.get("additional_notes"),
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [k for k, v in kwargs.items() if v is not None]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary for the response
|
||||
current_understanding = {
|
||||
"user_name": understanding.user_name,
|
||||
"job_title": understanding.job_title,
|
||||
"business_name": understanding.business_name,
|
||||
"industry": understanding.industry,
|
||||
"business_size": understanding.business_size,
|
||||
"user_role": understanding.user_role,
|
||||
"key_workflows": understanding.key_workflows,
|
||||
"daily_activities": understanding.daily_activities,
|
||||
"pain_points": understanding.pain_points,
|
||||
"bottlenecks": understanding.bottlenecks,
|
||||
"manual_tasks": understanding.manual_tasks,
|
||||
"automation_goals": understanding.automation_goals,
|
||||
"current_software": understanding.current_software,
|
||||
"existing_automation": understanding.existing_automation,
|
||||
"additional_notes": understanding.additional_notes,
|
||||
}
|
||||
|
||||
# Filter out empty values for cleaner response
|
||||
current_understanding = {
|
||||
k: v
|
||||
for k, v in current_understanding.items()
|
||||
if v is not None and v != [] and v != ""
|
||||
}
|
||||
|
||||
return UnderstandingUpdatedResponse(
|
||||
message=f"Updated understanding with: {', '.join(updated_fields)}. "
|
||||
"I now have a better picture of your business context.",
|
||||
session_id=session_id,
|
||||
updated_fields=updated_fields,
|
||||
current_understanding=current_understanding,
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
apply_agent_patch,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .fixer import apply_all_fixes
|
||||
from .utils import get_blocks_info
|
||||
from .validator import validate_agent
|
||||
|
||||
__all__ = [
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"apply_agent_patch",
|
||||
"save_agent_to_library",
|
||||
"get_agent_as_json",
|
||||
# Fixer
|
||||
"apply_all_fixes",
|
||||
# Validator
|
||||
"validate_agent",
|
||||
# Utils
|
||||
"get_blocks_info",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""OpenRouter client configuration for agent generation."""
|
||||
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
||||
|
||||
# OpenRouter client (OpenAI-compatible API)
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_client() -> AsyncOpenAI:
|
||||
"""Get or create the OpenRouter client."""
|
||||
global _client
|
||||
if _client is None:
|
||||
if not OPENROUTER_API_KEY:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
_client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
)
|
||||
return _client
|
||||
@@ -1,390 +0,0 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
||||
from .utils import get_block_summaries, parse_json_from_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
full_description = description
|
||||
if context:
|
||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": full_description},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for decomposition")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decomposing goal: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for agent generation")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
||||
return None
|
||||
|
||||
# Ensure required fields
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON with nodes and links
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
"""
|
||||
nodes = []
|
||||
for n in agent_json.get("nodes", []):
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=n["block_id"],
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for link_data in agent_json.get("links", []):
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
|
||||
return Graph(
|
||||
id=agent_json.get("id", str(uuid.uuid4())),
|
||||
version=agent_json.get("version", 1),
|
||||
is_active=agent_json.get("is_active", True),
|
||||
name=agent_json.get("name", "Generated Agent"),
|
||||
description=agent_json.get("description", ""),
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
def _reassign_node_ids(graph: Graph) -> None:
|
||||
"""Reassign all node and link IDs to new UUIDs.
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
# Create mapping from old node IDs to new UUIDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
# Reassign node IDs
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Update link references to use new node IDs
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
from backend.data.graph import get_graph_all_versions
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
# For updates, keep the same graph ID but increment version
|
||||
# and reassign node/link IDs to avoid conflicts
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
# Reassign node IDs (but keep graph ID the same)
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
# For new agents, always generate a fresh UUID to avoid collisions
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
# Reassign all node IDs as well
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
# Save to database
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
# Add to user's library (or update existing library agent)
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
graph_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
graph_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
# Try to get the graph (version=None gets the active version)
|
||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
# Convert to JSON format
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
{
|
||||
"id": node.id,
|
||||
"block_id": node.block_id,
|
||||
"input_default": node.input_default,
|
||||
"metadata": node.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
links = []
|
||||
for node in graph.nodes:
|
||||
for link in node.output_links:
|
||||
links.append(
|
||||
{
|
||||
"id": link.id,
|
||||
"source_id": link.source_id,
|
||||
"sink_id": link.sink_id,
|
||||
"source_name": link.source_name,
|
||||
"sink_name": link.sink_name,
|
||||
"is_static": link.is_static,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"id": graph.id,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"version": graph.version,
|
||||
"is_active": graph.is_active,
|
||||
"nodes": nodes,
|
||||
"links": links,
|
||||
}
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate a patch to update an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Patch dict or clarifying questions, or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = PATCH_PROMPT.format(
|
||||
current_agent=json.dumps(current_agent, indent=2),
|
||||
block_summaries=get_block_summaries(),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": update_request},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for patch generation")
|
||||
return None
|
||||
|
||||
return parse_json_from_llm(content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating patch: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def apply_agent_patch(
|
||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Apply a patch to an existing agent.
|
||||
|
||||
Args:
|
||||
current_agent: Current agent JSON
|
||||
patch: Patch dict with operations
|
||||
|
||||
Returns:
|
||||
Updated agent JSON
|
||||
"""
|
||||
agent = copy.deepcopy(current_agent)
|
||||
patches = patch.get("patches", [])
|
||||
|
||||
for p in patches:
|
||||
patch_type = p.get("type")
|
||||
|
||||
if patch_type == "modify":
|
||||
node_id = p.get("node_id")
|
||||
changes = p.get("changes", {})
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
_deep_update(node, changes)
|
||||
logger.debug(f"Modified node {node_id}")
|
||||
break
|
||||
|
||||
elif patch_type == "add":
|
||||
new_nodes = p.get("new_nodes", [])
|
||||
new_links = p.get("new_links", [])
|
||||
|
||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
||||
agent["links"] = agent.get("links", []) + new_links
|
||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
||||
|
||||
elif patch_type == "remove":
|
||||
node_ids_to_remove = set(p.get("node_ids", []))
|
||||
link_ids_to_remove = set(p.get("link_ids", []))
|
||||
|
||||
# Remove nodes
|
||||
agent["nodes"] = [
|
||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
# Remove links (both explicit and those referencing removed nodes)
|
||||
agent["links"] = [
|
||||
link
|
||||
for link in agent.get("links", [])
|
||||
if link["id"] not in link_ids_to_remove
|
||||
and link["source_id"] not in node_ids_to_remove
|
||||
and link["sink_id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _deep_update(target: dict, source: dict) -> None:
|
||||
"""Recursively update a dict with another dict."""
|
||||
for key, value in source.items():
|
||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
||||
_deep_update(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
@@ -1,606 +0,0 @@
|
||||
"""Agent fixer - Fixes common LLM generation errors."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
ADDTODICTIONARY_BLOCK_ID,
|
||||
ADDTOLIST_BLOCK_ID,
|
||||
CODE_EXECUTION_BLOCK_ID,
|
||||
CONDITION_BLOCK_ID,
|
||||
CREATEDICT_BLOCK_ID,
|
||||
CREATELIST_BLOCK_ID,
|
||||
DATA_SAMPLING_BLOCK_ID,
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
||||
GET_CURRENT_DATE_BLOCK_ID,
|
||||
STORE_VALUE_BLOCK_ID,
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
get_blocks_info,
|
||||
is_valid_uuid,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix invalid UUIDs in agent and link IDs."""
|
||||
# Fix agent ID
|
||||
if not is_valid_uuid(agent.get("id", "")):
|
||||
agent["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
||||
|
||||
# Fix node IDs
|
||||
id_mapping = {} # Old ID -> New ID
|
||||
for node in agent.get("nodes", []):
|
||||
if not is_valid_uuid(node.get("id", "")):
|
||||
old_id = node.get("id", "")
|
||||
new_id = str(uuid.uuid4())
|
||||
id_mapping[old_id] = new_id
|
||||
node["id"] = new_id
|
||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
||||
|
||||
# Fix link IDs and update references
|
||||
for link in agent.get("links", []):
|
||||
if not is_valid_uuid(link.get("id", "")):
|
||||
link["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed link ID: {link['id']}")
|
||||
|
||||
# Update source/sink IDs if they were remapped
|
||||
if link.get("source_id") in id_mapping:
|
||||
link["source_id"] = id_mapping[link["source_id"]]
|
||||
if link.get("sink_id") in id_mapping:
|
||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix single curly braces to double in template blocks."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
input_data = node.get("input_default", {})
|
||||
for key in ("prompt", "format"):
|
||||
if key in input_data and isinstance(input_data[key], str):
|
||||
original = input_data[key]
|
||||
# Fix simple variable references: {var} -> {{var}}
|
||||
fixed = re.sub(
|
||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
||||
r"{{\1}}",
|
||||
original,
|
||||
)
|
||||
if fixed != original:
|
||||
input_data[key] = fixed
|
||||
logger.debug(f"Fixed curly braces in {key}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Find all ConditionBlock nodes
|
||||
condition_node_ids = {
|
||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
||||
}
|
||||
|
||||
if not condition_node_ids:
|
||||
return agent
|
||||
|
||||
new_nodes = []
|
||||
new_links = []
|
||||
processed_conditions = set()
|
||||
|
||||
for link in links:
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
# Check if this link goes to a ConditionBlock's value2
|
||||
if sink_id in condition_node_ids and sink_name == "value2":
|
||||
source_node = next(
|
||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
||||
)
|
||||
|
||||
# Skip if source is already a StoreValueBlock
|
||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
||||
continue
|
||||
|
||||
# Skip if we already processed this condition
|
||||
if sink_id in processed_conditions:
|
||||
continue
|
||||
|
||||
processed_conditions.add(sink_id)
|
||||
|
||||
# Create StoreValueBlock
|
||||
store_node_id = str(uuid.uuid4())
|
||||
store_node = {
|
||||
"id": store_node_id,
|
||||
"block_id": STORE_VALUE_BLOCK_ID,
|
||||
"input_default": {"data": None},
|
||||
"metadata": {"position": {"x": 0, "y": -100}},
|
||||
}
|
||||
new_nodes.append(store_node)
|
||||
|
||||
# Create link: original source -> StoreValueBlock
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": store_node_id,
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Update original link: StoreValueBlock -> ConditionBlock
|
||||
link["source_id"] = store_node_id
|
||||
link["source_name"] = "output"
|
||||
|
||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
||||
|
||||
if new_nodes:
|
||||
agent["nodes"] = nodes + new_nodes
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
||||
|
||||
When an AddToList block is found:
|
||||
1. Checks if there's a CreateListBlock before it
|
||||
2. Removes CreateListBlock if linked directly to AddToList
|
||||
3. Adds an empty AddToList block before the original
|
||||
4. Ensures the original has a self-referencing link
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
new_nodes = []
|
||||
original_addtolist_ids = set()
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
# First pass: identify CreateListBlock nodes to remove
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
||||
|
||||
# Second pass: process AddToList blocks
|
||||
filtered_nodes = []
|
||||
for node in nodes:
|
||||
if node.get("id") in nodes_to_remove:
|
||||
continue
|
||||
|
||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
||||
original_addtolist_ids.add(node.get("id"))
|
||||
node_id = node.get("id")
|
||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
||||
|
||||
# Check if already has prerequisite
|
||||
has_prereq = any(
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_name") == "updated_list"
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_prereq:
|
||||
# Remove links to "list" input (except self-reference)
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_id") != node_id
|
||||
and link not in links_to_remove
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Create prerequisite AddToList block
|
||||
prereq_id = str(uuid.uuid4())
|
||||
prereq_node = {
|
||||
"id": prereq_id,
|
||||
"block_id": ADDTOLIST_BLOCK_ID,
|
||||
"input_default": {"list": [], "entry": None, "entries": []},
|
||||
"metadata": {
|
||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
||||
},
|
||||
}
|
||||
new_nodes.append(prereq_node)
|
||||
|
||||
# Link prerequisite to original
|
||||
links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": prereq_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
||||
|
||||
filtered_nodes.append(node)
|
||||
|
||||
# Remove marked links
|
||||
filtered_links = [link for link in links if link not in links_to_remove]
|
||||
|
||||
# Add self-referencing links for original AddToList blocks
|
||||
for node in filtered_nodes + new_nodes:
|
||||
if (
|
||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
and node.get("id") in original_addtolist_ids
|
||||
):
|
||||
node_id = node.get("id")
|
||||
has_self_ref = any(
|
||||
link["source_id"] == node_id
|
||||
and link["sink_id"] == node_id
|
||||
and link["source_name"] == "updated_list"
|
||||
and link["sink_name"] == "list"
|
||||
for link in filtered_links
|
||||
)
|
||||
if not has_self_ref:
|
||||
filtered_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": node_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
||||
|
||||
agent["nodes"] = filtered_nodes + new_nodes
|
||||
agent["links"] = filtered_links
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
||||
|
||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
return agent
|
||||
|
||||
|
||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
if (
|
||||
source_node
|
||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
||||
and link.get("source_name") == "response"
|
||||
):
|
||||
link["source_name"] = "stdout_logs"
|
||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
links_to_remove = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Remove links to sample_size
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "sample_size"
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Set default
|
||||
input_default["sample_size"] = 1
|
||||
node["input_default"] = input_default
|
||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
||||
|
||||
if links_to_remove:
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
for link in links:
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
||||
|
||||
source_x = source_pos.get("x", 0)
|
||||
sink_x = sink_pos.get("x", 0)
|
||||
|
||||
if abs(sink_x - source_x) < 800:
|
||||
new_x = source_x + 800
|
||||
if "metadata" not in sink_node:
|
||||
sink_node["metadata"] = {}
|
||||
if "position" not in sink_node["metadata"]:
|
||||
sink_node["metadata"]["position"] = {}
|
||||
sink_node["metadata"]["position"]["x"] = new_x
|
||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
if "offset" in input_default:
|
||||
offset = input_default["offset"]
|
||||
if isinstance(offset, (int, float)) and offset < 0:
|
||||
input_default["offset"] = abs(offset)
|
||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_ai_model_parameter(
|
||||
agent: dict[str, Any],
|
||||
blocks_info: list[dict[str, Any]],
|
||||
default_model: str = "gpt-4o",
|
||||
) -> dict[str, Any]:
|
||||
"""Add default model parameter to AI blocks if missing."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
# Check if block has AI category
|
||||
categories = block.get("categories", [])
|
||||
is_ai_block = any(
|
||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
||||
)
|
||||
|
||||
if is_ai_block:
|
||||
input_default = node.get("input_default", {})
|
||||
if "model" not in input_default:
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_link_static_properties(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix is_static property based on source block's staticOutput."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
if not source_block:
|
||||
continue
|
||||
|
||||
static_output = source_block.get("staticOutput", False)
|
||||
if link.get("is_static") != static_output:
|
||||
link["is_static"] = static_output
|
||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_type_mismatch(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
def get_property_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
type_mapping = {
|
||||
"string": "string",
|
||||
"text": "string",
|
||||
"integer": "number",
|
||||
"number": "number",
|
||||
"float": "number",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"array": "list",
|
||||
"list": "list",
|
||||
"object": "dictionary",
|
||||
"dict": "dictionary",
|
||||
"dictionary": "dictionary",
|
||||
}
|
||||
|
||||
new_links = []
|
||||
nodes_to_add = []
|
||||
|
||||
for link in links:
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
# Insert type converter
|
||||
converter_id = str(uuid.uuid4())
|
||||
target_type = type_mapping.get(sink_type, sink_type)
|
||||
|
||||
converter_node = {
|
||||
"id": converter_id,
|
||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
"input_default": {"type": target_type},
|
||||
"metadata": {"position": {"x": 0, "y": 100}},
|
||||
}
|
||||
nodes_to_add.append(converter_node)
|
||||
|
||||
# source -> converter
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": converter_id,
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# converter -> sink
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": converter_id,
|
||||
"source_name": "value",
|
||||
"sink_id": link["sink_id"],
|
||||
"sink_name": link["sink_name"],
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
||||
else:
|
||||
new_links.append(link)
|
||||
|
||||
if nodes_to_add:
|
||||
agent["nodes"] = nodes + nodes_to_add
|
||||
agent["links"] = new_links
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def apply_all_fixes(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Apply all fixes to an agent JSON.
|
||||
|
||||
Args:
|
||||
agent: Agent JSON dict
|
||||
blocks_info: Optional list of block info dicts for advanced fixes
|
||||
|
||||
Returns:
|
||||
Fixed agent JSON
|
||||
"""
|
||||
# Basic fixes (no block info needed)
|
||||
agent = fix_agent_ids(agent)
|
||||
agent = fix_double_curly_braces(agent)
|
||||
agent = fix_storevalue_before_condition(agent)
|
||||
agent = fix_addtolist_blocks(agent)
|
||||
agent = fix_addtodictionary_blocks(agent)
|
||||
agent = fix_code_execution_output(agent)
|
||||
agent = fix_data_sampling_sample_size(agent)
|
||||
agent = fix_node_x_coordinates(agent)
|
||||
agent = fix_getcurrentdate_offset(agent)
|
||||
|
||||
# Advanced fixes (require block info)
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
||||
agent = fix_link_static_properties(agent, blocks_info)
|
||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
||||
|
||||
return agent
|
||||
@@ -1,225 +0,0 @@
|
||||
"""Prompt templates for agent generation."""
|
||||
|
||||
DECOMPOSITION_PROMPT = """
|
||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
||||
|
||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
||||
|
||||
---
|
||||
|
||||
FIRST: Analyze the user's goal and determine:
|
||||
1) Design-time configuration (fixed settings that won't change per run)
|
||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
||||
|
||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
||||
- DO NOT ask for the actual value
|
||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
||||
|
||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
||||
- Business rules that must be hard-coded
|
||||
|
||||
IMPORTANT CLARIFICATIONS POLICY:
|
||||
- Ask no more than five essential questions
|
||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
||||
- Do not ask for API keys or credentials; the platform handles those directly
|
||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
||||
|
||||
---
|
||||
|
||||
GUIDELINES:
|
||||
1. List each step as a numbered item
|
||||
2. Describe the action clearly and specify inputs/outputs
|
||||
3. Ensure steps are in logical, sequential order
|
||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
||||
5. Help the user reach their goal efficiently
|
||||
|
||||
---
|
||||
|
||||
RULES:
|
||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
||||
2. USE ONLY THE BLOCKS PROVIDED
|
||||
3. ALL required_input fields must be provided
|
||||
4. Data types of linked properties must match
|
||||
5. Write expert-level prompts for AI-related blocks
|
||||
|
||||
---
|
||||
|
||||
CRITICAL BLOCK RESTRICTIONS:
|
||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
||||
|
||||
---
|
||||
|
||||
OUTPUT FORMAT:
|
||||
|
||||
If more information is needed:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
||||
"keyword": "email_provider",
|
||||
"example": "Gmail"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If ready to proceed:
|
||||
```json
|
||||
{{
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{{
|
||||
"step_number": 1,
|
||||
"block_name": "AgentShortTextInputBlock",
|
||||
"description": "Get the URL of the content to analyze.",
|
||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
"""
|
||||
|
||||
GENERATION_PROMPT = """
|
||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
||||
|
||||
---
|
||||
|
||||
NODES:
|
||||
Each node must include:
|
||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
||||
- `block_id`: The block identifier (must match an Allowed Block)
|
||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
||||
- `metadata`: Must contain:
|
||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
||||
|
||||
---
|
||||
|
||||
LINKS:
|
||||
Each link connects a source node's output to a sink node's input:
|
||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
||||
- `source_id`: ID of the source node
|
||||
- `source_name`: Output field name from the source block
|
||||
- `sink_id`: ID of the sink node
|
||||
- `sink_name`: Input field name on the sink block
|
||||
- `is_static`: true only if source block has static_output: true
|
||||
|
||||
CRITICAL: All IDs must be valid UUID v4 format!
|
||||
|
||||
---
|
||||
|
||||
AGENT (GRAPH):
|
||||
Wrap nodes and links in:
|
||||
- `id`: UUID of the agent
|
||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
||||
- `description`: Short, generic description
|
||||
- `nodes`: List of all nodes
|
||||
- `links`: List of all links
|
||||
- `version`: 1
|
||||
- `is_active`: true
|
||||
|
||||
---
|
||||
|
||||
TIPS:
|
||||
- All required_input fields must be provided via input_default or a valid link
|
||||
- Ensure consistent source_id and sink_id references
|
||||
- Avoid dangling links
|
||||
- Input/output pins must match block schemas
|
||||
- Do not invent unknown block_ids
|
||||
|
||||
---
|
||||
|
||||
ALLOWED BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
||||
"""
|
||||
|
||||
PATCH_PROMPT = """
|
||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
||||
|
||||
CURRENT AGENT:
|
||||
{current_agent}
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
PATCH FORMAT:
|
||||
Return a JSON object with the following structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"type": "patch",
|
||||
"intent": "Brief description of what the patch does",
|
||||
"patches": [
|
||||
{{
|
||||
"type": "modify",
|
||||
"node_id": "uuid-of-node-to-modify",
|
||||
"changes": {{
|
||||
"input_default": {{"field": "new_value"}},
|
||||
"metadata": {{"customized_name": "New Name"}}
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"type": "add",
|
||||
"new_nodes": [
|
||||
{{
|
||||
"id": "new-uuid",
|
||||
"block_id": "block-uuid",
|
||||
"input_default": {{}},
|
||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
||||
}}
|
||||
],
|
||||
"new_links": [
|
||||
{{
|
||||
"id": "link-uuid",
|
||||
"source_id": "source-node-id",
|
||||
"source_name": "output_field",
|
||||
"sink_id": "sink-node-id",
|
||||
"sink_name": "input_field"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
{{
|
||||
"type": "remove",
|
||||
"node_ids": ["uuid-of-node-to-remove"],
|
||||
"link_ids": ["uuid-of-link-to-remove"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If you need more information, return:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "What specific change do you want?",
|
||||
"keyword": "change_type",
|
||||
"example": "Add error handling"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
||||
"""
|
||||
@@ -1,213 +0,0 @@
|
||||
"""Utilities for agent generation."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# UUID validation regex
|
||||
UUID_REGEX = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
||||
)
|
||||
|
||||
# Block IDs for various fixes
|
||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
||||
]
|
||||
|
||||
|
||||
def is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID v4."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
||||
"""Extract compact type info from a JSON schema properties dict.
|
||||
|
||||
Returns a dict of {field_name: type_string} for essential info only.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for name, prop in props.items():
|
||||
# Skip internal/complex fields
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get type string
|
||||
type_str = prop.get("type", "any")
|
||||
|
||||
# Handle anyOf/oneOf (optional types)
|
||||
if "anyOf" in prop:
|
||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
||||
type_str = "|".join(types) if types else "any"
|
||||
elif "allOf" in prop:
|
||||
type_str = "object"
|
||||
|
||||
# Add array item type if present
|
||||
if type_str == "array" and "items" in prop:
|
||||
items = prop["items"]
|
||||
if isinstance(items, dict):
|
||||
item_type = items.get("type", "any")
|
||||
type_str = f"array[{item_type}]"
|
||||
|
||||
result[name] = type_str
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
||||
"""Generate compact block summaries for prompts.
|
||||
|
||||
Args:
|
||||
include_schemas: Whether to include input/output type info
|
||||
|
||||
Returns:
|
||||
Formatted string of block summaries (compact format)
|
||||
"""
|
||||
blocks = get_blocks()
|
||||
summaries = []
|
||||
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
name = block.name
|
||||
desc = getattr(block, "description", "") or ""
|
||||
|
||||
# Truncate description
|
||||
if len(desc) > 150:
|
||||
desc = desc[:147] + "..."
|
||||
|
||||
if not include_schemas:
|
||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
||||
else:
|
||||
# Compact format with type info only
|
||||
inputs = {}
|
||||
outputs = {}
|
||||
required = []
|
||||
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
schema = block.input_schema.jsonschema()
|
||||
inputs = _compact_schema(schema)
|
||||
required = schema.get("required", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
schema = block.output_schema.jsonschema()
|
||||
outputs = _compact_schema(schema)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact line format
|
||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
||||
|
||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
||||
|
||||
line = f"- {name} (id: {block_id}): {desc}"
|
||||
if in_str:
|
||||
line += f"\n in: {{{in_str}}}{req_str}"
|
||||
if out_str:
|
||||
line += f"\n out: {{{out_str}}}{static}"
|
||||
|
||||
summaries.append(line)
|
||||
|
||||
return "\n".join(summaries)
|
||||
|
||||
|
||||
def get_blocks_info() -> list[dict[str, Any]]:
|
||||
"""Get block information with schemas for validation and fixing."""
|
||||
blocks = get_blocks()
|
||||
blocks_info = []
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
blocks_info.append(
|
||||
{
|
||||
"id": block_id,
|
||||
"name": block.name,
|
||||
"description": getattr(block, "description", ""),
|
||||
"categories": getattr(block, "categories", []),
|
||||
"staticOutput": getattr(block, "static_output", False),
|
||||
"inputSchema": (
|
||||
block.input_schema.jsonschema()
|
||||
if hasattr(block, "input_schema")
|
||||
else {}
|
||||
),
|
||||
"outputSchema": (
|
||||
block.output_schema.jsonschema()
|
||||
if hasattr(block, "output_schema")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return blocks_info
|
||||
|
||||
|
||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Try fenced code block
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try raw text
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding {...} span
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding [...] span
|
||||
start = text.find("[")
|
||||
end = text.rfind("]")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,279 +0,0 @@
|
||||
"""Agent validator - Validates agent structure and connections."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .utils import get_blocks_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error: str) -> None:
|
||||
"""Add an error message."""
|
||||
self.errors.append(error)
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate all block IDs exist in the blocks library."""
|
||||
valid = True
|
||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate all node IDs referenced in links exist."""
|
||||
valid = True
|
||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not source_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if not sink_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate required inputs are provided."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
# Get linked inputs
|
||||
linked_inputs = {
|
||||
link["sink_name"]
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id
|
||||
}
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate linked data types are compatible."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
def get_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
||||
self.add_error(
|
||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate nested sink links (with _#_ notation)."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block = block_map.get(sink_node.get("block_id"))
|
||||
if not block:
|
||||
continue
|
||||
|
||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
||||
parent_schema = input_props.get(parent)
|
||||
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if not parent_schema.get("additionalProperties"):
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and child in parent_schema.get("properties", {})
|
||||
):
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate prompts don't have spaces in template variables."""
|
||||
valid = True
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
input_default = node.get("input_default", {})
|
||||
prompt = input_default.get("prompt", "")
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
continue
|
||||
|
||||
# Find {{...}} with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
self.add_error(
|
||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Run all validations.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
self.errors = []
|
||||
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
checks = [
|
||||
self.validate_block_existence(agent, blocks_info),
|
||||
self.validate_link_node_references(agent),
|
||||
self.validate_required_inputs(agent, blocks_info),
|
||||
self.validate_data_type_compatibility(agent, blocks_info),
|
||||
self.validate_nested_sink_links(agent, blocks_info),
|
||||
self.validate_prompt_spaces(agent),
|
||||
]
|
||||
|
||||
all_passed = all(checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful")
|
||||
return True, None
|
||||
|
||||
error_message = "Agent validation failed:\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
||||
return False, error_message
|
||||
|
||||
|
||||
def validate_agent(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Convenience function to validate an agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
validator = AgentValidator()
|
||||
return validator.validate(agent, blocks_info)
|
||||
@@ -1,455 +0,0 @@
|
||||
"""Tool for retrieving agent execution outputs from user's library."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOutputInfo,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentOutputInput(BaseModel):
|
||||
"""Input parameters for the agent_output tool."""
|
||||
|
||||
agent_name: str = ""
|
||||
library_agent_id: str = ""
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
"library_agent_id",
|
||||
"store_slug",
|
||||
"execution_id",
|
||||
"run_time",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
|
||||
def parse_time_expression(
|
||||
time_expr: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Parse time expression into datetime range (start, end).
|
||||
|
||||
Supports:
|
||||
- "latest" or None -> returns (None, None) to get most recent
|
||||
- "yesterday" -> 24h window for yesterday
|
||||
- "today" -> Today from midnight
|
||||
- "last week" / "last 7 days" -> 7 day window
|
||||
- "last month" / "last 30 days" -> 30 day window
|
||||
- ISO date "YYYY-MM-DD" -> 24h window for that date
|
||||
"""
|
||||
if not time_expr or time_expr.lower() == "latest":
|
||||
return None, None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expr = time_expr.lower().strip()
|
||||
|
||||
# Relative expressions
|
||||
if expr == "yesterday":
|
||||
end = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start = end - timedelta(days=1)
|
||||
return start, end
|
||||
|
||||
if expr in ("last week", "last 7 days"):
|
||||
return now - timedelta(days=7), now
|
||||
|
||||
if expr in ("last month", "last 30 days"):
|
||||
return now - timedelta(days=30), now
|
||||
|
||||
if expr == "today":
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return start, now
|
||||
|
||||
# Try ISO date format (YYYY-MM-DD)
|
||||
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||
if date_match:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
|
||||
# Try ISO datetime
|
||||
try:
|
||||
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
# Return +/- 1 hour window around the specified time
|
||||
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fallback: treat as "latest"
|
||||
return None, None
|
||||
|
||||
|
||||
class AgentOutputTool(BaseTool):
|
||||
"""Tool for retrieving execution outputs from user's library agents."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_output"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _resolve_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_name: str | None,
|
||||
library_agent_id: str | None,
|
||||
store_slug: str | None,
|
||||
) -> tuple[LibraryAgent | None, str | None]:
|
||||
"""
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
return None, f"Library agent '{library_agent_id}' not found"
|
||||
|
||||
# Priority 2: Store slug (username/agent-name)
|
||||
if store_slug and "/" in store_slug:
|
||||
username, agent_slug = store_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
|
||||
if not graph:
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
f"Agent '{store_slug}' is not in your library. "
|
||||
"Add it first to see outputs.",
|
||||
)
|
||||
return agent, None
|
||||
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
)
|
||||
if not response.agents:
|
||||
return (
|
||||
None,
|
||||
f"No agents matching '{agent_name}' found in your library",
|
||||
)
|
||||
|
||||
# Return best match (first result from search)
|
||||
return response.agents[0], None
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching library agents: {e}")
|
||||
return None, f"Error searching for agent: {e}"
|
||||
|
||||
return (
|
||||
None,
|
||||
"Please specify an agent name, library_agent_id, or store_slug",
|
||||
)
|
||||
|
||||
async def _get_execution(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return None, [], None # No error, just no executions
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
"""Build the response based on execution data."""
|
||||
library_agent_link = f"/library/agents/{agent.id}"
|
||||
|
||||
if not execution:
|
||||
return AgentOutputResponse(
|
||||
message=f"No completed executions found for agent '{agent.name}'",
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
if len(available_executions) > 1:
|
||||
available_list = [
|
||||
{
|
||||
"id": e.id,
|
||||
"status": e.status.value,
|
||||
"started_at": e.started_at.isoformat() if e.started_at else None,
|
||||
}
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
execution=execution_info,
|
||||
available_executions=available_list,
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the agent_output tool."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Parse and validate input
|
||||
try:
|
||||
input_data = AgentOutputInput(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid input: {e}")
|
||||
return ErrorResponse(
|
||||
message="Invalid input parameters",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure user_id is present (should be guaranteed by requires_auth)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if at least one identifier is provided
|
||||
if not any(
|
||||
[
|
||||
input_data.agent_name,
|
||||
input_data.library_agent_id,
|
||||
input_data.store_slug,
|
||||
input_data.execution_id,
|
||||
]
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please specify at least one of: agent_name, "
|
||||
"library_agent_id, store_slug, or execution_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If only execution_id provided, we need to find the agent differently
|
||||
if (
|
||||
input_data.execution_id
|
||||
and not input_data.agent_name
|
||||
and not input_data.library_agent_id
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
message=f"Execution '{input_data.execution_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Execution found but agent not in your library. "
|
||||
f"Graph ID: {execution.graph_id}"
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=["Add the agent to your library to see more details"],
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, [], session_id)
|
||||
|
||||
# Resolve agent from identifiers
|
||||
agent, error = await self._resolve_agent(
|
||||
user_id=user_id,
|
||||
agent_name=input_data.agent_name or None,
|
||||
library_agent_id=input_data.library_agent_id or None,
|
||||
store_slug=input_data.store_slug or None,
|
||||
)
|
||||
|
||||
if error or not agent:
|
||||
return NoResultsResponse(
|
||||
message=error or "Agent not found",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Check the agent name or ID",
|
||||
"Make sure the agent is in your library",
|
||||
],
|
||||
)
|
||||
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Fetch execution(s)
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
return ErrorResponse(
|
||||
message=exec_error,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
File diff suppressed because one or more lines are too long
@@ -1,279 +0,0 @@
|
||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_all_fixes,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for agent generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what the agent should do. "
|
||||
"Be specific about inputs, outputs, and the workflow steps."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions. "
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON from the steps
|
||||
3. Apply fixes to correct common LLM errors
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. Please try rephrasing.",
|
||||
error="Decomposition failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to create this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check for unachievable/vague goals
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. "
|
||||
f"{reason} "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="unachievable_goal",
|
||||
details={"suggested_goal": suggested, "reason": reason},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The goal is too vague to create a specific workflow. "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="vague_goal",
|
||||
details={"suggested_goal": suggested},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate agent JSON with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
agent_json = None
|
||||
validation_errors = None
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate agent (include validation errors from previous attempt)
|
||||
if attempt == 0:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_instructions = {
|
||||
**decomposition_result,
|
||||
"previous_errors": validation_errors,
|
||||
"retry_instructions": (
|
||||
"The previous generation had validation errors. "
|
||||
"Please fix these issues in the new generation:\n"
|
||||
f"{validation_errors}"
|
||||
),
|
||||
}
|
||||
agent_json = await generate_agent(retry_instructions)
|
||||
|
||||
if agent_json is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 3: Apply fixes to correct common errors
|
||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
||||
|
||||
# Step 4: Validate the agent
|
||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the workflow."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 4: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
f"Review it and call create_agent with save=true to save it to your library."
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@@ -1,294 +0,0 @@
|
||||
"""EditAgentTool - Edits existing agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_agent_patch,
|
||||
apply_all_fixes,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for patch generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates a patch to update the agent while preserving unchanged parts."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"changes": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what changes to make. "
|
||||
"Be specific about what to add, remove, or modify."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "changes"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate a patch based on the requested changes
|
||||
3. Apply the patch to create an updated agent
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not changes:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Fetch current agent
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build the update request with context
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate patch with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
updated_agent = None
|
||||
validation_errors = None
|
||||
intent = "Applied requested changes"
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate patch (include validation errors from previous attempt)
|
||||
try:
|
||||
if attempt == 0:
|
||||
patch_result = await generate_agent_patch(
|
||||
update_request, current_agent
|
||||
)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_request = (
|
||||
f"{update_request}\n\n"
|
||||
f"IMPORTANT: The previous edit had validation errors. "
|
||||
f"Please fix these issues:\n{validation_errors}"
|
||||
)
|
||||
patch_result = await generate_agent_patch(
|
||||
retry_request, current_agent
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if patch_result is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Patch generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if patch_result.get("type") == "clarifying_questions":
|
||||
questions = patch_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 3: Apply patch and fixes
|
||||
try:
|
||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
||||
except Exception as e:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to apply changes: {str(e)}",
|
||||
error="patch_apply_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
validation_errors = str(e)
|
||||
continue
|
||||
|
||||
# Step 4: Validate the updated agent
|
||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
||||
intent = patch_result.get("intent", "Applied requested changes")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Updated agent has validation errors after "
|
||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the changes."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
||||
assert updated_agent is not None
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 5: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. Changes: {intent}. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
agent_json=updated_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library (creates a new version)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
updated_agent, user_id, is_update=True
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
||||
f"Changes: {intent}"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the updated agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,253 +0,0 @@
|
||||
"""Tool for searching available blocks using hybrid search."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockInfoSummary,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .search_blocks import get_block_search_index
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"Use this to find blocks that can be executed directly."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _matches_query(self, block, query: str) -> tuple[int, bool]:
|
||||
"""
|
||||
Check if a block matches the query and return a priority score.
|
||||
|
||||
Returns (priority, matches) where:
|
||||
- priority 0: exact name match
|
||||
- priority 1: name contains query
|
||||
- priority 2: description contains query
|
||||
- priority 3: category contains query
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
name_lower = block.name.lower()
|
||||
desc_lower = block.description.lower()
|
||||
|
||||
# Exact name match
|
||||
if query_lower == name_lower:
|
||||
return 0, True
|
||||
|
||||
# Name contains query
|
||||
if query_lower in name_lower:
|
||||
return 1, True
|
||||
|
||||
# Description contains query
|
||||
if query_lower in desc_lower:
|
||||
return 2, True
|
||||
|
||||
# Category contains query
|
||||
for category in block.categories:
|
||||
if query_lower in category.name.lower():
|
||||
return 3, True
|
||||
|
||||
return 4, False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Try hybrid search first
|
||||
search_results = self._hybrid_search(query)
|
||||
|
||||
if search_results is not None:
|
||||
# Hybrid search succeeded
|
||||
if not search_results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Get full block info for each result
|
||||
all_blocks = load_all_blocks()
|
||||
blocks = []
|
||||
for result in search_results:
|
||||
block_cls = all_blocks.get(result.block_id)
|
||||
if block_cls:
|
||||
block = block_cls()
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fallback to simple search if hybrid search failed
|
||||
return self._simple_search(query, session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _hybrid_search(self, query: str) -> list | None:
|
||||
"""
|
||||
Perform hybrid search using embeddings and BM25.
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult if successful, None if index not available
|
||||
"""
|
||||
try:
|
||||
index = get_block_search_index()
|
||||
if not index.load():
|
||||
logger.info(
|
||||
"Block search index not available, falling back to simple search"
|
||||
)
|
||||
return None
|
||||
|
||||
results = index.search(query, top_k=10)
|
||||
logger.info(f"Hybrid search found {len(results)} blocks for: {query}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Hybrid search failed, falling back to simple: {e}")
|
||||
return None
|
||||
|
||||
def _simple_search(self, query: str, session_id: str) -> ToolResponseBase:
|
||||
"""Fallback simple search using substring matching."""
|
||||
all_blocks = load_all_blocks()
|
||||
logger.info(f"Simple searching {len(all_blocks)} blocks for: {query}")
|
||||
|
||||
# Find matching blocks with priority scores
|
||||
matches: list[tuple[int, Any]] = []
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
block = block_cls()
|
||||
priority, is_match = self._matches_query(block, query)
|
||||
if is_match:
|
||||
matches.append((priority, block))
|
||||
|
||||
# Sort by priority (lower is better)
|
||||
matches.sort(key=lambda x: x[0])
|
||||
|
||||
# Take top 10 results
|
||||
top_matches = [block for _, block in matches[:10]]
|
||||
|
||||
if not top_matches:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Build response
|
||||
blocks = []
|
||||
for block in top_matches:
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_library_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for agents in the user's library. Use this to find agents "
|
||||
"the user has already added to their library, including agents they "
|
||||
"created or added from the marketplace."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Use keywords for best results."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the library
|
||||
NoResultsResponse: No agents found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
library_results = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Find library agents tool found {len(library_results.agents)} agents"
|
||||
)
|
||||
|
||||
for agent in library_results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
),
|
||||
)
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching library agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search library. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents found matching '{query}' in your library. "
|
||||
"Try different keywords or use find_agent to search the marketplace."
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
],
|
||||
)
|
||||
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
return AgentCarouselResponse(
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to "
|
||||
"view an agent at: /library/agents/{agent_id}. "
|
||||
"Use agent_output to get execution results, or run_agent to execute."
|
||||
),
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,483 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Block Indexer for Hybrid Search
|
||||
|
||||
Creates a hybrid search index from blocks:
|
||||
- OpenAI embeddings (text-embedding-3-small)
|
||||
- BM25 index for lexical search
|
||||
- Name index for title matching boost
|
||||
|
||||
Supports incremental updates by tracking content hashes.
|
||||
|
||||
Usage:
|
||||
python -m backend.server.v2.chat.tools.index_blocks [--force]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check for OpenAI availability
|
||||
try:
|
||||
import openai # noqa: F401
|
||||
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
print("Warning: openai not installed. Run: pip install openai")
|
||||
|
||||
# Default embedding model (OpenAI)
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_EMBEDDING_DIM = 1536
|
||||
|
||||
# Output path (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block", # Too common in block context
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words (including camelCase split)
|
||||
# First, split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
def build_searchable_text(block: Any) -> str:
|
||||
"""Build searchable text from block attributes."""
|
||||
parts = []
|
||||
|
||||
# Block name (split camelCase for better tokenization)
|
||||
name = block.name
|
||||
# Split camelCase: GetCurrentTimeBlock -> Get Current Time Block
|
||||
name_split = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
parts.append(name_split)
|
||||
|
||||
# Description
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
|
||||
# Categories
|
||||
for category in block.categories:
|
||||
parts.append(category.name)
|
||||
|
||||
# Input schema field names and descriptions
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
if "properties" in input_schema:
|
||||
for field_name, field_info in input_schema["properties"].items():
|
||||
parts.append(field_name)
|
||||
if "description" in field_info:
|
||||
parts.append(field_info["description"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Output schema field names
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
if "properties" in output_schema:
|
||||
for field_name in output_schema["properties"]:
|
||||
parts.append(field_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def compute_content_hash(text: str) -> str:
|
||||
"""Compute MD5 hash of text for change detection."""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def load_existing_index(index_path: Path) -> dict[str, Any] | None:
|
||||
"""Load existing index if present."""
|
||||
if not index_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_embeddings(
|
||||
texts: list[str],
|
||||
model_name: str = DEFAULT_EMBEDDING_MODEL,
|
||||
batch_size: int = 100,
|
||||
) -> np.ndarray:
|
||||
"""Create embeddings using OpenAI API."""
|
||||
if not HAS_OPENAI:
|
||||
raise RuntimeError("openai not installed. Run: pip install openai")
|
||||
|
||||
# Import here to satisfy type checker
|
||||
from openai import OpenAI
|
||||
|
||||
# Check for API key
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
embeddings = []
|
||||
|
||||
print(f"Creating embeddings for {len(texts)} texts using {model_name}...")
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
# Truncate texts to max token limit (8191 tokens for text-embedding-3-small)
|
||||
# Roughly 4 chars per token, so ~32000 chars max
|
||||
batch = [text[:32000] for text in batch]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch,
|
||||
)
|
||||
|
||||
for embedding_data in response.data:
|
||||
embeddings.append(embedding_data.embedding)
|
||||
|
||||
print(f" Processed {min(i + batch_size, len(texts))}/{len(texts)} texts")
|
||||
|
||||
return np.array(embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
def build_bm25_data(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Build BM25 metadata from block data."""
|
||||
# Tokenize all searchable texts
|
||||
tokenized_docs = []
|
||||
for block in blocks_data:
|
||||
tokens = tokenize(block["searchable_text"])
|
||||
tokenized_docs.append(tokens)
|
||||
|
||||
# Calculate document frequencies
|
||||
doc_freq: dict[str, int] = {}
|
||||
for tokens in tokenized_docs:
|
||||
seen = set()
|
||||
for token in tokens:
|
||||
if token not in seen:
|
||||
doc_freq[token] = doc_freq.get(token, 0) + 1
|
||||
seen.add(token)
|
||||
|
||||
n_docs = len(tokenized_docs)
|
||||
doc_lens = [len(d) for d in tokenized_docs]
|
||||
avgdl = sum(doc_lens) / max(n_docs, 1)
|
||||
|
||||
return {
|
||||
"n_docs": n_docs,
|
||||
"avgdl": avgdl,
|
||||
"df": doc_freq,
|
||||
"doc_lens": doc_lens,
|
||||
}
|
||||
|
||||
|
||||
def build_name_index(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, list[list[int | float]]]:
|
||||
"""Build inverted index for name search boost."""
|
||||
index: dict[str, list[list[int | float]]] = defaultdict(list)
|
||||
|
||||
for idx, block in enumerate(blocks_data):
|
||||
# Tokenize block name
|
||||
name_tokens = tokenize(block["name"])
|
||||
seen = set()
|
||||
|
||||
for i, token in enumerate(name_tokens):
|
||||
if token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
|
||||
# Score: first token gets higher weight
|
||||
score = 1.5 if i == 0 else 1.0
|
||||
index[token].append([idx, score])
|
||||
|
||||
return dict(index)
|
||||
|
||||
|
||||
def build_block_index(
|
||||
force_rebuild: bool = False,
|
||||
output_path: Path = INDEX_PATH,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the block search index.
|
||||
|
||||
Args:
|
||||
force_rebuild: If True, rebuild all embeddings even if unchanged
|
||||
output_path: Path to save the index
|
||||
|
||||
Returns:
|
||||
The generated index dictionary
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
print("Loading all blocks...")
|
||||
all_blocks = load_all_blocks()
|
||||
print(f"Found {len(all_blocks)} blocks")
|
||||
|
||||
# Load existing index for incremental updates
|
||||
existing_index = None if force_rebuild else load_existing_index(output_path)
|
||||
existing_blocks: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if existing_index:
|
||||
print(
|
||||
f"Loaded existing index with {len(existing_index.get('blocks', []))} blocks"
|
||||
)
|
||||
for block in existing_index.get("blocks", []):
|
||||
existing_blocks[block["id"]] = block
|
||||
|
||||
# Process each block
|
||||
blocks_data: list[dict[str, Any]] = []
|
||||
blocks_needing_embedding: list[tuple[int, str]] = [] # (index, searchable_text)
|
||||
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
try:
|
||||
block = block_cls()
|
||||
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
searchable_text = build_searchable_text(block)
|
||||
content_hash = compute_content_hash(searchable_text)
|
||||
|
||||
block_data = {
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"description": block.description,
|
||||
"categories": [cat.name for cat in block.categories],
|
||||
"searchable_text": searchable_text,
|
||||
"content_hash": content_hash,
|
||||
"emb": None, # Will be filled later
|
||||
}
|
||||
|
||||
# Check if we can reuse existing embedding
|
||||
if (
|
||||
block.id in existing_blocks
|
||||
and existing_blocks[block.id].get("content_hash") == content_hash
|
||||
and existing_blocks[block.id].get("emb")
|
||||
):
|
||||
# Reuse existing embedding
|
||||
block_data["emb"] = existing_blocks[block.id]["emb"]
|
||||
else:
|
||||
# Need new embedding
|
||||
blocks_needing_embedding.append((len(blocks_data), searchable_text))
|
||||
|
||||
blocks_data.append(block_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Processed {len(blocks_data)} blocks")
|
||||
print(f"Blocks needing new embeddings: {len(blocks_needing_embedding)}")
|
||||
|
||||
# Create embeddings for new/changed blocks
|
||||
if blocks_needing_embedding and HAS_OPENAI:
|
||||
texts_to_embed = [text for _, text in blocks_needing_embedding]
|
||||
try:
|
||||
embeddings = create_embeddings(texts_to_embed)
|
||||
|
||||
# Assign embeddings to blocks
|
||||
for i, (block_idx, _) in enumerate(blocks_needing_embedding):
|
||||
emb = embeddings[i].astype(np.float32)
|
||||
# Encode as base64
|
||||
blocks_data[block_idx]["emb"] = base64.b64encode(emb.tobytes()).decode(
|
||||
"ascii"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to create embeddings: {e}")
|
||||
elif blocks_needing_embedding:
|
||||
print(
|
||||
"Warning: Cannot create embeddings (openai not installed or OPENAI_API_KEY not set)"
|
||||
)
|
||||
|
||||
# Build BM25 data
|
||||
print("Building BM25 index...")
|
||||
bm25_data = build_bm25_data(blocks_data)
|
||||
|
||||
# Build name index
|
||||
print("Building name index...")
|
||||
name_index = build_name_index(blocks_data)
|
||||
|
||||
# Build final index
|
||||
index = {
|
||||
"version": "1.0.0",
|
||||
"embedding_model": DEFAULT_EMBEDDING_MODEL,
|
||||
"embedding_dim": DEFAULT_EMBEDDING_DIM,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"blocks": blocks_data,
|
||||
"bm25": bm25_data,
|
||||
"name_index": name_index,
|
||||
}
|
||||
|
||||
# Save index
|
||||
print(f"Saving index to {output_path}...")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, separators=(",", ":"))
|
||||
|
||||
size_kb = output_path.stat().st_size / 1024
|
||||
print(f"Index saved ({size_kb:.1f} KB)")
|
||||
|
||||
# Print statistics
|
||||
print("\nIndex Statistics:")
|
||||
print(f" Blocks indexed: {len(blocks_data)}")
|
||||
print(f" BM25 vocabulary size: {len(bm25_data['df'])}")
|
||||
print(f" Name index terms: {len(name_index)}")
|
||||
print(f" Embeddings: {'Yes' if any(b.get('emb') for b in blocks_data) else 'No'}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build hybrid search index for blocks")
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Force rebuild all embeddings even if unchanged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=INDEX_PATH,
|
||||
help=f"Output index file path (default: {INDEX_PATH})",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
build_block_index(
|
||||
force_rebuild=args.force,
|
||||
output_path=args.output,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error building index: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,287 +0,0 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"Use find_block to discover available blocks and their input schemas. "
|
||||
"The block will run and return its outputs once complete."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the block to execute",
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Must match the block's input schema. "
|
||||
"Check the block's input_schema from find_block for required fields."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [c.model_dump() for c in missing_credentials],
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
exec_kwargs: dict[str, Any] = {"user_id": user_id}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
@@ -1,460 +0,0 @@
|
||||
"""
|
||||
Block Hybrid Search
|
||||
|
||||
Combines multiple ranking signals for block search:
|
||||
- Semantic search (OpenAI embeddings + cosine similarity)
|
||||
- Lexical search (BM25)
|
||||
- Name matching (boost for block name matches)
|
||||
- Category matching (boost for category matches)
|
||||
|
||||
Based on the docs search implementation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
# Path to the JSON index file
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization (same as index_blocks.py)
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block",
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for search."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchWeights:
|
||||
"""Configuration for hybrid search signal weights."""
|
||||
|
||||
semantic: float = 0.40 # Embedding similarity
|
||||
bm25: float = 0.25 # Lexical matching
|
||||
name_match: float = 0.25 # Block name matches
|
||||
category_match: float = 0.10 # Category matches
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockSearchResult:
|
||||
"""A single block search result."""
|
||||
|
||||
block_id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
score: float
|
||||
|
||||
# Individual signal scores (for debugging)
|
||||
semantic_score: float = 0.0
|
||||
bm25_score: float = 0.0
|
||||
name_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
|
||||
|
||||
class BlockSearchIndex:
|
||||
"""Hybrid search index for blocks combining BM25 + embeddings."""
|
||||
|
||||
def __init__(self, index_path: Path = INDEX_PATH):
|
||||
self.blocks: list[dict[str, Any]] = []
|
||||
self.bm25_data: dict[str, Any] = {}
|
||||
self.name_index: dict[str, list[list[int | float]]] = {}
|
||||
self.embeddings: Optional[np.ndarray] = None
|
||||
self.normalized_embeddings: Optional[np.ndarray] = None
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
self._embedding_model: Any = None
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Block index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.blocks = data.get("blocks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self.name_index = data.get("name_index", {})
|
||||
|
||||
# Decode embeddings from base64
|
||||
embeddings_list = []
|
||||
for block in self.blocks:
|
||||
if block.get("emb"):
|
||||
emb_bytes = base64.b64decode(block["emb"])
|
||||
emb = np.frombuffer(emb_bytes, dtype=np.float32)
|
||||
embeddings_list.append(emb)
|
||||
else:
|
||||
# No embedding, use zeros
|
||||
dim = data.get("embedding_dim", 384)
|
||||
embeddings_list.append(np.zeros(dim, dtype=np.float32))
|
||||
|
||||
if embeddings_list:
|
||||
self.embeddings = np.stack(embeddings_list)
|
||||
# Precompute normalized embeddings for cosine similarity
|
||||
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
|
||||
self.normalized_embeddings = self.embeddings / (norms + 1e-10)
|
||||
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded block index with {len(self.blocks)} blocks")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load block index: {e}")
|
||||
return False
|
||||
|
||||
def _get_openai_client(self) -> Any:
|
||||
"""Get OpenAI client for query embedding."""
|
||||
if self._embedding_model is None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set")
|
||||
return None
|
||||
self._embedding_model = OpenAI(api_key=api_key)
|
||||
except ImportError:
|
||||
logger.warning("openai not installed")
|
||||
return None
|
||||
return self._embedding_model
|
||||
|
||||
def _embed_query(self, query: str) -> Optional[np.ndarray]:
|
||||
"""Embed the search query using OpenAI."""
|
||||
client = self._get_openai_client()
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=query,
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
return np.array(embedding, dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to embed query: {e}")
|
||||
return None
|
||||
|
||||
def _compute_semantic_scores(self, query_embedding: np.ndarray) -> np.ndarray:
|
||||
"""Compute cosine similarity between query and all blocks."""
|
||||
if self.normalized_embeddings is None:
|
||||
return np.zeros(len(self.blocks))
|
||||
|
||||
# Normalize query embedding
|
||||
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
|
||||
|
||||
# Cosine similarity via dot product
|
||||
similarities = self.normalized_embeddings @ query_norm
|
||||
|
||||
# Scale to [0, 1] (cosine ranges from -1 to 1)
|
||||
return (similarities + 1) / 2
|
||||
|
||||
def _compute_bm25_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute BM25 scores for all blocks."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.bm25_data or not query_tokens:
|
||||
return scores
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.blocks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.blocks))
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
# Tokenize block's searchable text
|
||||
block_tokens = tokenize(block.get("searchable_text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(block_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = block_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
scores[i] = score
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_name_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute name match scores using the name index."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.name_index or not query_tokens:
|
||||
return scores
|
||||
|
||||
for token in query_tokens:
|
||||
if token in self.name_index:
|
||||
for block_idx, weight in self.name_index[token]:
|
||||
if block_idx < len(scores):
|
||||
scores[int(block_idx)] += weight
|
||||
|
||||
# Also check for partial matches in block names
|
||||
for i, block in enumerate(self.blocks):
|
||||
name_lower = block.get("name", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in name_lower:
|
||||
scores[i] += 0.5
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_category_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute category match scores."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not query_tokens:
|
||||
return scores
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
categories = block.get("categories", [])
|
||||
category_text = " ".join(categories).lower()
|
||||
|
||||
for token in query_tokens:
|
||||
if token in category_text:
|
||||
scores[i] += 1.0
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
weights: Optional[SearchWeights] = None,
|
||||
) -> list[BlockSearchResult]:
|
||||
"""
|
||||
Perform hybrid search combining multiple signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
top_k: Number of results to return
|
||||
weights: Optional custom weights for signals
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult sorted by score
|
||||
"""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
if weights is None:
|
||||
weights = SearchWeights()
|
||||
|
||||
# Tokenize query
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
# Fallback: try raw query words
|
||||
query_tokens = query.lower().split()
|
||||
|
||||
# Compute semantic scores
|
||||
semantic_scores = np.zeros(len(self.blocks))
|
||||
if self.normalized_embeddings is not None:
|
||||
query_embedding = self._embed_query(query)
|
||||
if query_embedding is not None:
|
||||
semantic_scores = self._compute_semantic_scores(query_embedding)
|
||||
|
||||
# Compute other scores
|
||||
bm25_scores = self._compute_bm25_scores(query_tokens)
|
||||
name_scores = self._compute_name_scores(query_tokens)
|
||||
category_scores = self._compute_category_scores(query_tokens)
|
||||
|
||||
# Combine scores using weights
|
||||
combined_scores = (
|
||||
weights.semantic * semantic_scores
|
||||
+ weights.bm25 * bm25_scores
|
||||
+ weights.name_match * name_scores
|
||||
+ weights.category_match * category_scores
|
||||
)
|
||||
|
||||
# Get top-k indices
|
||||
top_indices = np.argsort(combined_scores)[::-1][:top_k]
|
||||
|
||||
# Build results
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
if combined_scores[idx] <= 0:
|
||||
continue
|
||||
|
||||
block = self.blocks[idx]
|
||||
results.append(
|
||||
BlockSearchResult(
|
||||
block_id=block["id"],
|
||||
name=block["name"],
|
||||
description=block["description"],
|
||||
categories=block.get("categories", []),
|
||||
score=float(combined_scores[idx]),
|
||||
semantic_score=float(semantic_scores[idx]),
|
||||
bm25_score=float(bm25_scores[idx]),
|
||||
name_score=float(name_scores[idx]),
|
||||
category_score=float(category_scores[idx]),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_block_search_index: Optional[BlockSearchIndex] = None
|
||||
|
||||
|
||||
def get_block_search_index() -> BlockSearchIndex:
|
||||
"""Get or create the block search index singleton."""
|
||||
global _block_search_index
|
||||
if _block_search_index is None:
|
||||
_block_search_index = BlockSearchIndex(INDEX_PATH)
|
||||
return _block_search_index
|
||||
@@ -1,386 +0,0 @@
|
||||
"""Tool for searching platform documentation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Documentation base URL
|
||||
DOCS_BASE_URL = "https://docs.agpt.co/platform"
|
||||
|
||||
# Path to the JSON index file (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "docs_index.json"
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
stopwords = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"both",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
}
|
||||
return [w for w in words if len(w) > 2 and w not in stopwords]
|
||||
|
||||
|
||||
class DocSearchIndex:
|
||||
"""Lightweight documentation search index using BM25."""
|
||||
|
||||
def __init__(self, index_path: Path):
|
||||
self.chunks: list[dict] = []
|
||||
self.bm25_data: dict = {}
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Documentation index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.chunks = data.get("chunks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded documentation index with {len(self.chunks)} chunks")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load documentation index: {e}")
|
||||
return False
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Search the index using BM25."""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.chunks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.chunks))
|
||||
|
||||
scores = []
|
||||
for i, chunk in enumerate(self.chunks):
|
||||
# Tokenize chunk text
|
||||
chunk_tokens = tokenize(chunk.get("text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(chunk_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = chunk_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
# Boost for title/heading matches
|
||||
title = chunk.get("title", "").lower()
|
||||
heading = chunk.get("heading", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in title:
|
||||
score *= 1.5
|
||||
if token in heading:
|
||||
score *= 1.2
|
||||
|
||||
scores.append((i, score))
|
||||
|
||||
# Sort by score and return top_k
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
results = []
|
||||
seen_sections = set()
|
||||
for idx, score in scores:
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
chunk = self.chunks[idx]
|
||||
section_key = (chunk.get("doc", ""), chunk.get("heading", ""))
|
||||
|
||||
# Deduplicate by section
|
||||
if section_key in seen_sections:
|
||||
continue
|
||||
seen_sections.add(section_key)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"title": chunk.get("title", ""),
|
||||
"path": chunk.get("doc", ""),
|
||||
"heading": chunk.get("heading", ""),
|
||||
"text": chunk.get("text", ""), # Full text for LLM comprehension
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_search_index: DocSearchIndex | None = None
|
||||
|
||||
|
||||
def get_search_index() -> DocSearchIndex:
|
||||
"""Get or create the search index singleton."""
|
||||
global _search_index
|
||||
if _search_index is None:
|
||||
_search_index = DocSearchIndex(INDEX_PATH)
|
||||
return _search_index
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_platform_docs"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation and support Q&A for information about "
|
||||
"how to use the platform, create agents, configure blocks, "
|
||||
"set up integrations, troubleshoot issues, and more. Use this when users ask "
|
||||
"support questions or want to learn how to do something with AutoGPT."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query describing what the user wants to learn about. "
|
||||
"Use keywords like 'blocks', 'agents', 'credentials', 'API', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation for the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
index = get_search_index()
|
||||
results = index.search(query, top_k=5)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'. Try different keywords.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms like 'blocks', 'agents', 'setup'",
|
||||
"Check the documentation at docs.agpt.co",
|
||||
],
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
doc_results = []
|
||||
for r in results:
|
||||
# Build documentation URL
|
||||
path = r["path"]
|
||||
if path.endswith(".md"):
|
||||
path = path[:-3] # Remove .md extension
|
||||
doc_url = f"{DOCS_BASE_URL}/{path}"
|
||||
|
||||
full_text = r["text"]
|
||||
doc_results.append(
|
||||
DocSearchResult(
|
||||
title=r["title"],
|
||||
path=r["path"],
|
||||
section=r["heading"],
|
||||
snippet=(
|
||||
full_text[:300] + "..."
|
||||
if len(full_text) > 300
|
||||
else full_text
|
||||
),
|
||||
content=full_text, # Full text for LLM to read and understand
|
||||
score=round(r["score"], 3),
|
||||
doc_url=doc_url,
|
||||
)
|
||||
)
|
||||
|
||||
return DocSearchResultsResponse(
|
||||
message=(
|
||||
f"Found {len(doc_results)} relevant documentation sections. "
|
||||
"Use these to help answer the user's question. "
|
||||
"Include links to the documentation when helpful."
|
||||
),
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching documentation: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search documentation. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,833 +0,0 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Endpoints
|
||||
|
||||
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||
from gcloud.aio import storage as async_storage
|
||||
from PIL import Image
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
OAuthApplicationInfo,
|
||||
TokenIntrospectionResult,
|
||||
consume_authorization_code,
|
||||
create_access_token,
|
||||
create_authorization_code,
|
||||
create_refresh_token,
|
||||
get_oauth_application,
|
||||
get_oauth_application_by_id,
|
||||
introspect_token,
|
||||
list_user_oauth_applications,
|
||||
refresh_tokens,
|
||||
revoke_access_token,
|
||||
revoke_refresh_token,
|
||||
update_oauth_application,
|
||||
validate_client_credentials,
|
||||
validate_redirect_uri,
|
||||
validate_scopes,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 token response"""
|
||||
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
access_token: str
|
||||
access_token_expires_at: datetime
|
||||
refresh_token: str
|
||||
refresh_token_expires_at: datetime
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OAuth 2.0 error response"""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthApplicationPublicInfo(BaseModel):
|
||||
"""Public information about an OAuth application (for consent screen)"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Info Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/app/{client_id}",
|
||||
responses={
|
||||
404: {"description": "Application not found or disabled"},
|
||||
},
|
||||
)
|
||||
async def get_oauth_app_info(
|
||||
client_id: str, user_id: str = Security(get_user_id)
|
||||
) -> OAuthApplicationPublicInfo:
|
||||
"""
|
||||
Get public information about an OAuth application.
|
||||
|
||||
This endpoint is used by the consent screen to display application details
|
||||
to the user before they authorize access.
|
||||
|
||||
Returns:
|
||||
- name: Application name
|
||||
- description: Application description (if provided)
|
||||
- scopes: List of scopes the application is allowed to request
|
||||
"""
|
||||
app = await get_oauth_application(client_id)
|
||||
if not app or not app.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found",
|
||||
)
|
||||
|
||||
return OAuthApplicationPublicInfo(
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logo_url,
|
||||
scopes=[s.value for s in app.scopes],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuthorizeRequest(BaseModel):
|
||||
"""OAuth 2.0 authorization request"""
|
||||
|
||||
client_id: str = Field(description="Client identifier")
|
||||
redirect_uri: str = Field(description="Redirect URI")
|
||||
scopes: list[str] = Field(description="List of scopes")
|
||||
state: str = Field(description="Anti-CSRF token from client")
|
||||
response_type: str = Field(
|
||||
default="code", description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
"""OAuth 2.0 authorization response with redirect URL"""
|
||||
|
||||
redirect_url: str = Field(description="URL to redirect the user to")
|
||||
|
||||
|
||||
@router.post("/authorize")
|
||||
async def authorize(
|
||||
request: AuthorizeRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint
|
||||
|
||||
User must be logged in (authenticated with Supabase JWT).
|
||||
This endpoint creates an authorization code and returns a redirect URL.
|
||||
|
||||
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||
|
||||
The frontend consent screen should call this endpoint after the user approves,
|
||||
then redirect the user to the returned `redirect_url`.
|
||||
|
||||
Request Body:
|
||||
- client_id: The OAuth application's client ID
|
||||
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||
- response_type: Must be "code" (for authorization code flow)
|
||||
- code_challenge: PKCE code challenge (required)
|
||||
- code_challenge_method: "S256" (recommended) or "plain"
|
||||
|
||||
Returns:
|
||||
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||
|
||||
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||
for critical errors (like invalid redirect_uri).
|
||||
"""
|
||||
try:
|
||||
# Validate response_type
|
||||
if request.response_type != "code":
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"unsupported_response_type",
|
||||
"Only 'code' response type is supported",
|
||||
)
|
||||
|
||||
# Get application
|
||||
app = await get_oauth_application(request.client_id)
|
||||
if not app:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Unknown client_id",
|
||||
)
|
||||
|
||||
if not app.is_active:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Application is not active",
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if not validate_redirect_uri(app, request.redirect_uri):
|
||||
# For invalid redirect_uri, we can't redirect safely
|
||||
# Must return error instead
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Invalid redirect_uri. "
|
||||
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
try:
|
||||
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||
except ValueError as e:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
f"Invalid scope: {e}",
|
||||
)
|
||||
|
||||
if not requested_scopes:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"At least one scope is required",
|
||||
)
|
||||
|
||||
if not validate_scopes(app, requested_scopes):
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"Application is not authorized for all requested scopes. "
|
||||
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||
)
|
||||
|
||||
# Create authorization code
|
||||
auth_code = await create_authorization_code(
|
||||
application_id=app.id,
|
||||
user_id=user_id,
|
||||
scopes=requested_scopes,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_challenge=request.code_challenge,
|
||||
code_challenge_method=request.code_challenge_method,
|
||||
)
|
||||
|
||||
# Build redirect URL with authorization code
|
||||
params = {
|
||||
"code": auth_code.code,
|
||||
"state": request.state,
|
||||
}
|
||||
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"Authorization code issued for user #{user_id} "
|
||||
f"and app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"server_error",
|
||||
"An unexpected error occurred",
|
||||
)
|
||||
|
||||
|
||||
def _error_redirect_url(
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
error: str,
|
||||
error_description: Optional[str] = None,
|
||||
) -> AuthorizeResponse:
|
||||
"""Helper to build redirect URL with OAuth error parameters"""
|
||||
params = {
|
||||
"error": error,
|
||||
"state": state,
|
||||
}
|
||||
if error_description:
|
||||
params["error_description"] = error_description
|
||||
|
||||
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenRequestByCode(BaseModel):
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(description="Authorization code")
|
||||
redirect_uri: str = Field(
|
||||
description="Redirect URI (must match authorization request)"
|
||||
)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
code_verifier: str = Field(description="PKCE code verifier")
|
||||
|
||||
|
||||
class TokenRequestByRefreshToken(BaseModel):
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint
|
||||
|
||||
Exchanges authorization code or refresh token for access token.
|
||||
|
||||
Grant Types:
|
||||
1. authorization_code: Exchange authorization code for tokens
|
||||
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||
- Optional: code_verifier (required if PKCE was used)
|
||||
|
||||
2. refresh_token: Exchange refresh token for new access token
|
||||
- Required: grant_type, refresh_token, client_id, client_secret
|
||||
|
||||
Returns:
|
||||
- access_token: Bearer token for API access (1 hour TTL)
|
||||
- token_type: "Bearer"
|
||||
- expires_in: Seconds until access token expires
|
||||
- refresh_token: Token for refreshing access (30 days TTL)
|
||||
- scopes: List of scopes
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(
|
||||
request.client_id, request.client_secret
|
||||
)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Handle authorization_code grant
|
||||
if request.grant_type == "authorization_code":
|
||||
# Consume authorization code
|
||||
try:
|
||||
user_id, scopes = await consume_authorization_code(
|
||||
code=request.code,
|
||||
application_id=app.id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_verifier=request.code_verifier,
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create access and refresh tokens
|
||||
access_token = await create_access_token(app.id, user_id, scopes)
|
||||
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||
|
||||
logger.info(
|
||||
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||
"via authorization code"
|
||||
)
|
||||
|
||||
if not access_token.token or not refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=access_token.token.get_secret_value(),
|
||||
access_token_expires_at=access_token.expires_at,
|
||||
refresh_token=refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=refresh_token.expires_at,
|
||||
scopes=list(s.value for s in scopes),
|
||||
)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif request.grant_type == "refresh_token":
|
||||
# Refresh access token
|
||||
try:
|
||||
new_access_token, new_refresh_token = await refresh_tokens(
|
||||
request.refresh_token, app.id
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||
f"by app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
if not new_access_token.token or not new_refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=new_access_token.token.get_secret_value(),
|
||||
access_token_expires_at=new_access_token.expires_at,
|
||||
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||
scopes=list(s.value for s in new_access_token.scopes),
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||
"Must be 'authorization_code' or 'refresh_token'",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Body(description="Token to introspect"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||
|
||||
Allows clients to check if a token is valid and get its metadata.
|
||||
|
||||
Returns:
|
||||
- active: Whether the token is currently active
|
||||
- scopes: List of authorized scopes (if active)
|
||||
- client_id: The client the token was issued to (if active)
|
||||
- user_id: The user the token represents (if active)
|
||||
- exp: Expiration timestamp (if active)
|
||||
- token_type: "access_token" or "refresh_token" (if active)
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Introspect the token
|
||||
return await introspect_token(token, token_type_hint)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Body(description="Token to revoke"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
):
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||
|
||||
Allows clients to revoke an access or refresh token.
|
||||
|
||||
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||
Revoking an access token does NOT revoke the associated refresh token.
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Try to revoke as access token first
|
||||
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||
if token_type_hint != "refresh_token":
|
||||
revoked = await revoke_access_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Try to revoke as refresh token
|
||||
revoked = await revoke_refresh_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||
# or if token belongs to a different application.
|
||||
# This prevents token scanning attacks.
|
||||
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Management Endpoints (for app owners)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/apps/mine")
|
||||
async def list_my_oauth_apps(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[OAuthApplicationInfo]:
|
||||
"""
|
||||
List all OAuth applications owned by the current user.
|
||||
|
||||
Returns a list of OAuth applications with their details including:
|
||||
- id, name, description, logo_url
|
||||
- client_id (public identifier)
|
||||
- redirect_uris, grant_types, scopes
|
||||
- is_active status
|
||||
- created_at, updated_at timestamps
|
||||
|
||||
Note: client_secret is never returned for security reasons.
|
||||
"""
|
||||
return await list_user_oauth_applications(user_id)
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/status")
|
||||
async def update_app_status(
|
||||
app_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Enable or disable an OAuth application.
|
||||
|
||||
Only the application owner can update the status.
|
||||
When disabled, the application cannot be used for new authorizations
|
||||
and existing access tokens will fail validation.
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
action = "enabled" if is_active else "disabled"
|
||||
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
class UpdateAppLogoRequest(BaseModel):
|
||||
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/logo")
|
||||
async def update_app_logo(
|
||||
app_id: str,
|
||||
request: UpdateAppLogoRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Update the logo URL for an OAuth application.
|
||||
|
||||
Only the application owner can update the logo.
|
||||
The logo should be uploaded first using the media upload endpoint,
|
||||
then this endpoint is called with the resulting URL.
|
||||
|
||||
Logo requirements:
|
||||
- Must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=request.logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
# Logo upload constraints
|
||||
LOGO_MIN_SIZE = 512
|
||||
LOGO_MAX_SIZE = 2048
|
||||
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||
|
||||
|
||||
@router.post("/apps/{app_id}/logo/upload")
|
||||
async def upload_app_logo(
|
||||
app_id: str,
|
||||
file: UploadFile,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Upload a logo image for an OAuth application.
|
||||
|
||||
Requirements:
|
||||
- Image must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
- Allowed formats: JPEG, PNG, WebP
|
||||
- Maximum file size: 3MB
|
||||
|
||||
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||
Returns the updated application info.
|
||||
"""
|
||||
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Check GCS configuration
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Media storage is not configured",
|
||||
)
|
||||
|
||||
# Validate content type
|
||||
content_type = file.content_type
|
||||
if content_type not in LOGO_ALLOWED_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading logo file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read uploaded file",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"File too large. "
|
||||
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate image dimensions
|
||||
try:
|
||||
image = Image.open(io.BytesIO(file_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if width != height:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo must be square. Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width < LOGO_MIN_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width > LOGO_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating logo image: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid image file",
|
||||
)
|
||||
|
||||
# Scan for viruses
|
||||
filename = file.filename or "logo"
|
||||
await scan_content_safe(file_bytes, filename=filename)
|
||||
|
||||
# Generate unique filename
|
||||
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||
|
||||
# Upload to GCS
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading logo to GCS: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload logo",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
# Update the app with the new logo URL
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||
"""
|
||||
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||
"""
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||
|
||||
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||
old_path = app.logo_url.replace(storage_base_url, "")
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
await async_client.delete(bucket_name, old_path)
|
||||
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||
except Exception as e:
|
||||
# Log but don't fail - the new logo was uploaded successfully
|
||||
logger.warning(
|
||||
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,72 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI script to backfill embeddings for store agents.
|
||||
|
||||
Usage:
|
||||
poetry run python -m backend.server.v2.store.backfill_embeddings [--batch-size N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import prisma
|
||||
|
||||
|
||||
async def main(batch_size: int = 100) -> int:
|
||||
"""Run the backfill process."""
|
||||
# Initialize Prisma client
|
||||
client = prisma.Prisma()
|
||||
await client.connect()
|
||||
prisma.register(client)
|
||||
|
||||
try:
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
# Get current stats
|
||||
print("Current embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
if stats["without_embeddings"] == 0:
|
||||
print("\nAll agents already have embeddings. Nothing to do.")
|
||||
return 0
|
||||
|
||||
# Run backfill
|
||||
print(f"\nBackfilling up to {batch_size} embeddings...")
|
||||
result = await backfill_missing_embeddings(batch_size=batch_size)
|
||||
print(f" Processed: {result['processed']}")
|
||||
print(f" Success: {result['success']}")
|
||||
print(f" Failed: {result['failed']}")
|
||||
|
||||
# Get final stats
|
||||
print("\nFinal embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
return 0 if result["failed"] == 0 else 1
|
||||
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Backfill embeddings for store agents")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of embeddings to generate (default: 100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sys.exit(asyncio.run(main(batch_size=args.batch_size)))
|
||||
@@ -1,408 +0,0 @@
|
||||
"""
|
||||
Store Listing Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for store listings
|
||||
to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
EMBEDDING_DIM = 1536
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def compute_content_hash(text: str) -> str:
|
||||
"""Compute MD5 hash of text for change detection."""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
# Truncate text to avoid token limits (~32k chars for safety)
|
||||
truncated_text = text[:32000]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.debug(f"Generated embedding with {len(embedding)} dimensions")
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
content_hash: str,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
# Upsert the embedding
|
||||
# Set search_path to include public for vector type visibility
|
||||
await client.execute_raw(
|
||||
"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
INSERT INTO platform."StoreListingEmbedding" (
|
||||
"id", "storeListingVersionId", "embedding",
|
||||
"searchableText", "contentHash", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (
|
||||
gen_random_uuid(), $1, $2::vector,
|
||||
$3, $4, NOW(), NOW()
|
||||
)
|
||||
ON CONFLICT ("storeListingVersionId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $2::vector,
|
||||
"searchableText" = $3,
|
||||
"contentHash" = $4,
|
||||
"updatedAt" = NOW()
|
||||
""",
|
||||
version_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
content_hash,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for version {version_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
Returns dict with embedding, searchableText, contentHash or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
result = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
"id",
|
||||
"storeListingVersionId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"contentHash",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM platform."StoreListingEmbedding"
|
||||
WHERE "storeListingVersionId" = $1
|
||||
""",
|
||||
version_id,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for version {version_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing or if content has changed.
|
||||
Skips if content hash matches existing embedding.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if hash matches
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Build searchable text and compute hash
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
content_hash = compute_content_hash(searchable_text)
|
||||
|
||||
# Check if embedding already exists with same hash
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("contentHash") == content_hash:
|
||||
logger.debug(
|
||||
f"Embedding for version {version_id} is up to date (hash match)"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_embedding(
|
||||
version_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
content_hash=content_hash,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await client.execute_raw(
|
||||
"""
|
||||
DELETE FROM platform."StoreListingEmbedding"
|
||||
WHERE "storeListingVersionId" = $1
|
||||
""",
|
||||
version_id,
|
||||
)
|
||||
|
||||
logger.info(f"Deleted embedding for version {version_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns counts of:
|
||||
- Total approved listing versions
|
||||
- Versions with embeddings
|
||||
- Versions without embeddings
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Count approved versions
|
||||
approved_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion" slv
|
||||
JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total_approved": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
"coverage_percent": (
|
||||
round(with_embeddings / total_approved * 100, 1)
|
||||
if total_approved > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"total_approved": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Find approved versions without embeddings
|
||||
missing = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM platform."StoreListingVersion" slv
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND sle.id IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
if not missing:
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
|
||||
for row in missing:
|
||||
result = await ensure_embedding(
|
||||
version_id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
return {
|
||||
"processed": len(missing),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill embeddings: {e}")
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
@@ -1,440 +0,0 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
import prisma
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
|
||||
semantic: float = 0.35 # Embedding cosine similarity
|
||||
lexical: float = 0.35 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.35 semantic + 0.35 lexical + 0.20 category + 0.10 recency):
|
||||
# - 0.20 means at least ~50% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency alone (0.10 max) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
"""
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
client = prisma.get_client()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Determine if we can use hybrid search (have query embedding)
|
||||
use_hybrid = query_embedding is not None
|
||||
|
||||
if use_hybrid:
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Build hybrid search query with weighted scoring
|
||||
# The semantic score is (1 - cosine_distance), normalized to [0,1]
|
||||
# The lexical score is ts_rank_cd, normalized by max value
|
||||
# Set search_path to include public for vector type visibility
|
||||
sql_query = f"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
WITH search_scores AS (
|
||||
SELECT
|
||||
sa.*,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd normalized
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: 1 if query term appears in categories, else 0
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: exponential decay over 90 days
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
|
||||
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE {where_clause}
|
||||
AND (
|
||||
sa.search @@ plainto_tsquery('english', {query_param})
|
||||
OR sle.embedding IS NOT NULL
|
||||
)
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
-- Normalize lexical score by max in result set
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM search_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.semantic} * semantic_score +
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT * FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# Count query - must also filter by min_score
|
||||
count_query = f"""
|
||||
SET LOCAL search_path TO platform, public;
|
||||
WITH search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
|
||||
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
|
||||
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
|
||||
WHERE {where_clause}
|
||||
AND (
|
||||
sa.search @@ plainto_tsquery('english', {query_param})
|
||||
OR sle.embedding IS NOT NULL
|
||||
)
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
slug,
|
||||
semantic_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM search_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
(
|
||||
{weights.semantic} * semantic_score +
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT COUNT(*) as count FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
"""
|
||||
|
||||
else:
|
||||
# Fallback to lexical-only search (existing behavior)
|
||||
# Note: For lexical-only, we still require tsvector match but don't
|
||||
# apply min_score since ts_rank_cd isn't normalized to [0,1]
|
||||
logger.warning("Falling back to lexical-only search (no query embedding)")
|
||||
|
||||
sql_query = f"""
|
||||
WITH lexical_scores AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
0.0 as semantic_score,
|
||||
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND search @@ plainto_tsquery('english', {query_param})
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM lexical_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT * FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
params.extend([page_size, offset])
|
||||
|
||||
count_query = f"""
|
||||
WITH lexical_scores AS (
|
||||
SELECT
|
||||
slug,
|
||||
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
|
||||
) THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM platform."StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND search @@ plainto_tsquery('english', {query_param})
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
slug,
|
||||
category_score,
|
||||
recency_score,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM lexical_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
(
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
)
|
||||
SELECT COUNT(*) as count FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
"""
|
||||
|
||||
try:
|
||||
# Execute search query
|
||||
# Dynamic SQL is safe here - all user inputs are parameterized ($1, $2, etc.)
|
||||
results = await client.query_raw(sql_query, *params) # type: ignore[arg-type]
|
||||
|
||||
# Execute count query (without pagination params)
|
||||
count_params = params[:-2] # Remove LIMIT and OFFSET params
|
||||
count_result = await client.query_raw(count_query, *count_params) # type: ignore[arg-type]
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
|
||||
logger.info(
|
||||
f"Hybrid search for '{query}': {len(results)} results, {total} total "
|
||||
f"(hybrid={use_hybrid})"
|
||||
)
|
||||
|
||||
return results, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hybrid search failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def sort_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to sort the endpoints,
|
||||
schemas, and responses.
|
||||
"""
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Sort endpoints
|
||||
openapi_schema["paths"] = dict(sorted(openapi_schema["paths"].items()))
|
||||
|
||||
# Sort endpoints -> methods
|
||||
for p in openapi_schema["paths"].keys():
|
||||
openapi_schema["paths"][p] = dict(
|
||||
sorted(openapi_schema["paths"][p].items())
|
||||
)
|
||||
|
||||
# Sort endpoints -> methods -> responses
|
||||
for m in openapi_schema["paths"][p].keys():
|
||||
openapi_schema["paths"][p][m]["responses"] = dict(
|
||||
sorted(openapi_schema["paths"][p][m]["responses"].items())
|
||||
)
|
||||
|
||||
# Sort schemas and responses as well
|
||||
for k in openapi_schema["components"].keys():
|
||||
openapi_schema["components"][k] = dict(
|
||||
sorted(openapi_schema["components"][k].items())
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
@@ -36,10 +36,10 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -20,7 +20,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import Requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -247,11 +246,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -427,11 +422,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -608,11 +599,7 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -106,10 +106,7 @@ class ConditionBlock(Block):
|
||||
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
|
||||
}
|
||||
|
||||
try:
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Comparison failed: {e}") from e
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
|
||||
yield "result", result
|
||||
|
||||
|
||||
@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
# Handle missing key, null value, or valid list value
|
||||
if isinstance(first_result, dict):
|
||||
items = first_result.get("items") or []
|
||||
else:
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
@@ -60,18 +59,11 @@ class FirecrawlExtractBlock(Block):
|
||||
) -> BlockOutput:
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Extract failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
|
||||
yield "data", extract_result.data
|
||||
|
||||
@@ -19,7 +19,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import ModerationError
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -154,8 +153,6 @@ class AIImageEditorBlock(Block):
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
yield "output_image", result
|
||||
|
||||
@@ -167,8 +164,6 @@ class AIImageEditorBlock(Block):
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
seed: Optional[int],
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
@@ -178,21 +173,11 @@ class AIImageEditorBlock(Block):
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
except Exception as e:
|
||||
if "flagged as sensitive" in str(e).lower():
|
||||
raise ModerationError(
|
||||
message="Content was flagged as sensitive by the model provider",
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
moderation_type="model_provider",
|
||||
)
|
||||
raise ValueError(f"Model execution failed: {e}") from e
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, list) and output:
|
||||
output = output[0]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
@@ -45,11 +45,11 @@ class HumanInTheLoopBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
approved_data: Any = SchemaField(
|
||||
description="The data when approved (may be modified by reviewer)"
|
||||
reviewed_data: Any = SchemaField(
|
||||
description="The data after human review (may be modified)"
|
||||
)
|
||||
rejected_data: Any = SchemaField(
|
||||
description="The data when rejected (may be modified by reviewer)"
|
||||
status: Literal["approved", "rejected"] = SchemaField(
|
||||
description="Status of the review: 'approved' or 'rejected'"
|
||||
)
|
||||
review_message: str = SchemaField(
|
||||
description="Any message provided by the reviewer", default=""
|
||||
@@ -69,7 +69,8 @@ class HumanInTheLoopBlock(Block):
|
||||
"editable": True,
|
||||
},
|
||||
test_output=[
|
||||
("approved_data", {"name": "John Doe", "age": 30}),
|
||||
("status", "approved"),
|
||||
("reviewed_data", {"name": "John Doe", "age": 30}),
|
||||
],
|
||||
test_mock={
|
||||
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
|
||||
@@ -115,7 +116,8 @@ class HumanInTheLoopBlock(Block):
|
||||
logger.info(
|
||||
f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
yield "approved_data", input_data.data
|
||||
yield "status", "approved"
|
||||
yield "reviewed_data", input_data.data
|
||||
yield "review_message", "Auto-approved (safe mode disabled)"
|
||||
return
|
||||
|
||||
@@ -156,11 +158,12 @@ class HumanInTheLoopBlock(Block):
|
||||
)
|
||||
|
||||
if result.status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", result.data
|
||||
yield "status", "approved"
|
||||
yield "reviewed_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
elif result.status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", result.data
|
||||
yield "status", "rejected"
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
@@ -2,6 +2,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -331,8 +332,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with V3 endpoint: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
|
||||
|
||||
async def _run_model_legacy(
|
||||
self,
|
||||
@@ -384,8 +385,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with legacy endpoint: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
|
||||
|
||||
async def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
@@ -412,5 +413,5 @@ class IdeogramModelBlock(Block):
|
||||
|
||||
return (response.json())["data"][0]["url"]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to upscale image: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to upscale image: {str(e)}")
|
||||
|
||||
@@ -16,7 +16,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -57,17 +56,7 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
|
||||
try:
|
||||
results = await self.get_request(
|
||||
jina_search_url, headers=headers, json=False
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Search failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
results = await self.get_request(jina_search_url, headers=headers, json=False)
|
||||
|
||||
# Output the search results
|
||||
yield "results", results
|
||||
|
||||
@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
@@ -195,9 +194,8 @@ MODEL_METADATA = {
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
@@ -305,8 +303,6 @@ MODEL_METADATA = {
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -859,7 +855,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
@@ -18,7 +18,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -112,27 +111,9 @@ class ReplicateModelBlock(Block):
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Error running Replicate model: {error_msg}")
|
||||
|
||||
# Input validation errors (422, 400) → BlockInputError
|
||||
if (
|
||||
"422" in error_msg
|
||||
or "Input validation failed" in error_msg
|
||||
or "400" in error_msg
|
||||
):
|
||||
raise BlockInputError(
|
||||
message=f"Invalid model inputs: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
# Everything else → BlockExecutionError
|
||||
else:
|
||||
raise BlockExecutionError(
|
||||
message=f"Replicate model error: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
error_msg = f"Unexpected error running Replicate model: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
|
||||
"""
|
||||
|
||||
@@ -45,16 +45,10 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
|
||||
# Note: User-Agent is now automatically set by the request library
|
||||
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
|
||||
try:
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch Wikipedia summary: {e}") from e
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data.block import (
|
||||
@@ -23,41 +20,16 @@ from backend.data.dynamic_fields import (
|
||||
is_dynamic_field,
|
||||
is_tool_pin,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.manager import ExecutionProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolInfo(BaseModel):
|
||||
"""Processed tool call information."""
|
||||
|
||||
tool_call: Any # The original tool call object from LLM response
|
||||
tool_name: str # The function name
|
||||
tool_def: dict[str, Any] # The tool definition from tool_functions
|
||||
input_data: dict[str, Any] # Processed input data ready for tool execution
|
||||
field_mapping: dict[str, str] # Field name mapping for the tool
|
||||
|
||||
|
||||
class ExecutionParams(BaseModel):
|
||||
"""Tool execution parameters."""
|
||||
|
||||
user_id: str
|
||||
graph_id: str
|
||||
node_id: str
|
||||
graph_version: int
|
||||
graph_exec_id: str
|
||||
node_exec_id: str
|
||||
execution_context: "ExecutionContext"
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
@@ -133,50 +105,6 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Combine multiple Anthropic tool responses into a single user message.
|
||||
For non-Anthropic formats, returns the original list unchanged.
|
||||
"""
|
||||
if len(tool_outputs) <= 1:
|
||||
return tool_outputs
|
||||
|
||||
# Anthropic responses have role="user", type="message", and content is a list with tool_result items
|
||||
anthropic_responses = [
|
||||
output
|
||||
for output in tool_outputs
|
||||
if (
|
||||
output.get("role") == "user"
|
||||
and output.get("type") == "message"
|
||||
and isinstance(output.get("content"), list)
|
||||
and any(
|
||||
item.get("type") == "tool_result"
|
||||
for item in output.get("content", [])
|
||||
if isinstance(item, dict)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
if len(anthropic_responses) > 1:
|
||||
combined_content = [
|
||||
item for response in anthropic_responses for item in response["content"]
|
||||
]
|
||||
|
||||
combined_response = {
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": combined_content,
|
||||
}
|
||||
|
||||
non_anthropic_responses = [
|
||||
output for output in tool_outputs if output not in anthropic_responses
|
||||
]
|
||||
|
||||
return [combined_response] + non_anthropic_responses
|
||||
|
||||
return tool_outputs
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
@@ -226,7 +154,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -276,17 +204,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
agent_mode_max_iterations: int = SchemaField(
|
||||
title="Agent Mode Max Iterations",
|
||||
description="Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit.",
|
||||
advanced=True,
|
||||
default=0,
|
||||
)
|
||||
conversation_compaction: bool = SchemaField(
|
||||
default=True,
|
||||
title="Context window auto-compaction",
|
||||
description="Automatically compact the context window once it hits the limit",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||
@@ -589,7 +506,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
Returns the response if successful, raises ValueError if validation fails.
|
||||
"""
|
||||
resp = await llm.llm_call(
|
||||
compress_prompt_to_fit=input_data.conversation_compaction,
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=current_prompt,
|
||||
@@ -677,291 +593,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
return resp
|
||||
|
||||
def _process_tool_calls(
|
||||
self, response, tool_functions: list[dict[str, Any]]
|
||||
) -> list[ToolInfo]:
|
||||
"""Process tool calls and extract tool definitions, arguments, and input data.
|
||||
|
||||
Returns a list of tool info dicts with:
|
||||
- tool_call: The original tool call object
|
||||
- tool_name: The function name
|
||||
- tool_def: The tool definition from tool_functions
|
||||
- input_data: Processed input data dict (includes None values)
|
||||
- field_mapping: Field name mapping for the tool
|
||||
"""
|
||||
if not response.tool_calls:
|
||||
return []
|
||||
|
||||
processed_tools = []
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool_def:
|
||||
if len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Build input data for the tool
|
||||
input_data = {}
|
||||
field_mapping = tool_def["function"].get("_field_mapping", {})
|
||||
if "function" in tool_def and "parameters" in tool_def["function"]:
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
for clean_arg_name in expected_args:
|
||||
original_field_name = field_mapping.get(
|
||||
clean_arg_name, clean_arg_name
|
||||
)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
# Include all expected parameters, even if None (for backward compatibility with tests)
|
||||
input_data[original_field_name] = arg_value
|
||||
|
||||
processed_tools.append(
|
||||
ToolInfo(
|
||||
tool_call=tool_call,
|
||||
tool_name=tool_name,
|
||||
tool_def=tool_def,
|
||||
input_data=input_data,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
)
|
||||
|
||||
return processed_tools
|
||||
|
||||
def _update_conversation(
|
||||
self, prompt: list[dict], response, tool_outputs: list | None = None
|
||||
):
|
||||
"""Update conversation history with response and tool outputs."""
|
||||
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
|
||||
assistant_message = _convert_raw_response_to_dict(response.raw_response)
|
||||
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
|
||||
item.get("type") == "tool_use"
|
||||
for item in assistant_message.get("content", [])
|
||||
)
|
||||
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(assistant_message)
|
||||
|
||||
if tool_outputs:
|
||||
prompt.extend(tool_outputs)
|
||||
|
||||
async def _execute_single_tool_with_manager(
|
||||
self,
|
||||
tool_info: ToolInfo,
|
||||
execution_params: ExecutionParams,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
) -> dict:
|
||||
"""Execute a single tool using the execution manager for proper integration."""
|
||||
# Lazy imports to avoid circular dependencies
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
|
||||
tool_call = tool_info.tool_call
|
||||
tool_def = tool_info.tool_def
|
||||
raw_input_data = tool_info.input_data
|
||||
|
||||
# Get sink node and field mapping
|
||||
sink_node_id = tool_def["function"]["_sink_node_id"]
|
||||
|
||||
# Use proper database operations for tool execution
|
||||
db_client = get_database_manager_async_client()
|
||||
|
||||
# Get target node
|
||||
target_node = await db_client.get_node(sink_node_id)
|
||||
if not target_node:
|
||||
raise ValueError(f"Target node {sink_node_id} not found")
|
||||
|
||||
# Create proper node execution using upsert_execution_input
|
||||
node_exec_result = None
|
||||
final_input_data = None
|
||||
|
||||
# Add all inputs to the execution
|
||||
if not raw_input_data:
|
||||
raise ValueError(f"Tool call has no input data: {tool_call}")
|
||||
|
||||
for input_name, input_value in raw_input_data.items():
|
||||
node_exec_result, final_input_data = await db_client.upsert_execution_input(
|
||||
node_id=sink_node_id,
|
||||
graph_exec_id=execution_params.graph_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_value,
|
||||
)
|
||||
|
||||
assert node_exec_result is not None, "node_exec_result should not be None"
|
||||
|
||||
# Create NodeExecutionEntry for execution manager
|
||||
node_exec_entry = NodeExecutionEntry(
|
||||
user_id=execution_params.user_id,
|
||||
graph_exec_id=execution_params.graph_exec_id,
|
||||
graph_id=execution_params.graph_id,
|
||||
graph_version=execution_params.graph_version,
|
||||
node_exec_id=node_exec_result.node_exec_id,
|
||||
node_id=sink_node_id,
|
||||
block_id=target_node.block_id,
|
||||
inputs=final_input_data or {},
|
||||
execution_context=execution_params.execution_context,
|
||||
)
|
||||
|
||||
# Use the execution manager to execute the tool node
|
||||
try:
|
||||
# Get NodeExecutionProgress from the execution manager's running nodes
|
||||
node_exec_progress = execution_processor.running_node_execution[
|
||||
sink_node_id
|
||||
]
|
||||
|
||||
# Use the execution manager's own graph stats
|
||||
graph_stats_pair = (
|
||||
execution_processor.execution_stats,
|
||||
execution_processor.execution_stats_lock,
|
||||
)
|
||||
|
||||
# Create a completed future for the task tracking system
|
||||
node_exec_future = Future()
|
||||
node_exec_progress.add_task(
|
||||
node_exec_id=node_exec_result.node_exec_id,
|
||||
task=node_exec_future,
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the SmartDecisionMaker context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
node_exec_progress=node_exec_progress,
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=graph_stats_pair,
|
||||
)
|
||||
)
|
||||
|
||||
# Get outputs from database after execution completes using database manager client
|
||||
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||
node_exec_result.node_exec_id
|
||||
)
|
||||
|
||||
# Create tool response
|
||||
tool_response_content = (
|
||||
json.dumps(node_outputs)
|
||||
if node_outputs
|
||||
else "Tool executed successfully"
|
||||
)
|
||||
return _create_tool_response(tool_call.id, tool_response_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution with manager failed: {e}")
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
tool_call.id, f"Tool execution failed: {str(e)}"
|
||||
)
|
||||
|
||||
async def _execute_tools_agent_mode(
|
||||
self,
|
||||
input_data,
|
||||
credentials,
|
||||
tool_functions: list[dict[str, Any]],
|
||||
prompt: list[dict],
|
||||
graph_exec_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
):
|
||||
"""Execute tools in agent mode with a loop until finished."""
|
||||
max_iterations = input_data.agent_mode_max_iterations
|
||||
iteration = 0
|
||||
|
||||
# Execution parameters for tool execution
|
||||
execution_params = ExecutionParams(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
node_id=node_id,
|
||||
graph_version=graph_version,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
current_prompt = list(prompt)
|
||||
|
||||
while max_iterations < 0 or iteration < max_iterations:
|
||||
iteration += 1
|
||||
logger.debug(f"Agent mode iteration {iteration}")
|
||||
|
||||
# Prepare prompt for this iteration
|
||||
iteration_prompt = list(current_prompt)
|
||||
|
||||
# On the last iteration, add a special system message to encourage completion
|
||||
if max_iterations > 0 and iteration == max_iterations:
|
||||
last_iteration_message = {
|
||||
"role": "system",
|
||||
"content": f"{MAIN_OBJECTIVE_PREFIX}This is your last iteration ({iteration}/{max_iterations}). "
|
||||
"Try to complete the task with the information you have. If you cannot fully complete it, "
|
||||
"provide a summary of what you've accomplished and what remains to be done. "
|
||||
"Prefer finishing with a clear response rather than making additional tool calls.",
|
||||
}
|
||||
iteration_prompt.append(last_iteration_message)
|
||||
|
||||
# Get LLM response
|
||||
try:
|
||||
response = await self._attempt_llm_call_with_validation(
|
||||
credentials, input_data, iteration_prompt, tool_functions
|
||||
)
|
||||
except Exception as e:
|
||||
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
|
||||
return
|
||||
|
||||
# Process tool calls
|
||||
processed_tools = self._process_tool_calls(response, tool_functions)
|
||||
|
||||
# If no tool calls, we're done
|
||||
if not processed_tools:
|
||||
yield "finished", response.response
|
||||
self._update_conversation(current_prompt, response)
|
||||
yield "conversations", current_prompt
|
||||
return
|
||||
|
||||
# Execute tools and collect responses
|
||||
tool_outputs = []
|
||||
for tool_info in processed_tools:
|
||||
try:
|
||||
tool_response = await self._execute_single_tool_with_manager(
|
||||
tool_info, execution_params, execution_processor
|
||||
)
|
||||
tool_outputs.append(tool_response)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution failed: {e}")
|
||||
# Create error response for the tool
|
||||
error_response = _create_tool_response(
|
||||
tool_info.tool_call.id, f"Error: {str(e)}"
|
||||
)
|
||||
tool_outputs.append(error_response)
|
||||
|
||||
tool_outputs = _combine_tool_responses(tool_outputs)
|
||||
|
||||
self._update_conversation(current_prompt, response, tool_outputs)
|
||||
|
||||
# Yield intermediate conversation state
|
||||
yield "conversations", current_prompt
|
||||
|
||||
# If we reach max iterations, yield the current state
|
||||
if max_iterations < 0:
|
||||
yield "finished", f"Agent mode completed after {iteration} iterations"
|
||||
else:
|
||||
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
|
||||
yield "conversations", current_prompt
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -972,12 +603,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
graph_exec_id: str,
|
||||
node_exec_id: str,
|
||||
user_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
@@ -1021,52 +648,24 @@ class SmartDecisionMakerBlock(Block):
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
|
||||
prefix = "[Main Objective Prompt]: "
|
||||
|
||||
if input_data.sys_prompt and not any(
|
||||
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": MAIN_OBJECTIVE_PREFIX + input_data.sys_prompt,
|
||||
}
|
||||
)
|
||||
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
|
||||
|
||||
if input_data.prompt and not any(
|
||||
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
{"role": "user", "content": MAIN_OBJECTIVE_PREFIX + input_data.prompt}
|
||||
)
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
# Execute tools based on the selected mode
|
||||
if input_data.agent_mode_max_iterations != 0:
|
||||
# In agent mode, execute tools directly in a loop until finished
|
||||
async for result in self._execute_tools_agent_mode(
|
||||
input_data=input_data,
|
||||
credentials=credentials,
|
||||
tool_functions=tool_functions,
|
||||
prompt=prompt,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
execution_processor=execution_processor,
|
||||
):
|
||||
yield result
|
||||
return
|
||||
|
||||
# One-off mode: single LLM call and yield tool calls for external execution
|
||||
current_prompt = list(prompt)
|
||||
max_attempts = max(1, int(input_data.retry))
|
||||
response = None
|
||||
|
||||
last_error = None
|
||||
for _ in range(max_attempts):
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = await self._attempt_llm_call_with_validation(
|
||||
credentials, input_data, current_prompt, tool_functions
|
||||
|
||||
@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
async def test_rejects_text_outside_root(self):
|
||||
"""Ensure parser surfaces readable errors for invalid root text."""
|
||||
block = XMLParserBlock()
|
||||
invalid_xml = "<root><child>value</child></root> trailing"
|
||||
|
||||
with pytest.raises(ValueError, match="text outside the root element"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.DEFAULT_LLM_MODEL,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AITextGeneratorBlock.Input(
|
||||
prompt="Generate text",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text=long_text,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=100, # Small chunks
|
||||
chunk_overlap=10,
|
||||
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="This is a short text.",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AIListGeneratorBlock.Input(
|
||||
focus="test items",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_retries=3,
|
||||
)
|
||||
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"result": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
)
|
||||
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
max_tokens=1000,
|
||||
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000,
|
||||
)
|
||||
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import logging
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -21,10 +17,10 @@ async def create_graph(s: SpinTestServer, g, u: User):
|
||||
|
||||
|
||||
async def create_credentials(s: SpinTestServer, u: User):
|
||||
import backend.blocks.llm as llm_module
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
provider = ProviderName.OPENAI
|
||||
credentials = llm_module.TEST_CREDENTIALS
|
||||
credentials = llm.TEST_CREDENTIALS
|
||||
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
|
||||
|
||||
@@ -200,6 +196,8 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
@@ -218,6 +216,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
}
|
||||
|
||||
# Mock the _create_tool_node_signatures method to avoid database calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
@@ -233,21 +232,12 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
# Execute the block
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
@@ -256,9 +246,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
@@ -276,6 +263,8 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
@@ -322,6 +311,8 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_with_typo.reasoning = None
|
||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
@@ -335,20 +326,11 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
# Should raise ValueError after retries due to typo'd parameter name
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
@@ -360,9 +342,6 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
@@ -389,6 +368,8 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_missing_required.reasoning = None
|
||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
@@ -402,19 +383,10 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
# Should raise ValueError due to missing required parameter
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
@@ -426,9 +398,6 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
@@ -449,6 +418,8 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_valid.reasoning = None
|
||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
@@ -462,21 +433,12 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
# Should succeed - optional parameter missing is OK
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
@@ -485,9 +447,6 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
@@ -513,6 +472,8 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_all_params.reasoning = None
|
||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
@@ -526,21 +487,12 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
# Should succeed with all parameters
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
@@ -549,9 +501,6 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
@@ -564,6 +513,8 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
@@ -633,6 +584,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
)
|
||||
|
||||
# Mock llm_call to return different responses on different calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
@@ -648,22 +600,13 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
# Should succeed after retry, demonstrating our helper function works
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
@@ -672,9 +615,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
@@ -710,6 +650,8 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
"I'll help you with that." # Ollama returns string
|
||||
)
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
@@ -722,20 +664,11 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
@@ -744,9 +677,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
@@ -766,6 +696,8 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
"content": "Test response",
|
||||
} # Dict format
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
@@ -778,20 +710,11 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
@@ -800,260 +723,8 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Test response"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_mode():
|
||||
"""Test that agent mode executes tools directly and loops until finished."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool call that requires multiple iterations
|
||||
mock_tool_call_1 = MagicMock()
|
||||
mock_tool_call_1.id = "call_1"
|
||||
mock_tool_call_1.function.name = "search_keywords"
|
||||
mock_tool_call_1.function.arguments = (
|
||||
'{"query": "test", "max_keyword_difficulty": 50}'
|
||||
)
|
||||
|
||||
mock_response_1 = MagicMock()
|
||||
mock_response_1.response = None
|
||||
mock_response_1.tool_calls = [mock_tool_call_1]
|
||||
mock_response_1.prompt_tokens = 50
|
||||
mock_response_1.completion_tokens = 25
|
||||
mock_response_1.reasoning = "Using search tool"
|
||||
mock_response_1.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_1", "type": "function"}],
|
||||
}
|
||||
|
||||
# Final response with no tool calls (finished)
|
||||
mock_response_2 = MagicMock()
|
||||
mock_response_2.response = "Task completed successfully"
|
||||
mock_response_2.tool_calls = []
|
||||
mock_response_2.prompt_tokens = 30
|
||||
mock_response_2.completion_tokens = 15
|
||||
mock_response_2.reasoning = None
|
||||
mock_response_2.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully",
|
||||
}
|
||||
|
||||
# Mock the LLM call to return different responses on each iteration
|
||||
llm_call_mock = AsyncMock()
|
||||
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
# Mock tool node signatures
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"max_keyword_difficulty": {"type": "integer"},
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Mock database and execution components
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block-id"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
|
||||
# Mock upsert_execution_input to return proper NodeExecutionResult and input data
|
||||
mock_node_exec_result = MagicMock()
|
||||
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
|
||||
mock_input_data = {"query": "test", "max_keyword_difficulty": 50}
|
||||
mock_db_client.upsert_execution_input.return_value = (
|
||||
mock_node_exec_result,
|
||||
mock_input_data,
|
||||
)
|
||||
|
||||
# No longer need mock_execute_node since we use execution_processor.on_node_execution
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"backend.integrations.creds_manager.IntegrationCredentialsManager"
|
||||
):
|
||||
|
||||
# Create a mock execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(
|
||||
safe_mode=False,
|
||||
)
|
||||
|
||||
# Create a mock execution processor for agent mode tests
|
||||
|
||||
mock_execution_processor = AsyncMock()
|
||||
# Configure the execution processor mock with required attributes
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
# Mock the on_node_execution method to return successful stats
|
||||
mock_node_stats = MagicMock()
|
||||
mock_node_stats.error = None # No error
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
|
||||
# Mock the get_execution_outputs_by_node_exec_id method
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
"result": {"status": "success", "data": "search completed"}
|
||||
}
|
||||
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify agent mode behavior
|
||||
assert "tool_functions" in outputs # tool_functions is yielded in both modes
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Task completed successfully"
|
||||
assert "conversations" in outputs
|
||||
|
||||
# Verify the conversation includes tool responses
|
||||
conversations = outputs["conversations"]
|
||||
assert len(conversations) > 2 # Should have multiple conversation entries
|
||||
|
||||
# Verify LLM was called twice (once for tool call, once for finish)
|
||||
assert llm_call_mock.call_count == 2
|
||||
|
||||
# Verify tool was executed via execution processor
|
||||
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_traditional_mode_default():
|
||||
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "search_keywords"
|
||||
mock_tool_call.function.arguments = (
|
||||
'{"query": "test", "max_keyword_difficulty": 50}'
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"max_keyword_difficulty": {"type": "integer"},
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
):
|
||||
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
)
|
||||
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify traditional mode behavior
|
||||
assert (
|
||||
"tool_functions" in outputs
|
||||
) # Should yield tool_functions in traditional mode
|
||||
assert (
|
||||
"tools_^_test-sink-node-id_~_query" in outputs
|
||||
) # Should yield individual tool parameters
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert "conversations" in outputs
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -308,47 +308,10 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db_manager, patch.object(
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
# Set up the mock database manager
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_manager.return_value = mock_db_client
|
||||
|
||||
# Mock the node retrieval
|
||||
mock_target_node = Mock()
|
||||
mock_target_node.id = "test-sink-node-id"
|
||||
mock_target_node.block_id = "CreateDictionaryBlock"
|
||||
mock_target_node.block = Mock()
|
||||
mock_target_node.block.name = "Create Dictionary"
|
||||
mock_db_client.get_node.return_value = mock_target_node
|
||||
|
||||
# Mock the execution result creation
|
||||
mock_node_exec_result = Mock()
|
||||
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
|
||||
mock_final_input_data = {
|
||||
"values_#_name": "Alice",
|
||||
"values_#_age": 30,
|
||||
"values_#_email": "alice@example.com",
|
||||
}
|
||||
mock_db_client.upsert_execution_input.return_value = (
|
||||
mock_node_exec_result,
|
||||
mock_final_input_data,
|
||||
)
|
||||
|
||||
# Mock the output retrieval
|
||||
mock_outputs = {
|
||||
"values_#_name": "Alice",
|
||||
"values_#_age": 30,
|
||||
"values_#_email": "alice@example.com",
|
||||
}
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
|
||||
mock_outputs
|
||||
)
|
||||
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -373,17 +336,11 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||
model=llm.LlmModel.GPT4O,
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
@@ -392,9 +349,6 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
@@ -557,108 +511,45 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
}
|
||||
]
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db_manager:
|
||||
# Set up the mock database manager for agent mode
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_manager.return_value = mock_db_client
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
# Mock the node retrieval
|
||||
mock_target_node = Mock()
|
||||
mock_target_node.id = "test-sink-node-id"
|
||||
mock_target_node.block_id = "TestBlock"
|
||||
mock_target_node.block = Mock()
|
||||
mock_target_node.block.name = "Test Block"
|
||||
mock_db_client.get_node.return_value = mock_target_node
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
)
|
||||
|
||||
# Mock the execution result creation
|
||||
mock_node_exec_result = Mock()
|
||||
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
|
||||
mock_final_input_data = {"correct_param": "value"}
|
||||
mock_db_client.upsert_execution_input.return_value = (
|
||||
mock_node_exec_result,
|
||||
mock_final_input_data,
|
||||
)
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Mock the output retrieval
|
||||
mock_outputs = {"correct_param": "value"}
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
|
||||
mock_outputs
|
||||
)
|
||||
# Verify we had 2 LLM calls (initial + retry)
|
||||
assert call_count == 2
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
# Check the final conversation output
|
||||
final_conversation = outputs.get("conversations", [])
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
retry=3, # Allow retries
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
# The final conversation should NOT contain the validation error message
|
||||
error_messages = [
|
||||
msg
|
||||
for msg in final_conversation
|
||||
if msg.get("role") == "user"
|
||||
and "parameter errors" in msg.get("content", "")
|
||||
]
|
||||
assert (
|
||||
len(error_messages) == 0
|
||||
), "Validation error leaked into final conversation"
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a proper mock execution processor for agent mode
|
||||
from collections import defaultdict
|
||||
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = MagicMock()
|
||||
|
||||
# Create a mock NodeExecutionProgress for the sink node
|
||||
mock_node_exec_progress = MagicMock()
|
||||
mock_node_exec_progress.add_task = MagicMock()
|
||||
mock_node_exec_progress.pop_output = MagicMock(
|
||||
return_value=None
|
||||
) # No outputs to process
|
||||
|
||||
# Set up running_node_execution as a defaultdict that returns our mock for any key
|
||||
mock_execution_processor.running_node_execution = defaultdict(
|
||||
lambda: mock_node_exec_progress
|
||||
)
|
||||
|
||||
# Mock the on_node_execution method that gets called during tool execution
|
||||
mock_node_stats = MagicMock()
|
||||
mock_node_stats.error = None
|
||||
mock_execution_processor.on_node_execution.return_value = (
|
||||
mock_node_stats
|
||||
)
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify we had at least 1 LLM call
|
||||
assert call_count >= 1
|
||||
|
||||
# Check the final conversation output
|
||||
final_conversation = outputs.get("conversations", [])
|
||||
|
||||
# The final conversation should NOT contain validation error messages
|
||||
# Even if retries don't happen in agent mode, we should not leak errors
|
||||
error_messages = [
|
||||
msg
|
||||
for msg in final_conversation
|
||||
if msg.get("role") == "user"
|
||||
and "parameter errors" in msg.get("content", "")
|
||||
]
|
||||
assert (
|
||||
len(error_messages) == 0
|
||||
), "Validation error leaked into final conversation"
|
||||
# The final conversation should only have the successful response
|
||||
assert final_conversation[-1]["content"] == "valid"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from gravitasml.parser import Parser
|
||||
from gravitasml.token import Token, tokenize
|
||||
from gravitasml.token import tokenize
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import SchemaField
|
||||
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_tokens(tokens: list[Token]) -> None:
|
||||
"""Ensure the XML has a single root element and no stray text."""
|
||||
if not tokens:
|
||||
raise ValueError("XML input is empty.")
|
||||
|
||||
depth = 0
|
||||
root_seen = False
|
||||
|
||||
for token in tokens:
|
||||
if token.type == "TAG_OPEN":
|
||||
if depth == 0 and root_seen:
|
||||
raise ValueError("XML must have a single root element.")
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
root_seen = True
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
"XML contains text outside the root element; "
|
||||
"wrap content in a single root tag."
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = list(tokenize(input_data.input_xml))
|
||||
self._validate_tokens(tokens)
|
||||
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
parsed_result = parser.parse()
|
||||
yield "parsed_xml", parsed_result
|
||||
|
||||
@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path[:3] == "/v/":
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path.startswith("/shorts/"):
|
||||
return parsed_url.path.split("/")[2]
|
||||
raise ValueError(f"Invalid YouTube URL: {url}")
|
||||
|
||||
def get_transcript(
|
||||
|
||||
@@ -244,7 +244,11 @@ def websocket(server_address: str, graph_exec_id: str):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.api.ws_api import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""CLI utilities for backend development & administration"""
|
||||
@@ -1,57 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to generate OpenAPI JSON specification for the FastAPI app.
|
||||
|
||||
This script imports the FastAPI app from backend.api.rest_api and outputs
|
||||
the OpenAPI specification as JSON to stdout or a specified file.
|
||||
|
||||
Usage:
|
||||
`poetry run python generate_openapi_json.py`
|
||||
`poetry run python generate_openapi_json.py --output openapi.json`
|
||||
`poetry run python generate_openapi_json.py --indent 4 --output openapi.json`
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(dir_okay=False, path_type=Path),
|
||||
help="Output file path (default: stdout)",
|
||||
)
|
||||
@click.option(
|
||||
"--pretty",
|
||||
type=click.BOOL,
|
||||
default=False,
|
||||
help="Pretty-print JSON output (indented 2 spaces)",
|
||||
)
|
||||
def main(output: Path, pretty: bool):
|
||||
"""Generate and output the OpenAPI JSON specification."""
|
||||
openapi_schema = get_openapi_schema()
|
||||
|
||||
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
|
||||
|
||||
if output:
|
||||
output.write_text(json_output)
|
||||
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
|
||||
click.echo(f"\n{json_output[:500]} ...")
|
||||
else:
|
||||
print(json_output)
|
||||
|
||||
|
||||
def get_openapi_schema():
|
||||
"""Get the OpenAPI schema from the FastAPI app"""
|
||||
from backend.api.rest_api import app
|
||||
|
||||
return app.openapi()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["LOG_LEVEL"] = "ERROR" # disable stdout log output
|
||||
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
class APIKeyInfo(APIAuthorizationInfo):
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
head: str = Field(
|
||||
@@ -28,9 +26,12 @@ class APIKeyInfo(APIAuthorizationInfo):
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
type: Literal["api_key"] = "api_key" # type: ignore
|
||||
user_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
@@ -40,7 +41,7 @@ class APIKeyInfo(APIAuthorizationInfo):
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
scopes=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
@@ -210,7 +211,7 @@ async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.scopes
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
@@ -1,15 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIAuthorizationInfo(BaseModel):
|
||||
user_id: str
|
||||
scopes: list[APIKeyPermission]
|
||||
type: Literal["oauth", "api_key"]
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -1,872 +0,0 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Data Layer
|
||||
|
||||
Handles management of OAuth applications, authorization codes,
|
||||
access tokens, and refresh tokens.
|
||||
|
||||
Hashing strategy:
|
||||
- Access tokens & Refresh tokens: SHA256 (deterministic, allows direct lookup by hash)
|
||||
- Client secrets: Scrypt with salt (lookup by client_id, then verify with salt)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission as APIPermission
|
||||
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.types import OAuthApplicationUpdateInput
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith() # Only used for client secret hashing (Scrypt)
|
||||
|
||||
|
||||
def _generate_token() -> str:
|
||||
"""Generate a cryptographically secure random token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA256 (deterministic, for direct lookup)."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
# Token TTLs
|
||||
AUTHORIZATION_CODE_TTL = timedelta(minutes=10)
|
||||
ACCESS_TOKEN_TTL = timedelta(hours=1)
|
||||
REFRESH_TOKEN_TTL = timedelta(days=30)
|
||||
|
||||
ACCESS_TOKEN_PREFIX = "agpt_xt_"
|
||||
REFRESH_TOKEN_PREFIX = "agpt_rt_"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Exception Classes
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthError(Exception):
|
||||
"""Base OAuth error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidClientError(OAuthError):
|
||||
"""Invalid client_id or client_secret"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidGrantError(OAuthError):
|
||||
"""Invalid or expired authorization code/refresh token"""
|
||||
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
super().__init__(f"Invalid grant: {reason}")
|
||||
|
||||
|
||||
class InvalidTokenError(OAuthError):
|
||||
"""Invalid, expired, or revoked token"""
|
||||
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
super().__init__(f"Invalid token: {reason}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthApplicationInfo(BaseModel):
|
||||
"""OAuth application information (without client secret hash)"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
client_id: str
|
||||
redirect_uris: list[str]
|
||||
grant_types: list[str]
|
||||
scopes: list[APIPermission]
|
||||
owner_id: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(app: PrismaOAuthApplication):
|
||||
return OAuthApplicationInfo(
|
||||
id=app.id,
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logoUrl,
|
||||
client_id=app.clientId,
|
||||
redirect_uris=app.redirectUris,
|
||||
grant_types=app.grantTypes,
|
||||
scopes=[APIPermission(s) for s in app.scopes],
|
||||
owner_id=app.ownerId,
|
||||
is_active=app.isActive,
|
||||
created_at=app.createdAt,
|
||||
updated_at=app.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthApplicationInfoWithSecret(OAuthApplicationInfo):
|
||||
"""OAuth application with client secret hash (for validation)"""
|
||||
|
||||
client_secret_hash: str
|
||||
client_secret_salt: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(app: PrismaOAuthApplication):
|
||||
return OAuthApplicationInfoWithSecret(
|
||||
**OAuthApplicationInfo.from_db(app).model_dump(),
|
||||
client_secret_hash=app.clientSecret,
|
||||
client_secret_salt=app.clientSecretSalt,
|
||||
)
|
||||
|
||||
def verify_secret(self, plaintext_secret: str) -> bool:
|
||||
"""Verify a plaintext client secret against the stored hash"""
|
||||
# Use keysmith.verify_key() with stored salt
|
||||
return keysmith.verify_key(
|
||||
plaintext_secret, self.client_secret_hash, self.client_secret_salt
|
||||
)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeInfo(BaseModel):
|
||||
"""Authorization code information"""
|
||||
|
||||
id: str
|
||||
code: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
application_id: str
|
||||
user_id: str
|
||||
scopes: list[APIPermission]
|
||||
redirect_uri: str
|
||||
code_challenge: Optional[str] = None
|
||||
code_challenge_method: Optional[str] = None
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def is_used(self) -> bool:
|
||||
return self.used_at is not None
|
||||
|
||||
@staticmethod
|
||||
def from_db(code: PrismaOAuthAuthorizationCode):
|
||||
return OAuthAuthorizationCodeInfo(
|
||||
id=code.id,
|
||||
code=code.code,
|
||||
created_at=code.createdAt,
|
||||
expires_at=code.expiresAt,
|
||||
application_id=code.applicationId,
|
||||
user_id=code.userId,
|
||||
scopes=[APIPermission(s) for s in code.scopes],
|
||||
redirect_uri=code.redirectUri,
|
||||
code_challenge=code.codeChallenge,
|
||||
code_challenge_method=code.codeChallengeMethod,
|
||||
used_at=code.usedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessTokenInfo(APIAuthorizationInfo):
|
||||
"""Access token information"""
|
||||
|
||||
id: str
|
||||
expires_at: datetime # type: ignore
|
||||
application_id: str
|
||||
|
||||
type: Literal["oauth"] = "oauth" # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthAccessToken):
|
||||
return OAuthAccessTokenInfo(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
scopes=[APIPermission(s) for s in token.scopes],
|
||||
created_at=token.createdAt,
|
||||
expires_at=token.expiresAt,
|
||||
last_used_at=None,
|
||||
revoked_at=token.revokedAt,
|
||||
application_id=token.applicationId,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessToken(OAuthAccessTokenInfo):
|
||||
"""Access token with plaintext token included (sensitive)"""
|
||||
|
||||
token: SecretStr = Field(description="Plaintext token (sensitive)")
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthAccessToken, plaintext_token: str): # type: ignore
|
||||
return OAuthAccessToken(
|
||||
**OAuthAccessTokenInfo.from_db(token).model_dump(),
|
||||
token=SecretStr(plaintext_token),
|
||||
)
|
||||
|
||||
|
||||
class OAuthRefreshTokenInfo(BaseModel):
|
||||
"""Refresh token information"""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
scopes: list[APIPermission]
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
application_id: str
|
||||
revoked_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def is_revoked(self) -> bool:
|
||||
return self.revoked_at is not None
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthRefreshToken):
|
||||
return OAuthRefreshTokenInfo(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
scopes=[APIPermission(s) for s in token.scopes],
|
||||
created_at=token.createdAt,
|
||||
expires_at=token.expiresAt,
|
||||
application_id=token.applicationId,
|
||||
revoked_at=token.revokedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthRefreshToken(OAuthRefreshTokenInfo):
|
||||
"""Refresh token with plaintext token included (sensitive)"""
|
||||
|
||||
token: SecretStr = Field(description="Plaintext token (sensitive)")
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthRefreshToken, plaintext_token: str): # type: ignore
|
||||
return OAuthRefreshToken(
|
||||
**OAuthRefreshTokenInfo.from_db(token).model_dump(),
|
||||
token=SecretStr(plaintext_token),
|
||||
)
|
||||
|
||||
|
||||
class TokenIntrospectionResult(BaseModel):
|
||||
"""Result of token introspection (RFC 7662)"""
|
||||
|
||||
active: bool
|
||||
scopes: Optional[list[str]] = None
|
||||
client_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
exp: Optional[int] = None # Unix timestamp
|
||||
token_type: Optional[Literal["access_token", "refresh_token"]] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Application Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_oauth_application(client_id: str) -> Optional[OAuthApplicationInfo]:
|
||||
"""Get OAuth application by client ID (without secret)"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(
|
||||
where={"clientId": client_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfo.from_db(app)
|
||||
|
||||
|
||||
async def get_oauth_application_with_secret(
|
||||
client_id: str,
|
||||
) -> Optional[OAuthApplicationInfoWithSecret]:
|
||||
"""Get OAuth application by client ID (with secret hash for validation)"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(
|
||||
where={"clientId": client_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfoWithSecret.from_db(app)
|
||||
|
||||
|
||||
async def validate_client_credentials(
|
||||
client_id: str, client_secret: str
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Validate client credentials and return application info.
|
||||
|
||||
Raises:
|
||||
InvalidClientError: If client_id or client_secret is invalid, or app is inactive
|
||||
"""
|
||||
app = await get_oauth_application_with_secret(client_id)
|
||||
if not app:
|
||||
raise InvalidClientError("Invalid client_id")
|
||||
|
||||
if not app.is_active:
|
||||
raise InvalidClientError("Application is not active")
|
||||
|
||||
# Verify client secret
|
||||
if not app.verify_secret(client_secret):
|
||||
raise InvalidClientError("Invalid client_secret")
|
||||
|
||||
# Return without secret hash
|
||||
return OAuthApplicationInfo(**app.model_dump(exclude={"client_secret_hash"}))
|
||||
|
||||
|
||||
def validate_redirect_uri(app: OAuthApplicationInfo, redirect_uri: str) -> bool:
|
||||
"""Validate that redirect URI is registered for the application"""
|
||||
return redirect_uri in app.redirect_uris
|
||||
|
||||
|
||||
def validate_scopes(
|
||||
app: OAuthApplicationInfo, requested_scopes: list[APIPermission]
|
||||
) -> bool:
|
||||
"""Validate that all requested scopes are allowed for the application"""
|
||||
return all(scope in app.scopes for scope in requested_scopes)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Code Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _generate_authorization_code() -> str:
|
||||
"""Generate a cryptographically secure authorization code"""
|
||||
# 32 bytes = 256 bits of entropy
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
async def create_authorization_code(
|
||||
application_id: str,
|
||||
user_id: str,
|
||||
scopes: list[APIPermission],
|
||||
redirect_uri: str,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[Literal["S256", "plain"]] = None,
|
||||
) -> OAuthAuthorizationCodeInfo:
|
||||
"""
|
||||
Create a new authorization code.
|
||||
Expires in 10 minutes and can only be used once.
|
||||
"""
|
||||
code = _generate_authorization_code()
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + AUTHORIZATION_CODE_TTL
|
||||
|
||||
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
}
|
||||
)
|
||||
|
||||
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
||||
|
||||
|
||||
async def consume_authorization_code(
|
||||
code: str,
|
||||
application_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: Optional[str] = None,
|
||||
) -> tuple[str, list[APIPermission]]:
|
||||
"""
|
||||
Consume an authorization code and return (user_id, scopes).
|
||||
|
||||
This marks the code as used and validates:
|
||||
- Code exists and matches application
|
||||
- Code is not expired
|
||||
- Code has not been used
|
||||
- Redirect URI matches
|
||||
- PKCE code verifier matches (if code challenge was provided)
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If code is invalid, expired, used, or PKCE fails
|
||||
"""
|
||||
auth_code = await PrismaOAuthAuthorizationCode.prisma().find_unique(
|
||||
where={"code": code}
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
raise InvalidGrantError("authorization code not found")
|
||||
|
||||
# Validate application
|
||||
if auth_code.applicationId != application_id:
|
||||
raise InvalidGrantError(
|
||||
"authorization code does not belong to this application"
|
||||
)
|
||||
|
||||
# Check if already used
|
||||
if auth_code.usedAt is not None:
|
||||
raise InvalidGrantError(
|
||||
f"authorization code already used at {auth_code.usedAt}"
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if auth_code.expiresAt < now:
|
||||
raise InvalidGrantError("authorization code expired")
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirectUri != redirect_uri:
|
||||
raise InvalidGrantError("redirect_uri mismatch")
|
||||
|
||||
# Validate PKCE if code challenge was provided
|
||||
if auth_code.codeChallenge:
|
||||
if not code_verifier:
|
||||
raise InvalidGrantError("code_verifier required but not provided")
|
||||
|
||||
if not _verify_pkce(
|
||||
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
|
||||
):
|
||||
raise InvalidGrantError("PKCE verification failed")
|
||||
|
||||
# Mark code as used
|
||||
await PrismaOAuthAuthorizationCode.prisma().update(
|
||||
where={"code": code},
|
||||
data={"usedAt": now},
|
||||
)
|
||||
|
||||
return auth_code.userId, [APIPermission(s) for s in auth_code.scopes]
|
||||
|
||||
|
||||
def _verify_pkce(
|
||||
code_verifier: str, code_challenge: str, code_challenge_method: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Verify PKCE code verifier against code challenge.
|
||||
|
||||
Supports:
|
||||
- S256: SHA256(code_verifier) == code_challenge
|
||||
- plain: code_verifier == code_challenge
|
||||
"""
|
||||
if code_challenge_method == "S256":
|
||||
# Hash the verifier with SHA256 and base64url encode
|
||||
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
computed_challenge = (
|
||||
secrets.token_urlsafe(len(hashed)).encode("ascii").decode("ascii")
|
||||
)
|
||||
# For proper base64url encoding
|
||||
import base64
|
||||
|
||||
computed_challenge = (
|
||||
base64.urlsafe_b64encode(hashed).decode("ascii").rstrip("=")
|
||||
)
|
||||
return secrets.compare_digest(computed_challenge, code_challenge)
|
||||
elif code_challenge_method == "plain" or code_challenge_method is None:
|
||||
# Plain comparison
|
||||
return secrets.compare_digest(code_verifier, code_challenge)
|
||||
else:
|
||||
logger.warning(f"Unsupported code challenge method: {code_challenge_method}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Access Token Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_access_token(
|
||||
application_id: str, user_id: str, scopes: list[APIPermission]
|
||||
) -> OAuthAccessToken:
|
||||
"""
|
||||
Create a new access token.
|
||||
Returns OAuthAccessToken (with plaintext token).
|
||||
"""
|
||||
plaintext_token = ACCESS_TOKEN_PREFIX + _generate_token()
|
||||
token_hash = _hash_token(plaintext_token)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + ACCESS_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
)
|
||||
|
||||
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
|
||||
async def validate_access_token(
|
||||
token: str,
|
||||
) -> tuple[OAuthAccessTokenInfo, OAuthApplicationInfo]:
|
||||
"""
|
||||
Validate an access token and return token info.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is invalid, expired, or revoked
|
||||
InvalidClientError: If the client application is not marked as active
|
||||
"""
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Direct lookup by hash
|
||||
access_token = await PrismaOAuthAccessToken.prisma().find_unique(
|
||||
where={"token": token_hash}, include={"Application": True}
|
||||
)
|
||||
|
||||
if not access_token:
|
||||
raise InvalidTokenError("access token not found")
|
||||
|
||||
if not access_token.Application: # should be impossible
|
||||
raise InvalidClientError("Client application not found")
|
||||
|
||||
if not access_token.Application.isActive:
|
||||
raise InvalidClientError("Client application is disabled")
|
||||
|
||||
if access_token.revokedAt is not None:
|
||||
raise InvalidTokenError("access token has been revoked")
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if access_token.expiresAt < now:
|
||||
raise InvalidTokenError("access token expired")
|
||||
|
||||
return (
|
||||
OAuthAccessTokenInfo.from_db(access_token),
|
||||
OAuthApplicationInfo.from_db(access_token.Application),
|
||||
)
|
||||
|
||||
|
||||
async def revoke_access_token(
|
||||
token: str, application_id: str
|
||||
) -> OAuthAccessTokenInfo | None:
|
||||
"""
|
||||
Revoke an access token.
|
||||
|
||||
Args:
|
||||
token: The plaintext access token to revoke
|
||||
application_id: The application ID making the revocation request.
|
||||
Only tokens belonging to this application will be revoked.
|
||||
|
||||
Returns:
|
||||
OAuthAccessTokenInfo if token was found and revoked, None otherwise.
|
||||
|
||||
Note:
|
||||
Always performs exactly 2 DB queries regardless of outcome to prevent
|
||||
timing side-channel attacks that could reveal token existence.
|
||||
"""
|
||||
try:
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Use update_many to filter by both token and applicationId
|
||||
updated_count = await PrismaOAuthAccessToken.prisma().update_many(
|
||||
where={
|
||||
"token": token_hash,
|
||||
"applicationId": application_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Always perform second query to ensure constant time
|
||||
result = await PrismaOAuthAccessToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
# Only return result if we actually revoked something
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
return OAuthAccessTokenInfo.from_db(result) if result else None
|
||||
except Exception as e:
|
||||
logger.exception(f"Error revoking access token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Refresh Token Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_refresh_token(
|
||||
application_id: str, user_id: str, scopes: list[APIPermission]
|
||||
) -> OAuthRefreshToken:
|
||||
"""
|
||||
Create a new refresh token.
|
||||
Returns OAuthRefreshToken (with plaintext token).
|
||||
"""
|
||||
plaintext_token = REFRESH_TOKEN_PREFIX + _generate_token()
|
||||
token_hash = _hash_token(plaintext_token)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + REFRESH_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
)
|
||||
|
||||
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
|
||||
async def refresh_tokens(
|
||||
refresh_token: str, application_id: str
|
||||
) -> tuple[OAuthAccessToken, OAuthRefreshToken]:
|
||||
"""
|
||||
Use a refresh token to create new access and refresh tokens.
|
||||
Returns (new_access_token, new_refresh_token) both with plaintext tokens included.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If refresh token is invalid, expired, or revoked
|
||||
"""
|
||||
token_hash = _hash_token(refresh_token)
|
||||
|
||||
# Direct lookup by hash
|
||||
rt = await PrismaOAuthRefreshToken.prisma().find_unique(where={"token": token_hash})
|
||||
|
||||
if not rt:
|
||||
raise InvalidGrantError("refresh token not found")
|
||||
|
||||
# NOTE: no need to check Application.isActive, this is checked by the token endpoint
|
||||
|
||||
if rt.revokedAt is not None:
|
||||
raise InvalidGrantError("refresh token has been revoked")
|
||||
|
||||
# Validate application
|
||||
if rt.applicationId != application_id:
|
||||
raise InvalidGrantError("refresh token does not belong to this application")
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt.expiresAt < now:
|
||||
raise InvalidGrantError("refresh token expired")
|
||||
|
||||
# Revoke old refresh token
|
||||
await PrismaOAuthRefreshToken.prisma().update(
|
||||
where={"token": token_hash},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
|
||||
# Create new access and refresh tokens with same scopes
|
||||
scopes = [APIPermission(s) for s in rt.scopes]
|
||||
new_access_token = await create_access_token(
|
||||
rt.applicationId,
|
||||
rt.userId,
|
||||
scopes,
|
||||
)
|
||||
new_refresh_token = await create_refresh_token(
|
||||
rt.applicationId,
|
||||
rt.userId,
|
||||
scopes,
|
||||
)
|
||||
|
||||
return new_access_token, new_refresh_token
|
||||
|
||||
|
||||
async def revoke_refresh_token(
|
||||
token: str, application_id: str
|
||||
) -> OAuthRefreshTokenInfo | None:
|
||||
"""
|
||||
Revoke a refresh token.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token to revoke
|
||||
application_id: The application ID making the revocation request.
|
||||
Only tokens belonging to this application will be revoked.
|
||||
|
||||
Returns:
|
||||
OAuthRefreshTokenInfo if token was found and revoked, None otherwise.
|
||||
|
||||
Note:
|
||||
Always performs exactly 2 DB queries regardless of outcome to prevent
|
||||
timing side-channel attacks that could reveal token existence.
|
||||
"""
|
||||
try:
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Use update_many to filter by both token and applicationId
|
||||
updated_count = await PrismaOAuthRefreshToken.prisma().update_many(
|
||||
where={
|
||||
"token": token_hash,
|
||||
"applicationId": application_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Always perform second query to ensure constant time
|
||||
result = await PrismaOAuthRefreshToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
# Only return result if we actually revoked something
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
return OAuthRefreshTokenInfo.from_db(result) if result else None
|
||||
except Exception as e:
|
||||
logger.exception(f"Error revoking refresh token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def introspect_token(
|
||||
token: str,
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None,
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
Introspect a token and return its metadata (RFC 7662).
|
||||
|
||||
Returns TokenIntrospectionResult with active=True and metadata if valid,
|
||||
or active=False if the token is invalid/expired/revoked.
|
||||
"""
|
||||
# Try as access token first (or if hint says "access_token")
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
token_info, app = await validate_access_token(token)
|
||||
return TokenIntrospectionResult(
|
||||
active=True,
|
||||
scopes=list(s.value for s in token_info.scopes),
|
||||
client_id=app.client_id if app else None,
|
||||
user_id=token_info.user_id,
|
||||
exp=int(token_info.expires_at.timestamp()),
|
||||
token_type="access_token",
|
||||
)
|
||||
except InvalidTokenError:
|
||||
pass # Try as refresh token
|
||||
|
||||
# Try as refresh token
|
||||
token_hash = _hash_token(token)
|
||||
refresh_token = await PrismaOAuthRefreshToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
if refresh_token and refresh_token.revokedAt is None:
|
||||
# Check if valid (not expired)
|
||||
now = datetime.now(timezone.utc)
|
||||
if refresh_token.expiresAt > now:
|
||||
app = await get_oauth_application_by_id(refresh_token.applicationId)
|
||||
return TokenIntrospectionResult(
|
||||
active=True,
|
||||
scopes=list(s for s in refresh_token.scopes),
|
||||
client_id=app.client_id if app else None,
|
||||
user_id=refresh_token.userId,
|
||||
exp=int(refresh_token.expiresAt.timestamp()),
|
||||
token_type="refresh_token",
|
||||
)
|
||||
|
||||
# Token not found or inactive
|
||||
return TokenIntrospectionResult(active=False)
|
||||
|
||||
|
||||
async def get_oauth_application_by_id(app_id: str) -> Optional[OAuthApplicationInfo]:
|
||||
"""Get OAuth application by ID"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(where={"id": app_id})
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfo.from_db(app)
|
||||
|
||||
|
||||
async def list_user_oauth_applications(user_id: str) -> list[OAuthApplicationInfo]:
|
||||
"""Get all OAuth applications owned by a user"""
|
||||
apps = await PrismaOAuthApplication.prisma().find_many(
|
||||
where={"ownerId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [OAuthApplicationInfo.from_db(app) for app in apps]
|
||||
|
||||
|
||||
async def update_oauth_application(
|
||||
app_id: str,
|
||||
*,
|
||||
owner_id: str,
|
||||
is_active: Optional[bool] = None,
|
||||
logo_url: Optional[str] = None,
|
||||
) -> Optional[OAuthApplicationInfo]:
|
||||
"""
|
||||
Update OAuth application active status.
|
||||
Only the owner can update their app's status.
|
||||
|
||||
Returns the updated app info, or None if app not found or not owned by user.
|
||||
"""
|
||||
# First verify ownership
|
||||
app = await PrismaOAuthApplication.prisma().find_first(
|
||||
where={"id": app_id, "ownerId": owner_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
|
||||
patch: OAuthApplicationUpdateInput = {}
|
||||
if is_active is not None:
|
||||
patch["isActive"] = is_active
|
||||
if logo_url:
|
||||
patch["logoUrl"] = logo_url
|
||||
if not patch:
|
||||
return OAuthApplicationInfo.from_db(app) # return unchanged
|
||||
|
||||
updated_app = await PrismaOAuthApplication.prisma().update(
|
||||
where={"id": app_id},
|
||||
data=patch,
|
||||
)
|
||||
return OAuthApplicationInfo.from_db(updated_app) if updated_app else None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Cleanup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def cleanup_expired_oauth_tokens() -> dict[str, int]:
|
||||
"""
|
||||
Delete expired OAuth tokens from the database.
|
||||
|
||||
This removes:
|
||||
- Expired authorization codes (10 min TTL)
|
||||
- Expired access tokens (1 hour TTL)
|
||||
- Expired refresh tokens (30 day TTL)
|
||||
|
||||
Returns a dict with counts of deleted tokens by type.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired authorization codes
|
||||
codes_result = await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
# Delete expired access tokens
|
||||
access_result = await PrismaOAuthAccessToken.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
# Delete expired refresh tokens
|
||||
refresh_result = await PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
deleted = {
|
||||
"authorization_codes": codes_result,
|
||||
"access_tokens": access_result,
|
||||
"refresh_tokens": refresh_result,
|
||||
}
|
||||
|
||||
total = sum(deleted.values())
|
||||
if total > 0:
|
||||
logger.info(f"Cleaned up {total} expired OAuth tokens: {deleted}")
|
||||
|
||||
return deleted
|
||||
@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
|
||||
@@ -16,7 +16,6 @@ from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBala
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import UserHistoryResponse
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
@@ -30,6 +29,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if transaction.amount > 0 and transaction.type in [
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if (
|
||||
amount > 0
|
||||
and is_active
|
||||
and transaction_type
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return new_balance, tx_key
|
||||
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
|
||||
@@ -111,7 +111,7 @@ def get_database_schema() -> str:
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
@@ -5,7 +5,6 @@ from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from queue import Empty
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@@ -66,9 +65,6 @@ from .includes import (
|
||||
)
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -840,30 +836,6 @@ async def upsert_execution_output(
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
|
||||
|
||||
async def get_execution_outputs_by_node_exec_id(
|
||||
node_exec_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get all execution outputs for a specific node execution ID.
|
||||
|
||||
Args:
|
||||
node_exec_id: The node execution ID to get outputs for
|
||||
|
||||
Returns:
|
||||
Dictionary mapping output names to their data values
|
||||
"""
|
||||
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
|
||||
where={"referencedByOutputExecId": node_exec_id}
|
||||
)
|
||||
|
||||
result = {}
|
||||
for output in outputs:
|
||||
if output.data is not None:
|
||||
result[output.name] = type_utils.convert(output.data, JsonValue)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
|
||||
@@ -6,14 +6,14 @@ import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.store.model as store
|
||||
from backend.api.model import CreateGraph
|
||||
import backend.server.v2.store.model as store
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockSchema, BlockSchemaInput
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.executions.review.model import (
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ async def get_or_create_human_review(
|
||||
return None
|
||||
else:
|
||||
return ReviewResult(
|
||||
data=review.payload,
|
||||
data=review.payload if review.status == ReviewStatus.APPROVED else None,
|
||||
status=review.status,
|
||||
message=review.reviewMessage or "",
|
||||
processed=review.processed,
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .db import BaseDbModel
|
||||
from .graph import NodeModel
|
||||
@@ -79,7 +79,7 @@ class WebhookWithRelations(Webhook):
|
||||
# integrations.py → library/model.py → integrations.py (for Webhook)
|
||||
# Runtime import is used in WebhookWithRelations.from_db() method instead
|
||||
# Import at runtime to avoid circular dependency
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
@@ -285,8 +285,8 @@ async def unlink_webhook_from_graph(
|
||||
user_id: The ID of the user (for authorization)
|
||||
"""
|
||||
# Avoid circular imports
|
||||
from backend.api.features.library.db import set_preset_webhook
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.server.v2.library.db import set_preset_webhook
|
||||
|
||||
# Find all nodes in this graph that use this webhook
|
||||
nodes = await AgentNode.prisma().find_many(
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.server.model import NotificationPayload
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.api.features.store.model import StoreAgentDetails
|
||||
from backend.api.model import OnboardingNotificationPayload
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.notification_bus import (
|
||||
@@ -18,6 +16,8 @@ from backend.data.notification_bus import (
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.server.model import OnboardingNotificationPayload
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
@@ -442,8 +442,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
|
||||
@@ -1,429 +0,0 @@
|
||||
"""Data models and access layer for user business understanding."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import pydantic
|
||||
from prisma.models import UserBusinessUnderstanding
|
||||
from prisma.types import (
|
||||
UserBusinessUnderstandingCreateInput,
|
||||
UserBusinessUnderstandingUpdateInput,
|
||||
)
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache configuration
|
||||
CACHE_KEY_PREFIX = "understanding"
|
||||
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
|
||||
|
||||
|
||||
def _cache_key(user_id: str) -> str:
|
||||
"""Generate cache key for user business understanding."""
|
||||
return f"{CACHE_KEY_PREFIX}:{user_id}"
|
||||
|
||||
|
||||
def _json_to_list(value: Any) -> list[str]:
|
||||
"""Convert Json field to list[str], handling None."""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return cast(list[str], value)
|
||||
return []
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
|
||||
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = pydantic.Field(
|
||||
None, description="Name of the user's business"
|
||||
)
|
||||
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
|
||||
business_size: Optional[str] = pydantic.Field(
|
||||
None, description="Company size (e.g., '1-10', '11-50')"
|
||||
)
|
||||
user_role: Optional[str] = pydantic.Field(
|
||||
None,
|
||||
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
|
||||
)
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Key business workflows"
|
||||
)
|
||||
daily_activities: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Daily activities performed"
|
||||
)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Current pain points"
|
||||
)
|
||||
bottlenecks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Process bottlenecks"
|
||||
)
|
||||
manual_tasks: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Manual/repetitive tasks"
|
||||
)
|
||||
automation_goals: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Desired automation goals"
|
||||
)
|
||||
|
||||
# Current tools
|
||||
current_software: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Software/tools currently used"
|
||||
)
|
||||
existing_automation: Optional[list[str]] = pydantic.Field(
|
||||
None, description="Existing automations"
|
||||
)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = pydantic.Field(
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# User info
|
||||
user_name: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
|
||||
# Business basics
|
||||
business_name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
business_size: Optional[str] = None
|
||||
user_role: Optional[str] = None
|
||||
|
||||
# Processes & activities
|
||||
key_workflows: list[str] = pydantic.Field(default_factory=list)
|
||||
daily_activities: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Pain points & goals
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list)
|
||||
bottlenecks: list[str] = pydantic.Field(default_factory=list)
|
||||
manual_tasks: list[str] = pydantic.Field(default_factory=list)
|
||||
automation_goals: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Current tools
|
||||
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: UserBusinessUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
return cls(
|
||||
id=db_record.id,
|
||||
user_id=db_record.userId,
|
||||
created_at=db_record.createdAt,
|
||||
updated_at=db_record.updatedAt,
|
||||
user_name=db_record.userName,
|
||||
job_title=db_record.jobTitle,
|
||||
business_name=db_record.businessName,
|
||||
industry=db_record.industry,
|
||||
business_size=db_record.businessSize,
|
||||
user_role=db_record.userRole,
|
||||
key_workflows=_json_to_list(db_record.keyWorkflows),
|
||||
daily_activities=_json_to_list(db_record.dailyActivities),
|
||||
pain_points=_json_to_list(db_record.painPoints),
|
||||
bottlenecks=_json_to_list(db_record.bottlenecks),
|
||||
manual_tasks=_json_to_list(db_record.manualTasks),
|
||||
automation_goals=_json_to_list(db_record.automationGoals),
|
||||
current_software=_json_to_list(db_record.currentSoftware),
|
||||
existing_automation=_json_to_list(db_record.existingAutomation),
|
||||
additional_notes=db_record.additionalNotes,
|
||||
)
|
||||
|
||||
|
||||
def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
"""Merge two lists, removing duplicates while preserving order."""
|
||||
if new is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return new
|
||||
# Preserve order, add new items that don't exist
|
||||
merged = list(existing)
|
||||
for item in new:
|
||||
if item not in merged:
|
||||
merged.append(item)
|
||||
return merged
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
cached_data = await redis.get(_cache_key(user_id))
|
||||
if cached_data:
|
||||
return BusinessUnderstanding.model_validate_json(cached_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get understanding from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
|
||||
"""Set business understanding in Redis cache with TTL."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.setex(
|
||||
_cache_key(user_id),
|
||||
CACHE_TTL_SECONDS,
|
||||
understanding.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set understanding in cache: {e}")
|
||||
|
||||
|
||||
async def _delete_cache(user_id: str) -> None:
|
||||
"""Delete business understanding from Redis cache."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_cache_key(user_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete understanding from cache: {e}")
|
||||
|
||||
|
||||
async def get_business_understanding(
|
||||
user_id: str,
|
||||
) -> Optional[BusinessUnderstanding]:
|
||||
"""Get the business understanding for a user.
|
||||
|
||||
Checks cache first, falls back to database if not cached.
|
||||
Results are cached for 48 hours.
|
||||
"""
|
||||
# Try cache first
|
||||
cached = await _get_from_cache(user_id)
|
||||
if cached:
|
||||
logger.debug(f"Business understanding cache hit for user {user_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss - load from database
|
||||
logger.debug(f"Business understanding cache miss for user {user_id}")
|
||||
record = await UserBusinessUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Store in cache for next time
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def upsert_business_understanding(
|
||||
user_id: str,
|
||||
data: BusinessUnderstandingInput,
|
||||
) -> BusinessUnderstanding:
|
||||
"""
|
||||
Create or update business understanding with incremental merge strategy.
|
||||
|
||||
- String fields: new value overwrites if provided (not None)
|
||||
- List fields: new items are appended to existing (deduplicated)
|
||||
"""
|
||||
# Get existing record for merge
|
||||
existing = await UserBusinessUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Build update data with merge strategy
|
||||
update_data: UserBusinessUnderstandingUpdateInput = {}
|
||||
create_data: dict[str, Any] = {"userId": user_id}
|
||||
|
||||
# String fields - overwrite if provided
|
||||
if data.user_name is not None:
|
||||
update_data["userName"] = data.user_name
|
||||
create_data["userName"] = data.user_name
|
||||
if data.job_title is not None:
|
||||
update_data["jobTitle"] = data.job_title
|
||||
create_data["jobTitle"] = data.job_title
|
||||
if data.business_name is not None:
|
||||
update_data["businessName"] = data.business_name
|
||||
create_data["businessName"] = data.business_name
|
||||
if data.industry is not None:
|
||||
update_data["industry"] = data.industry
|
||||
create_data["industry"] = data.industry
|
||||
if data.business_size is not None:
|
||||
update_data["businessSize"] = data.business_size
|
||||
create_data["businessSize"] = data.business_size
|
||||
if data.user_role is not None:
|
||||
update_data["userRole"] = data.user_role
|
||||
create_data["userRole"] = data.user_role
|
||||
if data.additional_notes is not None:
|
||||
update_data["additionalNotes"] = data.additional_notes
|
||||
create_data["additionalNotes"] = data.additional_notes
|
||||
|
||||
# List fields - merge with existing
|
||||
if data.key_workflows is not None:
|
||||
existing_list = _json_to_list(existing.keyWorkflows) if existing else None
|
||||
merged = _merge_lists(existing_list, data.key_workflows)
|
||||
update_data["keyWorkflows"] = SafeJson(merged)
|
||||
create_data["keyWorkflows"] = SafeJson(merged)
|
||||
|
||||
if data.daily_activities is not None:
|
||||
existing_list = _json_to_list(existing.dailyActivities) if existing else None
|
||||
merged = _merge_lists(existing_list, data.daily_activities)
|
||||
update_data["dailyActivities"] = SafeJson(merged)
|
||||
create_data["dailyActivities"] = SafeJson(merged)
|
||||
|
||||
if data.pain_points is not None:
|
||||
existing_list = _json_to_list(existing.painPoints) if existing else None
|
||||
merged = _merge_lists(existing_list, data.pain_points)
|
||||
update_data["painPoints"] = SafeJson(merged)
|
||||
create_data["painPoints"] = SafeJson(merged)
|
||||
|
||||
if data.bottlenecks is not None:
|
||||
existing_list = _json_to_list(existing.bottlenecks) if existing else None
|
||||
merged = _merge_lists(existing_list, data.bottlenecks)
|
||||
update_data["bottlenecks"] = SafeJson(merged)
|
||||
create_data["bottlenecks"] = SafeJson(merged)
|
||||
|
||||
if data.manual_tasks is not None:
|
||||
existing_list = _json_to_list(existing.manualTasks) if existing else None
|
||||
merged = _merge_lists(existing_list, data.manual_tasks)
|
||||
update_data["manualTasks"] = SafeJson(merged)
|
||||
create_data["manualTasks"] = SafeJson(merged)
|
||||
|
||||
if data.automation_goals is not None:
|
||||
existing_list = _json_to_list(existing.automationGoals) if existing else None
|
||||
merged = _merge_lists(existing_list, data.automation_goals)
|
||||
update_data["automationGoals"] = SafeJson(merged)
|
||||
create_data["automationGoals"] = SafeJson(merged)
|
||||
|
||||
if data.current_software is not None:
|
||||
existing_list = _json_to_list(existing.currentSoftware) if existing else None
|
||||
merged = _merge_lists(existing_list, data.current_software)
|
||||
update_data["currentSoftware"] = SafeJson(merged)
|
||||
create_data["currentSoftware"] = SafeJson(merged)
|
||||
|
||||
if data.existing_automation is not None:
|
||||
existing_list = _json_to_list(existing.existingAutomation) if existing else None
|
||||
merged = _merge_lists(existing_list, data.existing_automation)
|
||||
update_data["existingAutomation"] = SafeJson(merged)
|
||||
create_data["existingAutomation"] = SafeJson(merged)
|
||||
|
||||
# Upsert
|
||||
record = await UserBusinessUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": UserBusinessUnderstandingCreateInput(**create_data),
|
||||
"update": update_data,
|
||||
},
|
||||
)
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
# Update cache with new understanding
|
||||
await _set_cache(user_id, understanding)
|
||||
|
||||
return understanding
|
||||
|
||||
|
||||
async def clear_business_understanding(user_id: str) -> bool:
|
||||
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||
# Delete from cache first
|
||||
await _delete_cache(user_id)
|
||||
|
||||
try:
|
||||
await UserBusinessUnderstanding.prisma().delete(where={"userId": user_id})
|
||||
return True
|
||||
except Exception:
|
||||
# Record might not exist
|
||||
return False
|
||||
|
||||
|
||||
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
|
||||
"""Format business understanding as text for system prompt injection."""
|
||||
sections = []
|
||||
|
||||
# User info section
|
||||
user_info = []
|
||||
if understanding.user_name:
|
||||
user_info.append(f"Name: {understanding.user_name}")
|
||||
if understanding.job_title:
|
||||
user_info.append(f"Job Title: {understanding.job_title}")
|
||||
if user_info:
|
||||
sections.append("## User\n" + "\n".join(user_info))
|
||||
|
||||
# Business section
|
||||
business_info = []
|
||||
if understanding.business_name:
|
||||
business_info.append(f"Company: {understanding.business_name}")
|
||||
if understanding.industry:
|
||||
business_info.append(f"Industry: {understanding.industry}")
|
||||
if understanding.business_size:
|
||||
business_info.append(f"Size: {understanding.business_size}")
|
||||
if understanding.user_role:
|
||||
business_info.append(f"Role Context: {understanding.user_role}")
|
||||
if business_info:
|
||||
sections.append("## Business\n" + "\n".join(business_info))
|
||||
|
||||
# Processes section
|
||||
processes = []
|
||||
if understanding.key_workflows:
|
||||
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
|
||||
if understanding.daily_activities:
|
||||
processes.append(
|
||||
f"Daily Activities: {', '.join(understanding.daily_activities)}"
|
||||
)
|
||||
if processes:
|
||||
sections.append("## Processes\n" + "\n".join(processes))
|
||||
|
||||
# Pain points section
|
||||
pain_points = []
|
||||
if understanding.pain_points:
|
||||
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
|
||||
if understanding.bottlenecks:
|
||||
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
|
||||
if understanding.manual_tasks:
|
||||
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
|
||||
if pain_points:
|
||||
sections.append("## Pain Points\n" + "\n".join(pain_points))
|
||||
|
||||
# Goals section
|
||||
if understanding.automation_goals:
|
||||
sections.append(
|
||||
"## Automation Goals\n"
|
||||
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
|
||||
)
|
||||
|
||||
# Current tools section
|
||||
tools_info = []
|
||||
if understanding.current_software:
|
||||
tools_info.append(
|
||||
f"Current Software: {', '.join(understanding.current_software)}"
|
||||
)
|
||||
if understanding.existing_automation:
|
||||
tools_info.append(
|
||||
f"Existing Automation: {', '.join(understanding.existing_automation)}"
|
||||
)
|
||||
if tools_info:
|
||||
sections.append("## Current Tools\n" + "\n".join(tools_info))
|
||||
|
||||
# Additional notes
|
||||
if understanding.additional_notes:
|
||||
sections.append(f"## Additional Context\n{understanding.additional_notes}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return "# User Business Context\n\n" + "\n\n".join(sections)
|
||||
@@ -2,11 +2,6 @@ import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -18,7 +13,6 @@ from backend.data.execution import (
|
||||
get_block_error_stats,
|
||||
get_child_graph_executions,
|
||||
get_execution_kv_data,
|
||||
get_execution_outputs_by_node_exec_id,
|
||||
get_frequently_executed_graphs,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
@@ -66,6 +60,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -151,7 +147,6 @@ class DatabaseManager(AppService):
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_execution_outputs_by_node_exec_id = _(get_execution_outputs_by_node_exec_id)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
@@ -282,7 +277,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
|
||||
@@ -48,8 +48,27 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
@@ -76,24 +95,7 @@ from backend.util.retry import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
from .utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
@@ -114,40 +116,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -165,8 +133,9 @@ def execute_graph(
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||
processor: ExecutionProcessor = _tls.processor
|
||||
return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
|
||||
return _tls.processor.on_graph_execution(
|
||||
graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -174,8 +143,8 @@ T = TypeVar("T")
|
||||
|
||||
async def execute_node(
|
||||
node: Node,
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> BlockOutput:
|
||||
@@ -200,7 +169,6 @@ async def execute_node(
|
||||
node_id = data.node_id
|
||||
node_block = node.block
|
||||
execution_context = data.execution_context
|
||||
creds_manager = execution_processor.creds_manager
|
||||
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -244,7 +212,6 @@ async def execute_node(
|
||||
"node_exec_id": node_exec_id,
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"execution_processor": execution_processor,
|
||||
}
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
@@ -641,8 +608,8 @@ class ExecutionProcessor:
|
||||
|
||||
async for output_name, output_data in execute_node(
|
||||
node=node,
|
||||
creds_manager=self.creds_manager,
|
||||
data=node_exec,
|
||||
execution_processor=self,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
):
|
||||
@@ -893,17 +860,12 @@ class ExecutionProcessor:
|
||||
execution_stats_lock = threading.Lock()
|
||||
|
||||
# State holders ----------------------------------------------------
|
||||
self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
|
||||
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
|
||||
NodeExecutionProgress
|
||||
)
|
||||
self.running_node_evaluation: dict[str, Future] = {}
|
||||
self.execution_stats = execution_stats
|
||||
self.execution_stats_lock = execution_stats_lock
|
||||
running_node_evaluation: dict[str, Future] = {}
|
||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
|
||||
running_node_execution = self.running_node_execution
|
||||
running_node_evaluation = self.running_node_evaluation
|
||||
|
||||
try:
|
||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
@@ -1295,40 +1257,12 @@ class ExecutionProcessor:
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
@@ -1342,7 +1276,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
|
||||
@@ -1,560 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def async_iter(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72, # $0.72
|
||||
amount=-714, # Attempting to spend $7.14
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate first-time notification (set returns True)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = True # Key was newly set
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
assert notification_call.type == NotificationType.ZERO_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, ZeroBalanceData)
|
||||
assert notification_call.data.current_balance == 72
|
||||
|
||||
# Verify Redis was checked with correct key pattern
|
||||
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
mock_redis_client.set.assert_called_once()
|
||||
call_args = mock_redis_client.set.call_args
|
||||
assert call_args[0][0] == expected_key
|
||||
assert call_args[1]["nx"] is True
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Insufficient Funds Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "Test Agent" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate duplicate notification (set returns False/None)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = None # Key already existed
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was NOT queued (deduplication worked)
|
||||
mock_queue_notif.assert_not_called()
|
||||
|
||||
# Verify Discord alert was NOT sent (deduplication worked)
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
# Both calls return True (first time for each agent)
|
||||
mock_redis_client.set.return_value = True
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify Discord alerts were sent for both agents
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
# Verify Redis was called with different keys
|
||||
assert mock_redis_client.set.call_count == 2
|
||||
calls = mock_redis_client.set.call_args_list
|
||||
assert (
|
||||
calls[0][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
|
||||
)
|
||||
assert (
|
||||
calls[1][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
"""Test that clearing notifications removes all keys for a user."""
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return some keys as an async iterator
|
||||
mock_keys = [
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
|
||||
# delete is awaited, so use AsyncMock
|
||||
mock_redis_client.delete = AsyncMock(return_value=3)
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify correct pattern was used
|
||||
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
|
||||
|
||||
# Verify delete was called with all keys
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
# Verify return value
|
||||
assert result == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
|
||||
"""Test clearing notifications when there are no keys to clear."""
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return no keys as an async iterator
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify delete was not called
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
# Verify return value
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that clearing notifications handles Redis errors gracefully."""
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
side_effect=Exception("Redis connection failed")
|
||||
)
|
||||
|
||||
# Clear notifications should not raise, just return 0
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify it returned 0 (graceful failure)
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to raise an error
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.side_effect = Exception("Redis connection error")
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was still queued despite Redis error
|
||||
mock_queue_notif.assert_called_once()
|
||||
|
||||
# Verify Discord alert was still sent despite Redis error
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create a concrete instance
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with GRANT type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # Positive amount
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
is_active=True, # Active transaction
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
mock_redis_client.delete = AsyncMock(return_value=0)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with TOP_UP type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000, # Positive amount
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was attempted
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with is_active=False (should NOT clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=False, # Inactive - pending Stripe payment
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with USAGE type (spending, should NOT clear)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-100, # Negative - spending credits
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-enable"
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
mock_transaction.amount = 1000
|
||||
mock_transaction.type = CreditTransactionType.TOP_UP
|
||||
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_transaction
|
||||
)
|
||||
|
||||
# Mock the query to return updated balance
|
||||
mock_query.return_value = [{"balance": 1500}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _enable_transaction (simulates Stripe checkout completion)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
await credit_model._enable_transaction(
|
||||
transaction_key="cs_test_123",
|
||||
user_id=user_id,
|
||||
metadata=SafeJson({"payment": "completed"}),
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
@@ -3,16 +3,16 @@ import logging
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.store.model
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -356,7 +356,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -444,7 +444,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -485,7 +485,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
slug=test_graph.id,
|
||||
@@ -514,7 +514,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
|
||||
admin_user = await create_test_user(alt_user=True)
|
||||
await server.agent_server.test_review_store_listing(
|
||||
backend.api.features.store.model.ReviewSubmissionRequest(
|
||||
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||
store_listing_version_id=slv_id,
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
@@ -523,7 +523,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Add the approved store listing to the admin user's library so they can execute it
|
||||
from backend.api.features.library.db import add_store_agent_to_library
|
||||
from backend.server.v2.library.db import add_store_agent_to_library
|
||||
|
||||
await add_store_agent_to_library(
|
||||
store_listing_version_id=slv_id, user_id=admin_user.id
|
||||
|
||||
@@ -23,7 +23,6 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
@@ -243,12 +242,6 @@ def cleanup_expired_files():
|
||||
run_async(cleanup_expired_files_async())
|
||||
|
||||
|
||||
def cleanup_oauth_tokens():
|
||||
"""Clean up expired OAuth tokens from the database."""
|
||||
# Wait for completion
|
||||
run_async(cleanup_expired_oauth_tokens())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
"""Check execution accuracy and send alerts if drops are detected."""
|
||||
return report_execution_accuracy_alerts()
|
||||
@@ -453,17 +446,6 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# OAuth Token Cleanup - configurable interval
|
||||
self.scheduler.add_job(
|
||||
cleanup_oauth_tokens,
|
||||
id="cleanup_oauth_tokens",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.oauth_token_cleanup_interval_hours
|
||||
* 3600, # Convert hours to seconds
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Execution Accuracy Monitoring - configurable interval
|
||||
self.scheduler.add_job(
|
||||
execution_accuracy_alerts,
|
||||
@@ -622,11 +604,6 @@ class Scheduler(AppService):
|
||||
"""Manually trigger cleanup of expired cloud storage files."""
|
||||
return cleanup_expired_files()
|
||||
|
||||
@expose
|
||||
def execute_cleanup_oauth_tokens(self):
|
||||
"""Manually trigger cleanup of expired OAuth tokens."""
|
||||
return cleanup_oauth_tokens()
|
||||
|
||||
@expose
|
||||
def execute_report_execution_accuracy_alerts(self):
|
||||
"""Manually trigger execution accuracy alert checking."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.data import db
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -149,10 +149,10 @@ async def setup_webhook_for_block(
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.api.features.library.db import create_preset
|
||||
from backend.api.features.library.model import LibraryAgentPresetCreatable
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.app import run_processes
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user