Compare commits

..

5 Commits

Author SHA1 Message Date
Otto
62c9e840b8 fix: preserve non-dict error bodies as truncated strings
Addresses CodeRabbit feedback: non-dict error bodies (e.g., HTML error
pages from proxies) were silently discarded, losing diagnostic info.
Now returns truncated string representation instead of None.
2026-02-09 07:53:44 +00:00
Otto
495d01b09b fix: use getattr for pyright type checking on dynamic error attributes
Fixes pyright errors:
- Cannot access attribute 'code' for class 'Exception'
- Cannot access attribute 'param' for class 'Exception'

Using getattr() instead of direct attribute access satisfies pyright's
type checker while maintaining the same runtime behavior.
2026-02-09 07:52:45 +00:00
Otto
b334f1a843 fix: defensive error body extraction to handle None and non-dict cases
Addresses review feedback: body.get('error', {}).get('message') is unsafe
when body['error'] is None or not a dict. Now properly checks isinstance()
before accessing nested fields, and falls back to top-level message field.
2026-02-09 07:42:06 +00:00
Otto
5fd1482944 Merge branch 'dev' into fix/copilot-error-logging
Resolved conflict in service.py by keeping both:
- Error logging functions (_log_api_error, _extract_api_error_details, _sanitize_error_body)
- Streaming continuation function (_generate_llm_continuation_with_streaming)
2026-02-09 07:41:02 +00:00
Otto
efd1e96235 feat(copilot): Add detailed API error logging for debugging
Adds comprehensive error logging for OpenRouter/OpenAI API errors to help
diagnose issues like provider routing failures, context length exceeded,
rate limits, etc.

Changes:
- Add _extract_api_error_details() to extract rich info from API errors
  including status code, response body, OpenRouter headers, etc.
- Add _log_api_error() helper to log errors with context (session ID,
  message count, model, retry count)
- Update error handling in _stream_chat_chunks() to use new logging
- Update error handling in _generate_llm_continuation() to use new logging
- Extract provider's error message from response body for better user feedback

This helps debug issues like SECRT-1859 where OpenRouter returns
'provider returned error' with provider_name='unknown' without
capturing the actual error details.

Refs: SECRT-1859
2026-02-03 12:36:54 +00:00
357 changed files with 13661 additions and 14903 deletions

View File

@@ -1,5 +1,29 @@
version: 2
updates:
# autogpt_libs (Poetry project)
- package-ecosystem: "pip"
directory: "autogpt_platform/autogpt_libs"
schedule:
interval: "weekly"
open-pull-requests-limit: 10
target-branch: "dev"
commit-message:
prefix: "chore(libs/deps)"
prefix-development: "chore(libs/deps-dev)"
ignore:
- dependency-name: "poetry"
groups:
production-dependencies:
dependency-type: "production"
update-types:
- "minor"
- "patch"
development-dependencies:
dependency-type: "development"
update-types:
- "minor"
- "patch"
# backend (Poetry project)
- package-ecosystem: "pip"
directory: "autogpt_platform/backend"

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.workflow_run.head_branch }}
fetch-depth: 0

View File

@@ -30,7 +30,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -78,7 +78,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"

View File

@@ -40,7 +40,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -94,7 +94,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"

View File

@@ -58,7 +58,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -27,7 +27,7 @@ jobs:
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0

View File

@@ -28,7 +28,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1

View File

@@ -25,7 +25,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
@@ -52,7 +52,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -17,7 +17,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name || 'master' }}
@@ -45,7 +45,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -6,11 +6,13 @@ on:
paths:
- ".github/workflows/platform-backend-ci.yml"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
pull_request:
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
merge_group:
concurrency:
@@ -66,7 +68,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true

View File

@@ -82,7 +82,7 @@ jobs:
- name: Dispatch Deploy Event
if: steps.check_status.outputs.should_deploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -110,7 +110,7 @@ jobs:
- name: Dispatch Undeploy Event (from comment)
if: steps.check_status.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -168,7 +168,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -31,7 +31,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Check for component changes
uses: dorny/paths-filter@v3
@@ -42,7 +42,7 @@ jobs:
- 'autogpt_platform/frontend/src/components/**'
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -71,10 +71,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -107,12 +107,12 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -148,12 +148,12 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -277,12 +277,12 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"

View File

@@ -29,10 +29,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -56,19 +56,19 @@ jobs:
run: pnpm install --frozen-lockfile
types:
runs-on: big-boi
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -85,7 +85,7 @@ jobs:
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v5

View File

@@ -11,7 +11,7 @@ jobs:
steps:
# - name: Wait some time for all actions to start
# run: sleep 30
- uses: actions/checkout@v6
- uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Set up Python

View File

@@ -8,7 +8,7 @@ AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`backend/api/auth`, `backend/logging`): Auth, logging, and common utilities integrated into backend
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation

View File

@@ -0,0 +1,3 @@
# AutoGPT Libs
This is a new project to store shared functionality across different services in the AutoGPT Platform (e.g. authentication)

View File

@@ -1,6 +1,6 @@
import hashlib
from backend.api.auth.api_key.keysmith import APIKeySmith
from autogpt_libs.api_key.keysmith import APIKeySmith
def test_generate_api_key():

View File

@@ -9,7 +9,7 @@ import os
import pytest
from pytest_mock import MockerFixture
from backend.api.auth.config import AuthConfigError, Settings
from autogpt_libs.auth.config import AuthConfigError, Settings
def test_environment_variable_precedence(mocker: MockerFixture):
@@ -228,7 +228,7 @@ def test_no_crypto_warning(mocker: MockerFixture, caplog: pytest.LogCaptureFixtu
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
# Mock has_crypto to return False
mocker.patch("backend.api.auth.config.has_crypto", False)
mocker.patch("autogpt_libs.auth.config.has_crypto", False)
with caplog.at_level(logging.WARNING):
Settings()

View File

@@ -43,7 +43,7 @@ def get_optional_user_id(
try:
# Parse JWT token to get user ID
from backend.api.auth.jwt_utils import parse_jwt_token
from autogpt_libs.auth.jwt_utils import parse_jwt_token
payload = parse_jwt_token(credentials.credentials)
return payload.get("sub")

View File

@@ -11,12 +11,12 @@ from fastapi import FastAPI, HTTPException, Request, Security
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from backend.api.auth.dependencies import (
from autogpt_libs.auth.dependencies import (
get_user_id,
requires_admin_user,
requires_user,
)
from backend.api.auth.models import User
from autogpt_libs.auth.models import User
class TestAuthDependencies:
@@ -53,7 +53,7 @@ class TestAuthDependencies:
# Mock get_jwt_payload to return our test payload
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = await requires_user(jwt_payload)
assert isinstance(user, User)
@@ -70,7 +70,7 @@ class TestAuthDependencies:
}
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = await requires_user(jwt_payload)
assert user.user_id == "admin-456"
@@ -105,7 +105,7 @@ class TestAuthDependencies:
}
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = await requires_admin_user(jwt_payload)
assert user.user_id == "admin-789"
@@ -137,7 +137,7 @@ class TestAuthDependencies:
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(request, jwt_payload)
assert user_id == "user-id-xyz"
@@ -344,7 +344,7 @@ class TestAuthDependenciesEdgeCases:
):
"""Test that errors propagate correctly through dependencies."""
# Import verify_user to test it directly since dependencies use FastAPI Security
from backend.api.auth.jwt_utils import verify_user
from autogpt_libs.auth.jwt_utils import verify_user
with pytest.raises(HTTPException) as exc_info:
verify_user(payload, admin_only=admin_only)
@@ -354,7 +354,7 @@ class TestAuthDependenciesEdgeCases:
async def test_dependency_valid_user(self):
"""Test valid user case for dependency."""
# Import verify_user to test it directly since dependencies use FastAPI Security
from backend.api.auth.jwt_utils import verify_user
from autogpt_libs.auth.jwt_utils import verify_user
# Valid case
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
@@ -376,16 +376,16 @@ class TestAdminImpersonation:
}
# Mock verify_user to return admin user data
mock_verify_user = mocker.patch("backend.api.auth.dependencies.verify_user")
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
mock_verify_user.return_value = Mock(
user_id="admin-456", email="admin@example.com", role="admin"
)
# Mock logger to verify audit logging
mock_logger = mocker.patch("backend.api.auth.dependencies.logger")
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(request, jwt_payload)
@@ -412,13 +412,13 @@ class TestAdminImpersonation:
}
# Mock verify_user to return regular user data
mock_verify_user = mocker.patch("backend.api.auth.dependencies.verify_user")
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
mock_verify_user.return_value = Mock(
user_id="regular-user", email="user@example.com", role="user"
)
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
with pytest.raises(HTTPException) as exc_info:
@@ -439,7 +439,7 @@ class TestAdminImpersonation:
}
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(request, jwt_payload)
@@ -459,7 +459,7 @@ class TestAdminImpersonation:
}
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(request, jwt_payload)
@@ -479,16 +479,16 @@ class TestAdminImpersonation:
}
# Mock verify_user to return admin user data
mock_verify_user = mocker.patch("backend.api.auth.dependencies.verify_user")
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
mock_verify_user.return_value = Mock(
user_id="admin-999", email="superadmin@company.com", role="admin"
)
# Mock logger to capture audit trail
mock_logger = mocker.patch("backend.api.auth.dependencies.logger")
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(request, jwt_payload)
@@ -515,7 +515,7 @@ class TestAdminImpersonation:
}
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(request, jwt_payload)
@@ -535,16 +535,16 @@ class TestAdminImpersonation:
}
# Mock verify_user to return admin user data
mock_verify_user = mocker.patch("backend.api.auth.dependencies.verify_user")
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
mock_verify_user.return_value = Mock(
user_id="admin-456", email="admin@example.com", role="admin"
)
# Mock logger
mock_logger = mocker.patch("backend.api.auth.dependencies.logger")
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
mocker.patch(
"backend.api.auth.dependencies.get_jwt_payload", return_value=jwt_payload
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(request, jwt_payload)

View File

@@ -3,11 +3,13 @@ Comprehensive tests for auth helpers module to achieve 100% coverage.
Tests OpenAPI schema generation and authentication response handling.
"""
from unittest import mock
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from backend.api.auth.helpers import add_auth_responses_to_openapi
from backend.api.auth.jwt_utils import bearer_jwt_auth
from autogpt_libs.auth.helpers import add_auth_responses_to_openapi
from autogpt_libs.auth.jwt_utils import bearer_jwt_auth
def test_add_auth_responses_to_openapi_basic():
@@ -17,7 +19,7 @@ def test_add_auth_responses_to_openapi_basic():
# Add some test endpoints with authentication
from fastapi import Depends
from backend.api.auth.dependencies import requires_user
from autogpt_libs.auth.dependencies import requires_user
@app.get("/protected", dependencies=[Depends(requires_user)])
def protected_endpoint():
@@ -62,7 +64,7 @@ def test_add_auth_responses_to_openapi_with_security():
# Mock endpoint with security
from fastapi import Security
from backend.api.auth.dependencies import get_user_id
from autogpt_libs.auth.dependencies import get_user_id
@app.get("/secured")
def secured_endpoint(user_id: str = Security(get_user_id)):
@@ -128,7 +130,7 @@ def test_add_auth_responses_to_openapi_existing_responses():
from fastapi import Security
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get(
"/with-responses",
@@ -195,8 +197,8 @@ def test_add_auth_responses_to_openapi_multiple_security_schemes():
from fastapi import Security
from backend.api.auth.dependencies import requires_admin_user, requires_user
from backend.api.auth.models import User
from autogpt_libs.auth.dependencies import requires_admin_user, requires_user
from autogpt_libs.auth.models import User
@app.get("/multi-auth")
def multi_auth(
@@ -225,29 +227,26 @@ def test_add_auth_responses_to_openapi_empty_components():
"""Test when OpenAPI schema has no components section initially."""
app = FastAPI()
def mock_openapi():
schema = get_openapi(
title=app.title,
version=app.version,
routes=app.routes,
)
# Remove components if it exists to test component creation
# Mock get_openapi to return schema without components
original_get_openapi = get_openapi
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Remove components if it exists
if "components" in schema:
del schema["components"]
return schema
# Replace app's openapi method
app.openapi = mock_openapi
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
# Apply customization (this wraps our mock)
add_auth_responses_to_openapi(app)
schema = app.openapi()
schema = app.openapi()
# Components should be created
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# Components should be created
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
def test_add_auth_responses_to_openapi_all_http_methods():
@@ -256,7 +255,7 @@ def test_add_auth_responses_to_openapi_all_http_methods():
from fastapi import Security
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get("/resource")
def get_resource(jwt: dict = Security(get_jwt_payload)):
@@ -334,59 +333,53 @@ def test_endpoint_without_responses_section():
app = FastAPI()
from fastapi import Security
from fastapi.openapi.utils import get_openapi as original_get_openapi
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# Create endpoint
@app.get("/no-responses")
def endpoint_without_responses(jwt: dict = Security(get_jwt_payload)):
return {"data": "test"}
# Create a mock openapi method that removes responses from the endpoint
def mock_openapi():
schema = get_openapi(
title=app.title,
version=app.version,
routes=app.routes,
)
# Remove responses from our endpoint to test response creation
# Mock get_openapi to remove responses from the endpoint
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Remove responses from our endpoint to trigger line 40
if "/no-responses" in schema.get("paths", {}):
if "get" in schema["paths"]["/no-responses"]:
# Delete responses to force the code to create it
if "responses" in schema["paths"]["/no-responses"]["get"]:
del schema["paths"]["/no-responses"]["get"]["responses"]
return schema
# Replace app's openapi method
app.openapi = mock_openapi
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
# Apply customization (this wraps our mock)
add_auth_responses_to_openapi(app)
# Get schema and verify 401 was added
schema = app.openapi()
# Get schema and verify 401 was added
schema = app.openapi()
# The endpoint should now have 401 response
if "/no-responses" in schema["paths"]:
if "get" in schema["paths"]["/no-responses"]:
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
assert "401" in responses
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
# The endpoint should now have 401 response
if "/no-responses" in schema["paths"]:
if "get" in schema["paths"]["/no-responses"]:
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
assert "401" in responses
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
def test_components_with_existing_responses():
"""Test when components already has a responses section."""
app = FastAPI()
# Create a mock openapi method that adds existing components/responses
def mock_openapi():
schema = get_openapi(
title=app.title,
version=app.version,
routes=app.routes,
)
# Mock get_openapi to return schema with existing components/responses
from fastapi.openapi.utils import get_openapi as original_get_openapi
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Add existing components/responses
if "components" not in schema:
schema["components"] = {}
@@ -395,21 +388,21 @@ def test_components_with_existing_responses():
}
return schema
# Replace app's openapi method
app.openapi = mock_openapi
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
# Apply customization (this wraps our mock)
add_auth_responses_to_openapi(app)
schema = app.openapi()
schema = app.openapi()
# Both responses should exist
assert "ExistingResponse" in schema["components"]["responses"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# Both responses should exist
assert "ExistingResponse" in schema["components"]["responses"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# Verify our 401 response structure
error_response = schema["components"]["responses"]["HTTP401NotAuthenticatedError"]
assert error_response["description"] == "Authentication required"
# Verify our 401 response structure
error_response = schema["components"]["responses"][
"HTTP401NotAuthenticatedError"
]
assert error_response["description"] == "Authentication required"
def test_openapi_schema_persistence():
@@ -418,7 +411,7 @@ def test_openapi_schema_persistence():
from fastapi import Security
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get("/test")
def test_endpoint(jwt: dict = Security(get_jwt_payload)):

View File

@@ -12,9 +12,9 @@ from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from pytest_mock import MockerFixture
from backend.api.auth import config, jwt_utils
from backend.api.auth.config import Settings
from backend.api.auth.models import User
from autogpt_libs.auth import config, jwt_utils
from autogpt_libs.auth.config import Settings
from autogpt_libs.auth.models import User
MOCK_JWT_SECRET = "test-secret-key-with-at-least-32-characters"
TEST_USER_PAYLOAD = {

View File

@@ -0,0 +1,33 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class RateLimitSettings(BaseSettings):
redis_host: str = Field(
default="redis://localhost:6379",
description="Redis host",
validation_alias="REDIS_HOST",
)
redis_port: str = Field(
default="6379", description="Redis port", validation_alias="REDIS_PORT"
)
redis_password: Optional[str] = Field(
default=None,
description="Redis password",
validation_alias="REDIS_PASSWORD",
)
requests_per_minute: int = Field(
default=60,
description="Maximum number of requests allowed per minute per API key",
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
)
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
RATE_LIMIT_SETTINGS = RateLimitSettings()

View File

@@ -0,0 +1,51 @@
import time
from typing import Tuple
from redis import Redis
from .config import RATE_LIMIT_SETTINGS
class RateLimiter:
def __init__(
self,
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
):
self.redis = Redis(
host=redis_host,
port=int(redis_port),
password=redis_password,
decode_responses=True,
)
self.window = 60
self.max_requests = requests_per_minute
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
"""
Check if request is within rate limits.
Args:
api_key_id: The API key identifier to check
Returns:
Tuple of (is_allowed, remaining_requests, reset_time)
"""
now = time.time()
window_start = now - self.window
key = f"ratelimit:{api_key_id}:1min"
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, window_start)
pipe.zadd(key, {str(now): now})
pipe.zcount(key, window_start, now)
pipe.expire(key, self.window)
_, _, request_count, _ = pipe.execute()
remaining = max(0, self.max_requests - request_count)
reset_time = int(now + self.window)
return request_count <= self.max_requests, remaining, reset_time

View File

@@ -0,0 +1,32 @@
from fastapi import HTTPException, Request
from starlette.middleware.base import RequestResponseEndpoint
from .limiter import RateLimiter
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
"""FastAPI middleware for rate limiting API requests."""
limiter = RateLimiter()
if not request.url.path.startswith("/api"):
return await call_next(request)
api_key = request.headers.get("Authorization")
if not api_key:
return await call_next(request)
api_key = api_key.replace("Bearer ", "")
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
if not is_allowed:
raise HTTPException(
status_code=429, detail="Rate limit exceeded. Please try again later."
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(reset_time)
return response

View File

@@ -0,0 +1,76 @@
from typing import Annotated, Any, Literal, Optional, TypedDict
from uuid import uuid4
from pydantic import BaseModel, Field, SecretStr, field_serializer
class _BaseCredentials(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
provider: str
title: Optional[str]
@field_serializer("*")
def dump_secret_strings(value: Any, _info):
if isinstance(value, SecretStr):
return value.get_secret_value()
return value
class OAuth2Credentials(_BaseCredentials):
type: Literal["oauth2"] = "oauth2"
username: Optional[str]
"""Username of the third-party service user that these credentials belong to"""
access_token: SecretStr
access_token_expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the access token expires (if at all)"""
refresh_token: Optional[SecretStr]
refresh_token_expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
scopes: list[str]
metadata: dict[str, Any] = Field(default_factory=dict)
def bearer(self) -> str:
return f"Bearer {self.access_token.get_secret_value()}"
class APIKeyCredentials(_BaseCredentials):
type: Literal["api_key"] = "api_key"
api_key: SecretStr
expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
def bearer(self) -> str:
return f"Bearer {self.api_key.get_secret_value()}"
Credentials = Annotated[
OAuth2Credentials | APIKeyCredentials,
Field(discriminator="type"),
]
CredentialsType = Literal["api_key", "oauth2"]
class OAuthState(BaseModel):
token: str
provider: str
expires_at: int
code_verifier: Optional[str] = None
scopes: list[str]
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
class UserMetadata(BaseModel):
integration_credentials: list[Credentials] = Field(default_factory=list)
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
class UserMetadataRaw(TypedDict, total=False):
integration_credentials: list[dict]
integration_oauth_states: list[dict]
class UserIntegrations(BaseModel):
credentials: list[Credentials] = Field(default_factory=list)
oauth_states: list[OAuthState] = Field(default_factory=list)

2896
autogpt_platform/autogpt_libs/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
[tool.poetry]
name = "autogpt-libs"
version = "0.2.0"
description = "Shared libraries across AutoGPT Platform"
authors = ["AutoGPT team <info@agpt.co>"]
readme = "README.md"
packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
cryptography = "^46.0"
expiringdict = "^1.2.2"
fastapi = "^0.128.0"
google-cloud-logging = "^3.13.0"
launchdarkly-server-sdk = "^9.14.1"
pydantic = "^2.12.5"
pydantic-settings = "^2.12.0"
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
redis = "^6.2.0"
supabase = "^2.27.2"
uvicorn = "^0.40.0"
[tool.poetry.group.dev.dependencies]
pyright = "^1.1.408"
pytest = "^8.4.1"
pytest-asyncio = "^1.3.0"
pytest-mock = "^3.15.1"
pytest-cov = "^6.2.1"
ruff = "^0.15.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
line-length = 88
[tool.ruff.lint]
extend-select = ["I"] # sort dependencies

View File

@@ -39,7 +39,8 @@ ENV PATH=/opt/poetry/bin:$PATH
RUN pip3 install poetry --break-system-packages
# Copy and install dependencies (autogpt_libs merged into backend - OPEN-2998)
# Copy and install dependencies
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
WORKDIR /app/autogpt_platform/backend
RUN poetry install --no-ansi --no-root
@@ -82,9 +83,11 @@ COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-pyth
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
# autogpt_libs merged into backend (OPEN-2998)
RUN mkdir -p /app/autogpt_platform/autogpt_libs
RUN mkdir -p /app/autogpt_platform/backend
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
WORKDIR /app/autogpt_platform/backend

View File

@@ -132,7 +132,7 @@ def test_endpoint_success(snapshot: Snapshot):
### Testing with Authentication
For the main API routes that use JWT authentication, auth is provided by the `backend.api.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
For the main API routes that use JWT authentication, auth is provided by the `autogpt_libs.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
If the test doesn't need the `user_id` specifically, mocking is not necessary as during tests auth is disabled anyway (see `conftest.py`).
@@ -158,7 +158,7 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user['get_jwt_payload']
yield
@@ -171,7 +171,7 @@ For admin-only endpoints, use `mock_jwt_admin` instead:
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_admin):
"""Setup auth overrides for admin tests"""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin['get_jwt_payload']
yield

View File

@@ -1,10 +1,10 @@
import logging
import typing
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Body, Security
from prisma.enums import CreditTransactionType
from backend.api.auth import get_user_id, requires_admin_user
from backend.data.credit import admin_get_user_history, get_user_credit_model
from backend.util.json import SafeJson

View File

@@ -6,9 +6,9 @@ import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
from backend.api.auth.jwt_utils import get_jwt_payload
from backend.data.model import UserTransaction
from backend.util.json import SafeJson
from backend.util.models import Pagination

View File

@@ -3,10 +3,10 @@ import logging
from datetime import datetime
from typing import Optional
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, HTTPException, Security
from pydantic import BaseModel, Field
from backend.api.auth import get_user_id, requires_admin_user
from backend.blocks.llm import LlmModel
from backend.data.analytics import (
AccuracyTrendsResponse,

View File

@@ -2,11 +2,11 @@ import logging
import tempfile
import typing
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
import backend.api.auth
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
prefix="/admin",
tags=["store", "admin"],
dependencies=[fastapi.Security(backend.api.auth.requires_admin_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
@@ -73,7 +73,7 @@ async def get_admin_listings_with_versions(
async def review_submission(
store_listing_version_id: str,
request: store_model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Review a store listing submission.
@@ -117,7 +117,7 @@ async def review_submission(
tags=["store", "admin"],
)
async def admin_download_agent_file(
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),

View File

@@ -5,10 +5,10 @@ from typing import Annotated
import fastapi
import pydantic
from autogpt_libs.auth import get_user_id
from autogpt_libs.auth.dependencies import requires_user
import backend.data.analytics
from backend.api.auth import get_user_id
from backend.api.auth.dependencies import requires_user
router = fastapi.APIRouter(dependencies=[fastapi.Security(requires_user)])
logger = logging.getLogger(__name__)

View File

@@ -20,7 +20,7 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module."""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield

View File

@@ -2,8 +2,8 @@ import logging
from typing import Annotated, Sequence
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from backend.api.auth.dependencies import get_user_id, requires_user
from backend.integrations.providers import ProviderName
from backend.util.models import Pagination

View File

@@ -93,12 +93,6 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Enable adaptive thinking for Claude models via OpenRouter",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):

View File

@@ -45,7 +45,10 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(data=data)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(

View File

@@ -10,8 +10,6 @@ from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -20,10 +18,6 @@ class ResponseType(str, Enum):
START = "start"
FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming
TEXT_START = "text-start"
TEXT_DELTA = "text-delta"
@@ -63,16 +57,6 @@ class StreamStart(StreamBaseResponse):
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse):
"""End of message/stream."""
@@ -80,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ==========
@@ -153,7 +117,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Keep these for internal backend use
# Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field(
default=None, description="Name of the tool that was executed"
)
@@ -161,17 +125,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ==========
@@ -195,18 +148,6 @@ class StreamError(StreamBaseResponse):
default=None, description="Additional error details"
)
def to_sse(self) -> str:
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
The AI SDK uses z.strictObject({type, errorText}) which rejects
any extra fields like `code` or `details`.
"""
data = {
"type": self.type.value,
"errorText": self.errorText,
}
return f"data: {json_dumps(data)}\n\n"
class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations.

View File

@@ -5,11 +5,11 @@ import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.api import auth
from backend.util.exceptions import NotFoundError
from . import service as chat_service
@@ -17,29 +17,7 @@ from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -288,36 +266,12 @@ async def stream_chat_post(
"""
import asyncio
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
@@ -326,28 +280,14 @@ async def stream_chat_post(
tool_name="chat",
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
import time as time_module
gen_start_time = time_module.perf_counter()
logger.info(
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
@@ -355,79 +295,25 @@ async def stream_chat_post(
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
):
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
ttfc = first_chunk_time - gen_start_time
logger.info(
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"time_to_first_chunk_ms": ttfc * 1000,
}
},
)
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"time_to_first_chunk_ms": (
ttfc * 1000 if ttfc is not None else None
),
"n_chunks": chunk_count,
}
},
)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
elapsed = time_module.perf_counter() - gen_start_time
logger.error(
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed * 1000,
"error": str(e),
}
},
f"Error in background AI generation for session {session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
@@ -442,70 +328,22 @@ async def stream_chat_post(
return
# Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
@@ -519,18 +357,6 @@ async def stream_chat_post(
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n"
return StreamingResponse(
@@ -548,90 +374,63 @@ async def stream_chat_post(
@router.get(
"/sessions/{session_id}/stream",
)
async def resume_session_stream(
async def stream_chat_get(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Resume an active stream for a session.
Stream chat responses for a session (GET - legacy endpoint).
Called by the AI SDK's ``useChat(resume: true)`` on page load.
Checks for an active (in-progress) task on the session and either replays
the full SSE stream or returns 204 No Content if nothing is running.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier.
session_id: The chat session identifier to associate with the streamed messages.
message: The user's new message to process.
user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns:
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
return Response(status_code=204)
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
if chunk_count < 3:
logger.info(
"Resume stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
logger.info(
"Resume stream completed",
extra={
"session_id": session_id,
"n_chunks": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
yield "data: [DONE]\n\n"
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -639,8 +438,8 @@ async def resume_session_stream(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -952,42 +751,3 @@ async def health_check() -> dict:
"service": "chat",
"version": "0.1.0",
}
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| BlockListResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -52,10 +52,8 @@ from .response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -353,10 +351,6 @@ async def stream_chat_completion(
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None, # {url: str, content: str}
_continuation_message_id: (
str | None
) = None, # Internal: reuse message ID for tool call continuations
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling.
@@ -377,45 +371,21 @@ async def stream_chat_completion(
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={
"json_fields": {
**log_meta,
"message_len": len(message) if message else 0,
"is_user_message": is_user_message,
}
},
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
)
# Only fetch from Redis if session not provided (initial call)
if session is None:
fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info(
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
f"n_messages={len(session.messages) if session else 0}",
extra={
"json_fields": {
**log_meta,
"duration_ms": fetch_time,
"n_messages": len(session.messages) if session else 0,
}
},
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
f"message_count={len(session.messages) if session else 0}"
)
else:
logger.info(
f"[TIMING] Using provided session, messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
f"Using provided session object: {session.session_id}, "
f"message_count={len(session.messages)}"
)
if not session:
@@ -436,25 +406,17 @@ async def stream_chat_completion(
# Track user message in PostHog
if is_user_message:
posthog_start = time.monotonic()
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info(
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
f"Upserting session: {session.session_id} with user id {session.user_id}, "
f"message_count={len(session.messages)}"
)
session = await upsert_chat_session(session)
assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking)
@@ -492,13 +454,7 @@ async def stream_chat_completion(
asyncio.create_task(_update_title())
# Build system prompt with business understanding
prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
)
# Initialize variables for streaming
assistant_response = ChatMessage(
@@ -523,27 +479,13 @@ async def stream_chat_completion(
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
is_continuation = _continuation_message_id is not None
message_id = _continuation_message_id or str(uuid_module.uuid4())
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Only yield message start for the initial call, not for continuations.
setup_time = (time.monotonic() - completion_start) * 1000
logger.info(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
if not is_continuation:
yield StreamStart(messageId=message_id, taskId=_task_id)
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
yield StreamStartStep()
# Yield message start
yield StreamStart(messageId=message_id)
try:
logger.info(
"[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta},
)
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
@@ -643,10 +585,6 @@ async def stream_chat_completion(
)
yield chunk
elif isinstance(chunk, StreamFinish):
if has_done_tool_call:
# Tool calls happened — close the step but don't send message-level finish.
# The continuation will open a new step, and finish will come at the end.
yield StreamFinishStep()
if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended:
@@ -678,8 +616,6 @@ async def stream_chat_completion(
has_saved_assistant_message = True
has_yielded_end = True
# Emit finish-step before finish (resets AI SDK text/reasoning state)
yield StreamFinishStep()
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
@@ -729,10 +665,6 @@ async def stream_chat_completion(
logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
)
# Close the current step before retrying so the recursive call's
# StreamStartStep doesn't produce unbalanced step events.
if not has_yielded_end:
yield StreamFinishStep()
should_retry = True
else:
# Non-retryable error or max retries exceeded
@@ -768,7 +700,6 @@ async def stream_chat_completion(
error_response = StreamError(errorText=error_message)
yield error_response
if not has_yielded_end:
yield StreamFinishStep()
yield StreamFinish()
return
@@ -783,8 +714,6 @@ async def stream_chat_completion(
retry_count=retry_count + 1,
session=session,
context=context,
_continuation_message_id=message_id, # Reuse message ID since start was already sent
_task_id=_task_id,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
@@ -854,8 +783,6 @@ async def stream_chat_completion(
session=session, # Pass session object to avoid Redis refetch
context=context,
tool_call_response=str(tool_response_messages),
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
_task_id=_task_id,
):
yield chunk
@@ -966,21 +893,9 @@ async def _stream_chat_chunks(
SSE formatted JSON response objects
"""
import time as time_module
stream_chunks_start = time_module.perf_counter()
model = config.model
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session.session_id}
if session.user_id:
log_meta["user_id"] = session.user_id
logger.info(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
logger.info("Starting pure chat stream")
messages = session.to_openai_messages()
if system_prompt:
@@ -991,18 +906,12 @@ async def _stream_chat_chunks(
messages = [system_message] + messages
# Apply context window management
context_start = time_module.perf_counter()
context_result = await _manage_context_window(
messages=messages,
model=model,
api_key=config.api_key,
base_url=config.base_url,
)
context_time = (time_module.perf_counter() - context_start) * 1000
logger.info(
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
)
if context_result.error:
if "System prompt dropped" in context_result.error:
@@ -1037,19 +946,9 @@ async def _stream_chat_chunks(
while retry_count <= MAX_RETRIES:
try:
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
)
logger.info(
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"retry_count": retry_count,
}
},
f"Creating OpenAI chat completion stream..."
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
)
# Build extra_body for OpenRouter tracing and PostHog analytics
@@ -1066,11 +965,6 @@ async def _stream_chat_chunks(
:128
] # OpenRouter limit
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in model.lower():
extra_body["reasoning"] = {"enabled": True}
api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create(
model=model,
messages=cast(list[ChatCompletionMessageParam], messages),
@@ -1080,11 +974,6 @@ async def _stream_chat_chunks(
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
extra_body=extra_body,
)
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
)
# Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = []
@@ -1095,13 +984,10 @@ async def _stream_chat_chunks(
# Track if we've started the text block
text_started = False
first_content_chunk = True
chunk_count = 0
# Process the stream
chunk: ChatCompletionChunk
async for chunk in stream:
chunk_count += 1
if chunk.usage:
yield StreamUsage(
promptTokens=chunk.usage.prompt_tokens,
@@ -1124,23 +1010,6 @@ async def _stream_chat_chunks(
if not text_started and text_block_id:
yield StreamTextStart(id=text_block_id)
text_started = True
# Log timing for first content chunk
if first_content_chunk:
first_content_chunk = False
ttfc = (
time_module.perf_counter() - api_call_start
) * 1000
logger.info(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"time_to_first_chunk_ms": ttfc,
"n_chunks": chunk_count,
}
},
)
# Stream the text delta
text_response = StreamTextDelta(
id=text_block_id or "",
@@ -1197,21 +1066,7 @@ async def _stream_chat_chunks(
toolName=tool_calls[idx]["function"]["name"],
)
emitted_start_for_idx.add(idx)
stream_duration = time_module.perf_counter() - api_call_start
logger.info(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
extra={
"json_fields": {
**log_meta,
"stream_duration_ms": stream_duration * 1000,
"finish_reason": finish_reason,
"n_chunks": chunk_count,
"n_tool_calls": len(tool_calls),
}
},
)
logger.info(f"Stream complete. Finish reason: {finish_reason}")
# Yield all accumulated tool calls after the stream is complete
# This ensures all tool call arguments have been fully received
@@ -1231,16 +1086,11 @@ async def _stream_chat_chunks(
# Re-raise to trigger retry logic in the parent function
raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
yield StreamFinish()
return
except Exception as e:
last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1
# Calculate delay with exponential backoff
@@ -1256,12 +1106,26 @@ async def _stream_chat_chunks(
continue # Retry the stream
else:
# Non-retryable error or max retries exceeded
logger.error(
f"Error in stream (not retrying): {e!s}",
exc_info=True,
_log_api_error(
error=e,
session_id=session.session_id if session else None,
message_count=len(messages) if messages else None,
model=model,
retry_count=retry_count,
)
error_code = None
error_text = str(e)
error_details = _extract_api_error_details(e)
if error_details.get("response_body"):
body = error_details["response_body"]
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict) and err.get("message"):
error_text = err["message"]
elif body.get("message"):
error_text = body["message"]
if _is_region_blocked_error(e):
error_code = "MODEL_NOT_AVAILABLE_REGION"
error_text = (
@@ -1278,9 +1142,12 @@ async def _stream_chat_chunks(
# If we exit the retry loop without returning, it means we exhausted retries
if last_error:
logger.error(
f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
exc_info=True,
_log_api_error(
error=last_error,
session_id=session.session_id if session else None,
message_count=len(messages) if messages else None,
model=model,
retry_count=MAX_RETRIES,
)
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
yield StreamFinish()
@@ -1716,7 +1583,6 @@ async def _execute_long_running_tool_with_streaming(
task_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation(
@@ -1833,10 +1699,6 @@ async def _generate_llm_continuation(
if session_id:
extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower():
extra_body["reasoning"] = {"enabled": True}
retry_count = 0
last_error: Exception | None = None
response = None
@@ -1857,6 +1719,7 @@ async def _generate_llm_continuation(
break # Success, exit retry loop
except Exception as e:
last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1
delay = min(
@@ -1870,17 +1733,23 @@ async def _generate_llm_continuation(
await asyncio.sleep(delay)
continue
else:
# Non-retryable error - log and exit gracefully
logger.error(
f"Non-retryable error in LLM continuation: {e!s}",
exc_info=True,
# Non-retryable error - log details and exit gracefully
_log_api_error(
error=e,
session_id=session_id,
message_count=len(messages) if messages else None,
model=config.model,
retry_count=retry_count,
)
return
if last_error:
logger.error(
f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
f"Last error: {last_error!s}"
_log_api_error(
error=last_error,
session_id=session_id,
message_count=len(messages) if messages else None,
model=config.model,
retry_count=MAX_RETRIES,
)
return
@@ -1920,6 +1789,89 @@ async def _generate_llm_continuation(
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
def _log_api_error(
error: Exception,
session_id: str | None = None,
message_count: int | None = None,
model: str | None = None,
retry_count: int = 0,
) -> None:
"""Log detailed API error information for debugging."""
details = _extract_api_error_details(error)
details["session_id"] = session_id
details["message_count"] = message_count
details["model"] = model
details["retry_count"] = retry_count
if isinstance(error, RateLimitError):
logger.warning(f"Rate limit error: {details}")
elif isinstance(error, APIConnectionError):
logger.warning(f"API connection error: {details}")
elif isinstance(error, APIStatusError) and error.status_code >= 500:
logger.error(f"API server error (5xx): {details}")
else:
logger.error(f"API error: {details}")
def _extract_api_error_details(error: Exception) -> dict[str, Any]:
"""Extract detailed information from OpenAI/OpenRouter API errors."""
error_msg = str(error)
details: dict[str, Any] = {
"error_type": type(error).__name__,
"error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg,
}
if hasattr(error, "code"):
details["code"] = getattr(error, "code", None)
if hasattr(error, "param"):
details["param"] = getattr(error, "param", None)
if isinstance(error, APIStatusError):
details["status_code"] = error.status_code
details["request_id"] = getattr(error, "request_id", None)
if hasattr(error, "body") and error.body:
details["response_body"] = _sanitize_error_body(error.body)
if hasattr(error, "response") and error.response:
headers = error.response.headers
details["openrouter_provider"] = headers.get("x-openrouter-provider")
details["openrouter_model"] = headers.get("x-openrouter-model")
details["retry_after"] = headers.get("retry-after")
details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining")
return details
def _sanitize_error_body(
body: Any, max_length: int = 2000
) -> dict[str, Any] | str | None:
"""Extract only safe fields from error response body to avoid logging sensitive data."""
if not isinstance(body, dict):
# Non-dict bodies (e.g., HTML error pages) - return truncated string
if body is not None:
body_str = str(body)
if len(body_str) > max_length:
return body_str[:max_length] + "...[truncated]"
return body_str
return None
safe_fields = ("message", "type", "code", "param", "error")
sanitized: dict[str, Any] = {}
for field in safe_fields:
if field in body:
value = body[field]
if field == "error" and isinstance(value, dict):
sanitized[field] = _sanitize_error_body(value, max_length)
elif isinstance(value, str) and len(value) > max_length:
sanitized[field] = value[:max_length] + "...[truncated]"
else:
sanitized[field] = value
return sanitized if sanitized else None
async def _generate_llm_continuation_with_streaming(
session_id: str,
user_id: str | None,
@@ -1967,10 +1919,6 @@ async def _generate_llm_continuation_with_streaming(
if session_id:
extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower():
extra_body["reasoning"] = {"enabled": True}
# Make streaming LLM call (no tools - just text response)
from typing import cast
@@ -1982,7 +1930,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamStartStep())
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response
@@ -2006,7 +1953,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(task_id, StreamFinishStep())
if assistant_content:
# Reload session from DB to avoid race condition with user messages
@@ -2048,5 +1994,4 @@ async def _generate_llm_continuation_with_streaming(
task_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -104,24 +104,6 @@ async def create_task(
Returns:
The created ActiveTask instance (metadata only)
"""
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
task = ActiveTask(
task_id=task_id,
session_id=session_id,
@@ -132,18 +114,10 @@ async def create_task(
)
# Store metadata in Redis
redis_start = time.perf_counter()
redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
@@ -157,22 +131,12 @@ async def create_task(
"created_at": task.created_at.isoformat(),
},
)
hset_time = (time.perf_counter() - hset_start) * 1000
logger.info(
f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
@@ -192,60 +156,26 @@ async def publish_chunk(
Returns:
The Redis Stream message ID
"""
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json()
message_id = "0-0"
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"chunk_type": chunk_type,
}
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
chunk_type
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50
):
logger.info(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"xadd_time_ms": xadd_time,
"message_id": message_id,
}
},
)
except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
@@ -270,61 +200,24 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
if not meta:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
extra={
"json_fields": {
**log_meta,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
@@ -332,19 +225,7 @@ async def subscribe_to_task(
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"task_status": task_status,
}
},
)
replayed_count = 0
replay_last_id = last_message_id
@@ -363,48 +244,19 @@ async def subscribe_to_task(
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.info(
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={
"json_fields": {
**log_meta,
"n_messages_replayed": replayed_count,
"replay_last_id": replay_last_id,
}
},
)
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
logger.info(
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"n_messages_replayed": replayed_count,
}
},
)
return subscriber_queue
@@ -412,7 +264,6 @@ async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
log_meta: dict | None = None,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
@@ -423,27 +274,10 @@ async def _stream_listener(
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
"""
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
messages_delivered = 0
first_message_time = None
xread_count = 0
try:
redis = await get_redis_async()
@@ -453,39 +287,9 @@ async def _stream_listener(
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
xread_time = (time.perf_counter() - xread_start) * 1000
if messages:
msg_count = sum(len(msgs) for _, msgs in messages)
logger.info(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"n_messages": msg_count,
"duration_ms": xread_time,
}
},
)
elif xread_time > 1000:
# Only log timeouts (30s blocking)
logger.info(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"duration_ms": xread_time,
"reason": "timeout",
}
},
)
if not messages:
# Timeout - check if task is still running
@@ -522,30 +326,10 @@ async def _stream_listener(
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError:
logger.warning(
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
extra={
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
@@ -567,44 +351,15 @@ async def _stream_listener(
# Stop listening on finish
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return
except Exception as e:
logger.warning(
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"messages_delivered": messages_delivered,
"reason": "cancelled",
}
},
)
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
)
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
@@ -613,24 +368,10 @@ async def _stream_listener(
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
"Could not deliver finish event after error",
extra={"json_fields": log_meta},
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time / 1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
"xread_count": xread_count,
}
},
)
_listener_tasks.pop(queue_id, None)
@@ -857,10 +598,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -874,8 +613,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.START_STEP.value: StreamStartStep,
ResponseType.FINISH_STEP.value: StreamFinishStep,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,

View File

@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import BlockType, get_block
from backend.data.block import get_block
logger = logging.getLogger(__name__)
_TARGET_RESULTS = 10
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
_OVERFETCH_PAGE_SIZE = 40
# Block types that only work within graphs and cannot run standalone in CoPilot.
COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
"3b191d9f-356f-482d-8238-ba04b6d18381",
}
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=_OVERFETCH_PAGE_SIZE,
page_size=10,
)
if not results:
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
block = get_block(block_id)
# Skip disabled blocks
if not block or block.disabled:
continue
if block and not block.disabled:
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Skip blocks excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
continue
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception as e:
logger.debug(
"Failed to generate input schema for block %s: %s",
block_id,
e,
)
try:
output_schema = block.output_schema.jsonschema()
except Exception as e:
logger.debug(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
)
)
if len(blocks) >= _TARGET_RESULTS:
break
if blocks and len(blocks) < _TARGET_RESULTS:
logger.debug(
"find_block returned %d/%d results for query '%s' "
"(filtered %d excluded/disabled blocks)",
len(blocks),
_TARGET_RESULTS,
query,
len(results) - len(blocks),
)
if not blocks:
return NoResultsResponse(

View File

@@ -1,139 +0,0 @@
"""Tests for block filtering in FindBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from backend.api.features.chat.tools.models import BlockListResponse
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-find-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields.return_value = {}
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = {}
mock.categories = []
return mock
class TestFindBlockFiltering:
"""Tests for block filtering in FindBlockTool."""
def test_excluded_block_types_contains_expected_types(self):
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
def test_excluded_block_ids_contains_smart_decision_maker(self):
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_filtered_from_results(self):
"""Verify blocks with excluded BlockTypes are filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
search_results = [
{"content_id": "input-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
standard_block = make_mock_block(
"standard-block-id", "HTTP Request", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"input-block-id": input_block,
"standard-block-id": standard_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
# Should only return the standard block, not the INPUT block
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "standard-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_filtered_from_results(self):
"""Verify SmartDecisionMakerBlock is filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
search_results = [
{"content_id": smart_decision_id, "score": 0.9},
{"content_id": "normal-block-id", "score": 0.8},
]
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
normal_block = make_mock_block(
"normal-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
smart_decision_id: smart_block,
"normal-block-id": normal_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="decision"
)
# Should only return normal block, not SmartDecisionMakerBlock
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "normal-block-id"

View File

@@ -1,29 +0,0 @@
"""Shared helpers for chat tools."""
from typing import Any
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema."""
if not isinstance(input_schema, dict):
return []
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]

View File

@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
),
requirements={
"credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema),
"inputs": self._get_inputs_list(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
suffix: str,
) -> str:
"""Build a message describing available inputs for an agent."""
inputs_list = get_inputs_from_schema(graph.input_schema)
inputs_list = self._get_inputs_list(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]]

View File

@@ -8,19 +8,14 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.data.block import AnyBlockSchema, get_block
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
BlockOutputResponse,
ErrorResponse,
@@ -29,10 +24,7 @@ from .models import (
ToolResponseBase,
UserReadiness,
)
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
@@ -81,6 +73,91 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
@@ -135,26 +212,12 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
)
if missing_credentials:
# Return setup requirements response with missing credentials
@@ -282,75 +345,29 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
resolved: dict[str, CredentialsFieldInfo] = {}
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved
return inputs_list

View File

@@ -1,106 +0,0 @@
"""Tests for block execution guards in RunBlockTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-run-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields_info.return_value = []
return mock
class TestRunBlockFiltering:
"""Tests for block execution guards in RunBlockTool."""
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_returns_error(self):
"""Attempting to execute a block with excluded BlockType returns error."""
session = make_session(user_id=_TEST_USER_ID)
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=input_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="input-block-id",
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "designed for use within graphs only" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_returns_error(self):
"""Attempting to execute SmartDecisionMakerBlock returns error."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=smart_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=smart_decision_id,
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_non_excluded_block_passes_guard(self):
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
session = make_session(user_id=_TEST_USER_ID)
standard_block = make_mock_block(
"standard-id", "HTTP Request", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="standard-id",
input_data={},
)
# Should NOT be an ErrorResponse about CoPilot exclusion
# (may be other errors like missing credentials, but not the exclusion guard)
if isinstance(response, ErrorResponse):
assert "cannot be run directly in CoPilot" not in response.message

View File

@@ -8,7 +8,6 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
@@ -224,99 +223,6 @@ async def get_or_create_library_agent(
return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph(
user_id: str,
graph: GraphModel,
@@ -425,6 +331,8 @@ def _credential_has_required_scopes(
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)

View File

@@ -25,7 +25,7 @@ FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
@pytest_asyncio.fixture(loop_scope="session")
async def client(server, mock_jwt_user) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create async HTTP client with auth overrides"""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# Override get_jwt_payload dependency to return our test user
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]

View File

@@ -2,10 +2,10 @@ import asyncio
import logging
from typing import Any, List
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
import backend.api.auth as autogpt_auth_lib
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,

View File

@@ -3,6 +3,7 @@ import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List, Literal
from autogpt_libs.auth import get_user_id
from fastapi import (
APIRouter,
Body,
@@ -16,7 +17,6 @@ from fastapi import (
from pydantic import BaseModel, Field, SecretStr
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.auth import get_user_id
from backend.api.features.library.db import set_preset_webhook, update_preset
from backend.api.features.library.model import LibraryAgentPreset
from backend.data.graph import NodeModel, get_graph, set_node_webhook

View File

@@ -1,10 +1,10 @@
from typing import Literal, Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
from fastapi.responses import Response
from prisma.enums import OnboardingStep
import backend.api.auth as autogpt_auth_lib
from backend.data.onboarding import complete_onboarding_step
from .. import db as library_db

View File

@@ -1,9 +1,9 @@
import logging
from typing import Any, Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
import backend.api.auth as autogpt_auth_lib
from backend.data.execution import GraphExecutionMeta
from backend.data.graph import get_graph
from backend.data.integrations import get_webhook

View File

@@ -23,7 +23,7 @@ FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield

View File

@@ -21,13 +21,13 @@ from datetime import datetime
from typing import Literal, Optional
from urllib.parse import urlencode
from autogpt_libs.auth import get_user_id
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
from gcloud.aio import storage as async_storage
from PIL import Image
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.auth import get_user_id
from backend.data.auth.oauth import (
InvalidClientError,
InvalidGrantError,

View File

@@ -21,6 +21,7 @@ from typing import AsyncGenerator
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
from prisma.models import OAuthApplication as PrismaOAuthApplication
@@ -28,7 +29,6 @@ from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.models import User as PrismaUser
from backend.api.auth.api_key.keysmith import APIKeySmith
from backend.api.rest_api import app
keysmith = APIKeySmith()
@@ -134,7 +134,7 @@ async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, No
Depends on `server` to ensure the DB is connected and `test_user` to ensure
the user exists in the database before running tests.
"""
from backend.api.auth import get_user_id
from autogpt_libs.auth import get_user_id
# Override get_user_id dependency to return our test user
def override_get_user_id():

View File

@@ -1,9 +1,8 @@
import logging
from autogpt_libs.auth import get_user_id, requires_user
from fastapi import APIRouter, HTTPException, Security
from backend.api.auth import get_user_id, requires_user
from .models import ApiResponse, ChatRequest
from .service import OttoService

View File

@@ -19,7 +19,7 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield

View File

@@ -57,7 +57,7 @@ async def postmark_webhook_handler(
webhook: Annotated[
PostmarkWebhook,
Body(discriminator="RecordType"),
],
]
):
logger.info(f"Received webhook from Postmark: {webhook}")
match webhook:

View File

@@ -164,7 +164,7 @@ class BlockHandler(ContentHandler):
block_ids = list(all_blocks.keys())
# Query for existing embeddings
placeholders = ",".join([f"${i + 1}" for i in range(len(block_ids))])
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"
@@ -265,7 +265,7 @@ class BlockHandler(ContentHandler):
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = enabled_block_ids
placeholders = ",".join([f"${i + 1}" for i in range(len(block_ids))])
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema(
f"""
@@ -508,7 +508,7 @@ class DocumentationHandler(ContentHandler):
]
# Check which ones have embeddings
placeholders = ",".join([f"${i + 1}" for i in range(len(section_content_ids))])
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"

View File

@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
import logging
import re
import time
from dataclasses import dataclass
from typing import Any, Literal
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
LIMIT {limit_param} OFFSET {offset_param}
"""
try:
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
results = await query_raw_with_schema(sql_query, *params)
total = results[0]["total_count"] if results else 0
# Apply BM25 reranking
@@ -691,11 +686,7 @@ async def hybrid_search(
LIMIT {limit_param} OFFSET {offset_param}
"""
try:
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
results = await query_raw_with_schema(sql_query, *params)
total = results[0]["total_count"] if results else 0
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
return await hybrid_search(query=query, page=page, page_size=page_size)
# ============================================================================
# Diagnostics
# ============================================================================
# Rate limit: only log vector error diagnostics once per this interval
_VECTOR_DIAG_INTERVAL_SECONDS = 60
_last_vector_diag_time: float = 0
async def _log_vector_error_diagnostics(error: Exception) -> None:
"""Log diagnostic info when 'type vector does not exist' error occurs.
Note: Diagnostic queries use query_raw_with_schema which may run on a different
pooled connection than the one that failed. Session-level search_path can differ,
so these diagnostics show cluster-wide state, not necessarily the failed session.
Includes rate limiting to avoid log spam - only logs once per minute.
Caller should re-raise the error after calling this function.
"""
global _last_vector_diag_time
# Check if this is the vector type error
error_str = str(error).lower()
if not (
"type" in error_str and "vector" in error_str and "does not exist" in error_str
):
return
# Rate limit: only log once per interval
now = time.time()
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
return
_last_vector_diag_time = now
try:
diagnostics: dict[str, object] = {}
try:
search_path_result = await query_raw_with_schema("SHOW search_path")
diagnostics["search_path"] = search_path_result
except Exception as e:
diagnostics["search_path"] = f"Error: {e}"
try:
schema_result = await query_raw_with_schema("SELECT current_schema()")
diagnostics["current_schema"] = schema_result
except Exception as e:
diagnostics["current_schema"] = f"Error: {e}"
try:
user_result = await query_raw_with_schema(
"SELECT current_user, session_user, current_database()"
)
diagnostics["user_info"] = user_result
except Exception as e:
diagnostics["user_info"] = f"Error: {e}"
try:
# Check pgvector extension installation (cluster-wide, stable info)
ext_result = await query_raw_with_schema(
"SELECT extname, extversion, nspname as schema "
"FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
diagnostics["pgvector_extension"] = ext_result
except Exception as e:
diagnostics["pgvector_extension"] = f"Error: {e}"
logger.error(
f"Vector type error diagnostics:\n"
f" Error: {error}\n"
f" search_path: {diagnostics.get('search_path')}\n"
f" current_schema: {diagnostics.get('current_schema')}\n"
f" user_info: {diagnostics.get('user_info')}\n"
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
)
except Exception as diag_error:
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -47,7 +47,7 @@ def mock_storage_client(mocker):
async def test_upload_media_success(mock_settings, mock_storage_client):
# Create test JPEG data with valid signature
test_data = b"\xff\xd8\xff" + b"test data"
test_data = b"\xFF\xD8\xFF" + b"test data"
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
@@ -85,7 +85,7 @@ async def test_upload_media_missing_credentials(monkeypatch):
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
file=io.BytesIO(b"\xff\xd8\xff" + b"test data"), # Valid JPEG signature
file=io.BytesIO(b"\xFF\xD8\xFF" + b"test data"), # Valid JPEG signature
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
@@ -110,7 +110,7 @@ async def test_upload_media_video_type(mock_settings, mock_storage_client):
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
large_data = b"\xff\xd8\xff" + b"x" * (
large_data = b"\xFF\xD8\xFF" + b"x" * (
50 * 1024 * 1024 + 1
) # 50MB + 1 byte with valid JPEG signature
test_file = fastapi.UploadFile(

View File

@@ -4,11 +4,11 @@ import typing
import urllib.parse
from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
import backend.api.auth
import backend.data.graph
import backend.util.json
from backend.util.models import Pagination
@@ -34,11 +34,11 @@ router = fastapi.APIRouter()
"/profile",
summary="Get user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.ProfileDetails,
)
async def get_profile(
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Get the profile details for the authenticated user.
@@ -57,12 +57,12 @@ async def get_profile(
"/profile",
summary="Update user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.CreatorDetails,
)
async def update_or_create_profile(
profile: store_model.Profile,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Update the store profile for the authenticated user.
@@ -169,7 +169,7 @@ async def unified_search(
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
backend.api.auth.get_optional_user_id, use_cache=False
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
):
"""
@@ -274,7 +274,7 @@ async def get_agent(
"/graph/{store_listing_version_id}",
summary="Get agent graph",
tags=["store"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
@@ -290,7 +290,7 @@ async def get_graph_meta_by_store_listing_version_id(
"/agents/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(store_listing_version_id: str):
@@ -306,14 +306,14 @@ async def get_store_agent(store_listing_version_id: str):
"/agents/{username}/{agent_name}/review",
summary="Create agent review",
tags=["store"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreReview,
)
async def create_review(
username: str,
agent_name: str,
review: store_model.StoreReviewCreate,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a review for a store agent.
@@ -417,11 +417,11 @@ async def get_creator(
"/myagents",
summary="Get my agents",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.MyAgentsResponse,
)
async def get_my_agents(
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
):
@@ -436,12 +436,12 @@ async def get_my_agents(
"/submissions/{submission_id}",
summary="Delete store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=bool,
)
async def delete_submission(
submission_id: str,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Delete a store listing submission.
@@ -465,11 +465,11 @@ async def delete_submission(
"/submissions",
summary="List my submissions",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmissionsResponse,
)
async def get_submissions(
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: int = 1,
page_size: int = 20,
):
@@ -508,12 +508,12 @@ async def get_submissions(
"/submissions",
summary="Create store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
)
async def create_submission(
submission_request: store_model.StoreSubmissionRequest,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new store listing submission.
@@ -552,13 +552,13 @@ async def create_submission(
"/submissions/{store_listing_version_id}",
summary="Edit store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
)
async def edit_submission(
store_listing_version_id: str,
submission_request: store_model.StoreSubmissionEditRequest,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Edit an existing store listing submission.
@@ -596,11 +596,11 @@ async def edit_submission(
"/submissions/media",
summary="Upload submission media",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def upload_submission_media(
file: fastapi.UploadFile,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Upload media (images/videos) for a store listing submission.
@@ -623,11 +623,11 @@ async def upload_submission_media(
"/submissions/generate_image",
summary="Generate submission image",
tags=["store", "private"],
dependencies=[fastapi.Security(backend.api.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def generate_image(
agent_id: str,
user_id: str = fastapi.Security(backend.api.auth.get_user_id),
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> fastapi.responses.Response:
"""
Generate an image for a store listing submission.

View File

@@ -24,7 +24,7 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield

View File

@@ -9,6 +9,8 @@ from typing import Annotated, Any, Sequence, get_args
import pydantic
import stripe
from autogpt_libs.auth import get_user_id, requires_user
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from fastapi import (
APIRouter,
Body,
@@ -26,8 +28,6 @@ from pydantic import BaseModel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
from backend.api.auth import get_user_id, requires_user
from backend.api.auth.jwt_utils import get_jwt_payload
from backend.api.model import (
CreateAPIKeyRequest,
CreateAPIKeyResponse,

View File

@@ -25,7 +25,7 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user, setup_test_user):
"""Setup auth overrides for all tests in this module"""
from backend.api.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# setup_test_user fixture already executed and user is created in database
# It returns the user_id which we don't need to await
@@ -499,12 +499,10 @@ async def test_upload_file_success(test_user_id: str):
)
# Mock dependencies
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.store_file.return_value = "gcs://test-bucket/uploads/123/test.txt"
@@ -553,12 +551,10 @@ async def test_upload_file_no_filename(test_user_id: str):
),
)
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.store_file.return_value = (
@@ -636,12 +632,10 @@ async def test_upload_file_cloud_storage_failure(test_user_id: str):
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
)
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.store_file.side_effect = RuntimeError("Storage error!")
@@ -685,12 +679,10 @@ async def test_upload_file_gcs_not_configured_fallback(test_user_id: str):
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
)
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.config.gcs_bucket_name = "" # Simulate no GCS bucket configured

View File

@@ -8,9 +8,9 @@ from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.api.auth.dependencies import get_user_id, requires_user
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage

View File

@@ -9,6 +9,8 @@ import fastapi.responses
import pydantic
import starlette.middleware.cors
import uvicorn
from autogpt_libs.auth import add_auth_responses_to_openapi
from autogpt_libs.auth import verify_settings as verify_auth_settings
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.routing import APIRoute
@@ -38,8 +40,6 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.auth import add_auth_responses_to_openapi
from backend.api.auth import verify_settings as verify_auth_settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
@@ -69,7 +69,7 @@ from .utils.openapi import sort_openapi
settings = backend.util.settings.Settings()
logger = logging.getLogger(__name__)
logging.getLogger("backend.api.auth").setLevel(logging.INFO)
logging.getLogger("autogpt_libs").setLevel(logging.INFO)
@contextlib.contextmanager

View File

@@ -457,8 +457,7 @@ async def test_api_key_with_unicode_characters_normalization_attack(mock_request
"""Test that Unicode normalization doesn't bypass validation."""
# Create auth with composed Unicode character
auth = APIKeyAuthenticator(
header_name="X-API-Key",
expected_token="café", # é is composed
header_name="X-API-Key", expected_token="café" # é is composed
)
# Try with decomposed version (c + a + f + e + ´)
@@ -523,8 +522,8 @@ async def test_api_keys_with_newline_variations(mock_request):
"valid\r\ntoken", # Windows newline
"valid\rtoken", # Mac newline
"valid\x85token", # NEL (Next Line)
"valid\x0btoken", # Vertical Tab
"valid\x0ctoken", # Form Feed
"valid\x0Btoken", # Vertical Tab
"valid\x0Ctoken", # Form Feed
]
for api_key in newline_variations:

View File

@@ -5,10 +5,10 @@ from typing import Protocol
import pydantic
import uvicorn
from autogpt_libs.auth.jwt_utils import parse_jwt_token
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
from backend.api.auth.jwt_utils import parse_jwt_token
from backend.api.conn_manager import ConnectionManager
from backend.api.model import (
WSMessage,

View File

@@ -44,12 +44,9 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
"backend.api.ws_api.build_cors_params", return_value=cors_params
)
with (
override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
),
override_config(settings, "app_env", AppEnvironment.LOCAL),
):
with override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
), override_config(settings, "app_env", AppEnvironment.LOCAL):
WebsocketServer().run()
build_cors.assert_called_once_with(
@@ -68,12 +65,9 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
mocker.patch("backend.api.ws_api.uvicorn.run")
with (
override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
),
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
):
with override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
with pytest.raises(ValueError):
WebsocketServer().run()

View File

@@ -174,9 +174,7 @@ class AIImageGeneratorBlock(Block):
],
test_mock={
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: (
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
)
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
},
)

View File

@@ -142,9 +142,7 @@ class AIMusicGeneratorBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: (
"https://replicate.com/output/generated-audio-url.wav"
),
"run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
},
test_credentials=TEST_CREDENTIALS,
)

View File

@@ -69,18 +69,12 @@ class PostToBlueskyBlock(Block):
client = create_ayrshare_client()
if not client:
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate character limit for Bluesky
if len(input_data.post) > 300:
yield (
"error",
f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)",
)
yield "error", f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)"
return
# Validate media constraints for Bluesky

View File

@@ -131,10 +131,7 @@ class PostToFacebookBlock(Block):
client = create_ayrshare_client()
if not client:
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Convert datetime to ISO format if provided

View File

@@ -120,18 +120,12 @@ class PostToGMBBlock(Block):
client = create_ayrshare_client()
if not client:
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate GMB constraints
if len(input_data.media_urls) > 1:
yield (
"error",
"Google My Business supports only one image or video per post",
)
yield "error", "Google My Business supports only one image or video per post"
return
# Validate offer coupon code length

View File

@@ -123,25 +123,16 @@ class PostToInstagramBlock(Block):
client = create_ayrshare_client()
if not client:
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Instagram constraints
if len(input_data.post) > 2200:
yield (
"error",
f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)",
)
yield "error", f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 10:
yield (
"error",
"Instagram supports a maximum of 10 images/videos in a carousel",
)
yield "error", "Instagram supports a maximum of 10 images/videos in a carousel"
return
if len(input_data.collaborators) > 3:
@@ -156,10 +147,7 @@ class PostToInstagramBlock(Block):
]
if any(reel_options) and not all(reel_options):
yield (
"error",
"When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset",
)
yield "error", "When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset"
return
# Count hashtags and mentions
@@ -167,17 +155,11 @@ class PostToInstagramBlock(Block):
mention_count = input_data.post.count("@")
if hashtag_count > 30:
yield (
"error",
f"Instagram allows maximum 30 hashtags ({hashtag_count} found)",
)
yield "error", f"Instagram allows maximum 30 hashtags ({hashtag_count} found)"
return
if mention_count > 3:
yield (
"error",
f"Instagram allows maximum 3 @mentions ({mention_count} found)",
)
yield "error", f"Instagram allows maximum 3 @mentions ({mention_count} found)"
return
# Convert datetime to ISO format if provided
@@ -209,10 +191,7 @@ class PostToInstagramBlock(Block):
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 1000:
yield (
"error",
f"Alt text {i + 1} exceeds 1,000 character limit ({len(alt)} characters)",
)
yield "error", f"Alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
return
instagram_options["altText"] = input_data.alt_text
@@ -227,19 +206,13 @@ class PostToInstagramBlock(Block):
try:
tag_obj = InstagramUserTag(**tag)
except Exception as e:
yield (
"error",
f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)",
)
yield "error", f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)"
return
tag_dict: dict[str, float | str] = {"username": tag_obj.username}
if tag_obj.x is not None and tag_obj.y is not None:
# Validate coordinates
if not (0.0 <= tag_obj.x <= 1.0) or not (0.0 <= tag_obj.y <= 1.0):
yield (
"error",
f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})",
)
yield "error", f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})"
return
tag_dict["x"] = tag_obj.x
tag_dict["y"] = tag_obj.y

View File

@@ -123,18 +123,12 @@ class PostToLinkedInBlock(Block):
client = create_ayrshare_client()
if not client:
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate LinkedIn constraints
if len(input_data.post) > 3000:
yield (
"error",
f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)",
)
yield "error", f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 9:
@@ -142,19 +136,13 @@ class PostToLinkedInBlock(Block):
return
if input_data.document_title and len(input_data.document_title) > 400:
yield (
"error",
f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)",
)
yield "error", f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)"
return
# Validate visibility option
valid_visibility = ["public", "connections", "loggedin"]
if input_data.visibility not in valid_visibility:
yield (
"error",
f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}",
)
yield "error", f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}"
return
# Check for document extensions

View File

@@ -103,32 +103,20 @@ class PostToPinterestBlock(Block):
client = create_ayrshare_client()
if not client:
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Pinterest constraints
if len(input_data.post) > 500:
yield (
"error",
f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)",
)
yield "error", f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)"
return
if len(input_data.pin_title) > 100:
yield (
"error",
f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)",
)
yield "error", f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)"
return
if len(input_data.link) > 2048:
yield (
"error",
f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)",
)
yield "error", f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)"
return
if len(input_data.media_urls) == 0:
@@ -153,10 +141,7 @@ class PostToPinterestBlock(Block):
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 500:
yield (
"error",
f"Pinterest alt text {i + 1} exceeds 500 character limit ({len(alt)} characters)",
)
yield "error", f"Pinterest alt text {i+1} exceeds 500 character limit ({len(alt)} characters)"
return
# Convert datetime to ISO format if provided

Some files were not shown because too many files have changed in this diff Show More