mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-11 16:18:07 -05:00
Compare commits
60 Commits
fix/chat-f
...
hackathon-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5aaf07fbaf | ||
|
|
0d2996e501 | ||
|
|
843c487500 | ||
|
|
47a3a5ef41 | ||
|
|
ec00aa951a | ||
|
|
9e37a66bca | ||
|
|
429a074848 | ||
|
|
36fb1ea004 | ||
|
|
a81ac150da | ||
|
|
49ee087496 | ||
|
|
fc25e008b3 | ||
|
|
b0855e8cf2 | ||
|
|
5e2146dd76 | ||
|
|
103a62c9da | ||
|
|
fc8434fb30 | ||
|
|
7f1245dc42 | ||
|
|
3ae08cd48e | ||
|
|
4db13837b9 | ||
|
|
df87867625 | ||
|
|
e503126170 | ||
|
|
7ee28197a3 | ||
|
|
818de26d24 | ||
|
|
cb08def96c | ||
|
|
ac2daee5f8 | ||
|
|
266e0d79d4 | ||
|
|
01f443190e | ||
|
|
bdba0033de | ||
|
|
b87c64ce38 | ||
|
|
003affca43 | ||
|
|
290d0d9a9b | ||
|
|
fba61c72ed | ||
|
|
79d45a15d0 | ||
|
|
66f0d97ca2 | ||
|
|
5894a8fcdf | ||
|
|
dff8efa35d | ||
|
|
e26822998f | ||
|
|
88731b1f76 | ||
|
|
c3e407ef09 | ||
|
|
08a60dcb9b | ||
|
|
de78d062a9 | ||
|
|
217e3718d7 | ||
|
|
3dbc03e488 | ||
|
|
b76b5a37c5 | ||
|
|
eed07b173a | ||
|
|
c5e8b0b08f | ||
|
|
cd3e35df9e | ||
|
|
4a7bc006a8 | ||
|
|
4c474417bc | ||
|
|
99e2261254 | ||
|
|
cab498fa8c | ||
|
|
22078671df | ||
|
|
0082a72657 | ||
|
|
9a1d940677 | ||
|
|
e640d36265 | ||
|
|
cc9179178f | ||
|
|
e8d37ab116 | ||
|
|
7f7ef6a271 | ||
|
|
aefac541d9 | ||
|
|
ff5c8f324b | ||
|
|
71157bddd7 |
@@ -16,6 +16,7 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,6 +108,16 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
@@ -12,6 +12,7 @@ reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -33,6 +34,7 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
@@ -57,6 +57,9 @@ class APIKeySmith:
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
if not raw_key.startswith(self.PREFIX):
|
||||
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
@@ -1,29 +1,25 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -48,7 +48,8 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.server.v2.myroute import router
|
||||
from backend.api.features.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.server.v2.myroute import router
|
||||
from backend.api.features.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.conn_manager import ConnectionManager
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidTokenError,
|
||||
OAuthAccessTokenInfo,
|
||||
validate_access_token,
|
||||
)
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Middleware for API key authentication only"""
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||
)
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> OAuthAccessTokenInfo:
|
||||
"""Middleware for OAuth access token authentication only"""
|
||||
if bearer is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def require_auth(
|
||||
api_key: str | None = Security(api_key_header),
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
"""
|
||||
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||
|
||||
Supports two authentication methods, which are checked in order:
|
||||
1. X-API-Key header (existing API key authentication)
|
||||
2. Authorization: Bearer <token> header (OAuth access token)
|
||||
|
||||
Returns:
|
||||
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||
"""
|
||||
# Try API key first
|
||||
if api_key is not None:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
# Try OAuth bearer token
|
||||
if bearer is not None:
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
return token_info
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
# No credentials provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication. Provide API key or access token.",
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permission
|
||||
@@ -16,7 +16,9 @@ from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
@@ -28,8 +30,6 @@ from backend.data.model import (
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.integrations.models import get_all_provider_names
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -255,7 +255,7 @@ def _get_oauth_handler_for_external(
|
||||
|
||||
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
||||
async def list_providers(
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[ProviderInfo]:
|
||||
@@ -319,7 +319,7 @@ async def list_providers(
|
||||
async def initiate_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthInitiateRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthInitiateResponse:
|
||||
@@ -337,7 +337,10 @@ async def initiate_oauth(
|
||||
if not validate_callback_url(request.callback_url):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
detail=(
|
||||
f"Callback URL origin is not allowed. "
|
||||
f"Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
),
|
||||
)
|
||||
|
||||
# Validate provider
|
||||
@@ -359,13 +362,15 @@ async def initiate_oauth(
|
||||
)
|
||||
|
||||
# Store state token with external flow metadata
|
||||
# Note: initiated_by_api_key_id is only available for API key auth, not OAuth
|
||||
api_key_id = getattr(auth, "id", None) if auth.type == "api_key" else None
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
||||
scopes=request.scopes,
|
||||
callback_url=request.callback_url,
|
||||
state_metadata=request.state_metadata,
|
||||
initiated_by_api_key_id=api_key.id,
|
||||
initiated_by_api_key_id=api_key_id,
|
||||
)
|
||||
|
||||
# Build login URL
|
||||
@@ -393,7 +398,7 @@ async def initiate_oauth(
|
||||
async def complete_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthCompleteRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthCompleteResponse:
|
||||
@@ -406,7 +411,7 @@ async def complete_oauth(
|
||||
"""
|
||||
# Verify state token
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
api_key.user_id, request.state_token, provider
|
||||
auth.user_id, request.state_token, provider
|
||||
)
|
||||
|
||||
if not valid_state:
|
||||
@@ -453,7 +458,7 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
# Store credentials
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
|
||||
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
||||
|
||||
@@ -470,7 +475,7 @@ async def complete_oauth(
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
async def list_credentials(
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -479,7 +484,7 @@ async def list_credentials(
|
||||
|
||||
Returns metadata about each credential without exposing sensitive tokens.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
@@ -499,7 +504,7 @@ async def list_credentials(
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -507,7 +512,7 @@ async def list_credentials_by_provider(
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_creds_by_provider(
|
||||
api_key.user_id, provider
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
@@ -536,7 +541,7 @@ async def create_credential(
|
||||
CreateUserPasswordCredentialRequest,
|
||||
CreateHostScopedCredentialRequest,
|
||||
] = Body(..., discriminator="type"),
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> CreateCredentialResponse:
|
||||
@@ -591,7 +596,7 @@ async def create_credential(
|
||||
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
raise HTTPException(
|
||||
@@ -623,7 +628,7 @@ class DeleteCredentialResponse(BaseModel):
|
||||
async def delete_credential(
|
||||
provider: Annotated[str, Path(title="The provider")],
|
||||
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
||||
),
|
||||
) -> DeleteCredentialResponse:
|
||||
@@ -634,7 +639,7 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
@@ -645,6 +650,6 @@ async def delete_credential(
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
await creds_manager.delete(api_key.user_id, cred_id)
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
||||
@@ -5,46 +5,60 @@ from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .integrations import integrations_router
|
||||
from .tools import tools_router
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
v1_router = APIRouter()
|
||||
|
||||
|
||||
class NodeOutput(TypedDict):
|
||||
key: str
|
||||
value: Any
|
||||
v1_router.include_router(integrations_router)
|
||||
v1_router.include_router(tools_router)
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
class UserInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: Optional[str]
|
||||
email: str
|
||||
timezone: str = Field(
|
||||
description="The user's last known timezone (e.g. 'Europe/Amsterdam'), "
|
||||
"or 'not-set' if not set"
|
||||
)
|
||||
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: list[NodeOutput]
|
||||
@v1_router.get(
|
||||
path="/me",
|
||||
tags=["user", "meta"],
|
||||
)
|
||||
async def get_user_info(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.IDENTITY)
|
||||
),
|
||||
) -> UserInfoResponse:
|
||||
user = await user_db.get_user_by_id(auth.user_id)
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
return UserInfoResponse(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
timezone=user.timezone,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -65,7 +79,9 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||
),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
@@ -85,12 +101,14 @@ async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_GRAPH)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
@@ -100,6 +118,19 @@ async def execute_graph(
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||
tags=["graphs"],
|
||||
@@ -107,10 +138,12 @@ async def execute_graph(
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_GRAPH)
|
||||
),
|
||||
) -> GraphExecutionResult:
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
@@ -122,7 +155,7 @@ async def get_graph_execution_results(
|
||||
if not await graph_db.get_graph(
|
||||
graph_id=graph_exec.graph_id,
|
||||
version=graph_exec.graph_version,
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
):
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
@@ -14,19 +14,19 @@ from fastapi import APIRouter, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.server.v2.chat.tools.models import ToolResponseBase
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
|
||||
# Note: We use Security() as a function parameter dependency (auth: APIAuthorizationInfo = Security(...))
|
||||
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
||||
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
|
||||
# while still enforcing auth AND giving us access to auth for extracting user_id.
|
||||
|
||||
|
||||
# Request models
|
||||
@@ -80,7 +80,9 @@ def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
)
|
||||
async def find_agent(
|
||||
request: FindAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for agents in the marketplace based on capabilities and user needs.
|
||||
@@ -91,9 +93,9 @@ async def find_agent(
|
||||
Returns:
|
||||
List of matching agents or no results response
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
result = await find_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
session=session,
|
||||
query=request.query,
|
||||
)
|
||||
@@ -105,7 +107,9 @@ async def find_agent(
|
||||
)
|
||||
async def run_agent(
|
||||
request: RunAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run or schedule an agent from the marketplace.
|
||||
@@ -129,9 +133,9 @@ async def run_agent(
|
||||
- execution_started: If agent was run or scheduled successfully
|
||||
- error: If something went wrong
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
result = await run_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
session=session,
|
||||
username_agent_slug=request.username_agent_slug,
|
||||
inputs=request.inputs,
|
||||
@@ -6,9 +6,10 @@ from fastapi import APIRouter, Body, Security
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .model import AddUserCreditsResponse, UserHistoryResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
|
||||
import backend.server.v2.admin.model as admin_model
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .credit_admin_routes import router as credit_admin_router
|
||||
from .model import UserHistoryResponse
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(credit_admin_routes.router)
|
||||
app.include_router(credit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
@@ -30,7 +31,7 @@ def setup_app_admin_auth(mock_jwt_admin):
|
||||
|
||||
|
||||
def test_add_user_credits_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
admin_user_id: str,
|
||||
target_user_id: str,
|
||||
@@ -42,7 +43,7 @@ def test_add_user_credits_success(
|
||||
return_value=(1500, "transaction-123-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -84,7 +85,7 @@ def test_add_user_credits_success(
|
||||
|
||||
|
||||
def test_add_user_credits_negative_amount(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test credit deduction by admin (negative amount)"""
|
||||
@@ -94,7 +95,7 @@ def test_add_user_credits_negative_amount(
|
||||
return_value=(200, "transaction-456-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -119,12 +120,12 @@ def test_add_user_credits_negative_amount(
|
||||
|
||||
|
||||
def test_get_user_history_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful retrieval of user credit history"""
|
||||
# Mock the admin_get_user_history function
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[
|
||||
UserTransaction(
|
||||
user_id="user-1",
|
||||
@@ -150,7 +151,7 @@ def test_get_user_history_success(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -170,12 +171,12 @@ def test_get_user_history_success(
|
||||
|
||||
|
||||
def test_get_user_history_with_filters(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test user credit history with search and filter parameters"""
|
||||
# Mock the admin_get_user_history function
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[
|
||||
UserTransaction(
|
||||
user_id="user-3",
|
||||
@@ -194,7 +195,7 @@ def test_get_user_history_with_filters(
|
||||
)
|
||||
|
||||
mock_get_history = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -230,12 +231,12 @@ def test_get_user_history_with_filters(
|
||||
|
||||
|
||||
def test_get_user_history_empty_results(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test user credit history with no results"""
|
||||
# Mock empty history response
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
@@ -246,7 +247,7 @@ def test_get_user_history_empty_results(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -7,9 +7,9 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,7 +24,7 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
@@ -48,7 +48,7 @@ async def get_admin_listings_with_versions(
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
"""
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_admin_listings_with_versions(
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
@@ -68,11 +68,11 @@ async def get_admin_listings_with_versions(
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -87,12 +87,10 @@ async def review_submission(
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = (
|
||||
await backend.server.v2.store.db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
@@ -136,7 +134,7 @@ async def admin_download_agent_file(
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent_as_admin(
|
||||
graph_data = await store_db.get_agent_as_admin(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
@@ -6,10 +6,11 @@ from typing import Annotated
|
||||
import fastapi
|
||||
import pydantic
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
import backend.data.analytics
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
router = fastapi.APIRouter(dependencies=[fastapi.Security(requires_user)])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -6,17 +6,20 @@ from typing import Sequence
|
||||
|
||||
import prisma
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.db as store_db
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
@@ -26,8 +29,6 @@ from backend.server.v2.builder.model import (
|
||||
ProviderResponse,
|
||||
SearchEntry,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
@@ -2,8 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.data.block import BlockInfo
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
@@ -4,11 +4,12 @@ from typing import Annotated, Sequence
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
|
||||
import backend.server.v2.builder.db as builder_db
|
||||
import backend.server.v2.builder.model as builder_model
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import db as builder_db
|
||||
from . import model as builder_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
@@ -19,9 +19,10 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.exceptions import RedisError
|
||||
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.model import (
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
@@ -9,10 +9,11 @@ from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.chat.service as chat_service
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
@@ -7,15 +7,17 @@ import orjson
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
import backend.server.v2.chat.config
|
||||
from backend.server.v2.chat.model import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.server.v2.chat.response_model import (
|
||||
from .response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
@@ -26,12 +28,11 @@ from backend.server.v2.chat.response_model import (
|
||||
StreamToolExecutionResult,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.server.v2.chat.tools import execute_tool, tools
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from .tools import execute_tool, tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = backend.server.v2.chat.config.ChatConfig()
|
||||
config = ChatConfig()
|
||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.chat.service as chat_service
|
||||
from backend.server.v2.chat.response_model import (
|
||||
from . import service as chat_service
|
||||
from .response_model import (
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
@@ -2,14 +2,14 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
# Initialize tool instances
|
||||
find_agent_tool = FindAgentTool()
|
||||
@@ -5,6 +5,8 @@ from os import getenv
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
@@ -13,8 +15,6 @@ from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -3,17 +3,18 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -5,14 +5,21 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
ErrorResponse,
|
||||
@@ -23,19 +30,13 @@ from backend.server.v2.chat.tools.models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from backend.server.v2.chat.tools.utils import (
|
||||
from .utils import (
|
||||
check_user_has_required_credentials,
|
||||
extract_credentials_from_schema,
|
||||
fetch_graph_from_store_slug,
|
||||
get_or_create_library_agent,
|
||||
match_user_credentials_to_graph,
|
||||
)
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -3,13 +3,13 @@ import uuid
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
from ._test_data import (
|
||||
make_session,
|
||||
setup_firecrawl_test_data,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.run_agent import RunAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
setup_llm_test_data = setup_llm_test_data
|
||||
@@ -3,13 +3,13 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.library import db as library_db
|
||||
from backend.server.v2.library import model as library_model
|
||||
from backend.server.v2.store import db as store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -7,9 +7,10 @@ import pytest_mock
|
||||
from prisma.enums import ReviewStatus
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.server.rest_api import handle_internal_http_error
|
||||
from backend.server.v2.executions.review.model import PendingHumanReviewModel
|
||||
from backend.server.v2.executions.review.routes import router
|
||||
from backend.api.rest_api import handle_internal_http_error
|
||||
|
||||
from .model import PendingHumanReviewModel
|
||||
from .routes import router
|
||||
|
||||
# Using a fixed timestamp for reproducible tests
|
||||
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
@@ -54,13 +55,13 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel:
|
||||
|
||||
|
||||
def test_get_pending_reviews_empty(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews when none exist"""
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||
)
|
||||
mock_get_reviews.return_value = []
|
||||
|
||||
@@ -72,14 +73,14 @@ def test_get_pending_reviews_empty(
|
||||
|
||||
|
||||
def test_get_pending_reviews_with_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews with data"""
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||
)
|
||||
mock_get_reviews.return_value = [sample_pending_review]
|
||||
|
||||
@@ -94,14 +95,14 @@ def test_get_pending_reviews_with_data(
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews for specific execution"""
|
||||
mock_get_graph_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_get_graph_execution.return_value = {
|
||||
"id": "test_graph_exec_456",
|
||||
@@ -109,7 +110,7 @@ def test_get_pending_reviews_for_execution_success(
|
||||
}
|
||||
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews.return_value = [sample_pending_review]
|
||||
|
||||
@@ -121,24 +122,23 @@ def test_get_pending_reviews_for_execution_success(
|
||||
assert data[0]["graph_exec_id"] == "test_graph_exec_456"
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_access_denied(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
def test_get_pending_reviews_for_execution_not_available(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test access denied when user doesn't own the execution"""
|
||||
mock_get_graph_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_get_graph_execution.return_value = None
|
||||
|
||||
response = client.get("/api/review/execution/test_graph_exec_456")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "Access denied" in response.json()["detail"]
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_process_review_action_approve_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -146,12 +146,12 @@ def test_process_review_action_approve_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# Create approved review for return
|
||||
approved_review = PendingHumanReviewModel(
|
||||
@@ -174,11 +174,11 @@ def test_process_review_action_approve_success(
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
mocker.patch("backend.server.v2.executions.review.routes.add_graph_execution")
|
||||
mocker.patch("backend.api.features.executions.review.routes.add_graph_execution")
|
||||
|
||||
request_data = {
|
||||
"reviews": [
|
||||
@@ -202,7 +202,7 @@ def test_process_review_action_approve_success(
|
||||
|
||||
|
||||
def test_process_review_action_reject_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -210,12 +210,12 @@ def test_process_review_action_reject_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
rejected_review = PendingHumanReviewModel(
|
||||
node_exec_id="test_node_123",
|
||||
@@ -237,7 +237,7 @@ def test_process_review_action_reject_success(
|
||||
mock_process_all_reviews.return_value = {"test_node_123": rejected_review}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
@@ -262,7 +262,7 @@ def test_process_review_action_reject_success(
|
||||
|
||||
|
||||
def test_process_review_action_mixed_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -289,12 +289,12 @@ def test_process_review_action_mixed_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review, second_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# Create approved version of first review
|
||||
approved_review = PendingHumanReviewModel(
|
||||
@@ -338,7 +338,7 @@ def test_process_review_action_mixed_success(
|
||||
}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
@@ -369,7 +369,7 @@ def test_process_review_action_mixed_success(
|
||||
|
||||
|
||||
def test_process_review_action_empty_request(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when no reviews provided"""
|
||||
@@ -386,19 +386,19 @@ def test_process_review_action_empty_request(
|
||||
|
||||
|
||||
def test_process_review_action_review_not_found(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when review is not found"""
|
||||
# Mock the functions that extract graph execution ID from the request
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [] # No reviews found
|
||||
|
||||
# Mock process_all_reviews to simulate not finding reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# This should raise a ValueError with "Reviews not found" message based on the data/human_review.py logic
|
||||
mock_process_all_reviews.side_effect = ValueError(
|
||||
@@ -422,20 +422,20 @@ def test_process_review_action_review_not_found(
|
||||
|
||||
|
||||
def test_process_review_action_partial_failure(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test handling of partial failures in review processing"""
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
# Mock partial failure in processing
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.side_effect = ValueError("Some reviews failed validation")
|
||||
|
||||
@@ -456,20 +456,20 @@ def test_process_review_action_partial_failure(
|
||||
|
||||
|
||||
def test_process_review_action_invalid_node_exec_id(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test failure when trying to process review with invalid node execution ID"""
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
# Mock validation failure - this should return 400, not 500
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.side_effect = ValueError(
|
||||
"Invalid node execution ID format"
|
||||
@@ -13,11 +13,8 @@ from backend.data.human_review import (
|
||||
process_all_reviews_for_execution,
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
ReviewRequest,
|
||||
ReviewResponse,
|
||||
)
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,8 +67,7 @@ async def list_pending_reviews(
|
||||
response_model=List[PendingHumanReviewModel],
|
||||
responses={
|
||||
200: {"description": "List of pending reviews for the execution"},
|
||||
400: {"description": "Invalid graph execution ID"},
|
||||
403: {"description": "Access denied to graph execution"},
|
||||
404: {"description": "Graph execution not found"},
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
@@ -94,7 +90,7 @@ async def list_pending_reviews_for_execution(
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 403: If user doesn't own the graph execution
|
||||
- 404: If the graph execution doesn't exist or isn't owned by this user
|
||||
- 500: If authentication fails or database error occurs
|
||||
|
||||
Note:
|
||||
@@ -108,8 +104,8 @@ async def list_pending_reviews_for_execution(
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to graph execution",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
@@ -134,18 +130,14 @@ async def process_review_action(
|
||||
# Build review decisions map
|
||||
review_decisions = {}
|
||||
for review in request.reviews:
|
||||
if review.approved:
|
||||
review_decisions[review.node_exec_id] = (
|
||||
ReviewStatus.APPROVED,
|
||||
review.reviewed_data,
|
||||
review.message,
|
||||
)
|
||||
else:
|
||||
review_decisions[review.node_exec_id] = (
|
||||
ReviewStatus.REJECTED,
|
||||
None,
|
||||
review.message,
|
||||
)
|
||||
review_status = (
|
||||
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
||||
)
|
||||
review_decisions[review.node_exec_id] = (
|
||||
review_status,
|
||||
review.reviewed_data,
|
||||
review.message,
|
||||
)
|
||||
|
||||
# Process all reviews
|
||||
updated_reviews = await process_all_reviews_for_execution(
|
||||
@@ -17,6 +17,8 @@ from fastapi import (
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.data.graph import NodeModel, get_graph, set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
@@ -45,13 +47,6 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.integrations.models import (
|
||||
ProviderConstants,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
)
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
from backend.util.exceptions import (
|
||||
GraphNotInLibraryError,
|
||||
MissingConfigError,
|
||||
@@ -60,6 +55,8 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
@@ -4,16 +4,14 @@ from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
import prisma.errors
|
||||
import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.image_gen as store_image_gen
|
||||
import backend.api.features.store.media as store_media
|
||||
import backend.data.graph as graph_db
|
||||
import backend.data.integrations as integrations_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
@@ -28,6 +26,8 @@ from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
@@ -538,6 +538,7 @@ async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
auto_update_version: Optional[bool] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
is_favorite: Optional[bool] = None,
|
||||
is_archived: Optional[bool] = None,
|
||||
is_deleted: Optional[Literal[False]] = None,
|
||||
@@ -550,6 +551,7 @@ async def update_library_agent(
|
||||
library_agent_id: The ID of the LibraryAgent to update.
|
||||
user_id: The owner of this LibraryAgent.
|
||||
auto_update_version: Whether the agent should auto-update to active version.
|
||||
graph_version: Specific graph version to update to.
|
||||
is_favorite: Whether this agent is marked as a favorite.
|
||||
is_archived: Whether this agent is archived.
|
||||
settings: User-specific settings for this library agent.
|
||||
@@ -563,8 +565,8 @@ async def update_library_agent(
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating library agent {library_agent_id} for user {user_id} with "
|
||||
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
|
||||
f"is_archived={is_archived}, settings={settings}"
|
||||
f"auto_update_version={auto_update_version}, graph_version={graph_version}, "
|
||||
f"is_favorite={is_favorite}, is_archived={is_archived}, settings={settings}"
|
||||
)
|
||||
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
|
||||
if auto_update_version is not None:
|
||||
@@ -581,10 +583,23 @@ async def update_library_agent(
|
||||
update_fields["isDeleted"] = is_deleted
|
||||
if settings is not None:
|
||||
update_fields["settings"] = SafeJson(settings.model_dump())
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
try:
|
||||
# If graph_version is provided, update to that specific version
|
||||
if graph_version is not None:
|
||||
# Get the current agent to find its graph_id
|
||||
agent = await get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
# Update to the specified version using existing function
|
||||
return await update_agent_version_in_library(
|
||||
user_id=user_id,
|
||||
agent_graph_id=agent.graph_id,
|
||||
agent_graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Otherwise, just update the simple fields
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id},
|
||||
data=update_fields,
|
||||
@@ -1,16 +1,15 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
from . import db
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agents(mocker):
|
||||
@@ -88,7 +87,7 @@ async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.server.v2.library.db.transaction")
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
@@ -151,7 +150,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.server.v2.library.db.graph_db")
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
@@ -159,7 +158,9 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_graph_db.get_graph = mocker.AsyncMock(return_value=mock_graph_model)
|
||||
|
||||
# Mock the model conversion
|
||||
mock_from_db = mocker.patch("backend.server.v2.library.model.LibraryAgent.from_db")
|
||||
mock_from_db = mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db"
|
||||
)
|
||||
mock_from_db.return_value = mocker.Mock()
|
||||
|
||||
# Call function
|
||||
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
@@ -48,6 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -163,6 +164,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
@@ -385,6 +387,9 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
auto_update_version: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Auto-update the agent version"
|
||||
)
|
||||
graph_version: Optional[int] = pydantic.Field(
|
||||
default=None, description="Specific graph version to update to"
|
||||
)
|
||||
is_favorite: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Mark the agent as a favorite"
|
||||
)
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -6,12 +6,13 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
from prisma.enums import OnboardingStep
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
from backend.data.onboarding import complete_onboarding_step
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .. import db as library_db
|
||||
from .. import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
@@ -284,6 +285,7 @@ async def update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
graph_version=payload.graph_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
@@ -4,8 +4,6 @@ from typing import Any, Optional
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
@@ -17,6 +15,9 @@ from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import db
|
||||
from .. import model as models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
@@ -7,10 +7,11 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from backend.server.v2.library.routes import router as library_router
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import model as library_model
|
||||
from .routes import router as library_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(library_router)
|
||||
|
||||
@@ -41,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -63,6 +65,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -86,7 +89,7 @@ async def test_get_library_agents_success(
|
||||
total_items=2, total_pages=1, current_page=1, page_size=50
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
@@ -112,7 +115,7 @@ async def test_get_library_agents_success(
|
||||
|
||||
|
||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
@@ -137,6 +140,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -161,7 +165,7 @@ async def test_get_favorite_library_agents_success(
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
"backend.api.features.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
@@ -184,7 +188,7 @@ def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
"backend.api.features.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
@@ -204,6 +208,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -223,11 +228,11 @@ def test_add_agent_to_library_success(
|
||||
)
|
||||
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
"backend.api.features.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.return_value = mock_library_agent
|
||||
mock_complete_onboarding = mocker.patch(
|
||||
"backend.server.v2.library.routes.agents.complete_onboarding_step",
|
||||
"backend.api.features.library.routes.agents.complete_onboarding_step",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
@@ -249,7 +254,7 @@ def test_add_agent_to_library_success(
|
||||
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
"backend.api.features.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Endpoints
|
||||
|
||||
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||
from gcloud.aio import storage as async_storage
|
||||
from PIL import Image
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
OAuthApplicationInfo,
|
||||
TokenIntrospectionResult,
|
||||
consume_authorization_code,
|
||||
create_access_token,
|
||||
create_authorization_code,
|
||||
create_refresh_token,
|
||||
get_oauth_application,
|
||||
get_oauth_application_by_id,
|
||||
introspect_token,
|
||||
list_user_oauth_applications,
|
||||
refresh_tokens,
|
||||
revoke_access_token,
|
||||
revoke_refresh_token,
|
||||
update_oauth_application,
|
||||
validate_client_credentials,
|
||||
validate_redirect_uri,
|
||||
validate_scopes,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 token response"""
|
||||
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
access_token: str
|
||||
access_token_expires_at: datetime
|
||||
refresh_token: str
|
||||
refresh_token_expires_at: datetime
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OAuth 2.0 error response"""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthApplicationPublicInfo(BaseModel):
|
||||
"""Public information about an OAuth application (for consent screen)"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Info Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/app/{client_id}",
|
||||
responses={
|
||||
404: {"description": "Application not found or disabled"},
|
||||
},
|
||||
)
|
||||
async def get_oauth_app_info(
|
||||
client_id: str, user_id: str = Security(get_user_id)
|
||||
) -> OAuthApplicationPublicInfo:
|
||||
"""
|
||||
Get public information about an OAuth application.
|
||||
|
||||
This endpoint is used by the consent screen to display application details
|
||||
to the user before they authorize access.
|
||||
|
||||
Returns:
|
||||
- name: Application name
|
||||
- description: Application description (if provided)
|
||||
- scopes: List of scopes the application is allowed to request
|
||||
"""
|
||||
app = await get_oauth_application(client_id)
|
||||
if not app or not app.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found",
|
||||
)
|
||||
|
||||
return OAuthApplicationPublicInfo(
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logo_url,
|
||||
scopes=[s.value for s in app.scopes],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuthorizeRequest(BaseModel):
|
||||
"""OAuth 2.0 authorization request"""
|
||||
|
||||
client_id: str = Field(description="Client identifier")
|
||||
redirect_uri: str = Field(description="Redirect URI")
|
||||
scopes: list[str] = Field(description="List of scopes")
|
||||
state: str = Field(description="Anti-CSRF token from client")
|
||||
response_type: str = Field(
|
||||
default="code", description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
"""OAuth 2.0 authorization response with redirect URL"""
|
||||
|
||||
redirect_url: str = Field(description="URL to redirect the user to")
|
||||
|
||||
|
||||
@router.post("/authorize")
|
||||
async def authorize(
|
||||
request: AuthorizeRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint
|
||||
|
||||
User must be logged in (authenticated with Supabase JWT).
|
||||
This endpoint creates an authorization code and returns a redirect URL.
|
||||
|
||||
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||
|
||||
The frontend consent screen should call this endpoint after the user approves,
|
||||
then redirect the user to the returned `redirect_url`.
|
||||
|
||||
Request Body:
|
||||
- client_id: The OAuth application's client ID
|
||||
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||
- response_type: Must be "code" (for authorization code flow)
|
||||
- code_challenge: PKCE code challenge (required)
|
||||
- code_challenge_method: "S256" (recommended) or "plain"
|
||||
|
||||
Returns:
|
||||
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||
|
||||
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||
for critical errors (like invalid redirect_uri).
|
||||
"""
|
||||
try:
|
||||
# Validate response_type
|
||||
if request.response_type != "code":
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"unsupported_response_type",
|
||||
"Only 'code' response type is supported",
|
||||
)
|
||||
|
||||
# Get application
|
||||
app = await get_oauth_application(request.client_id)
|
||||
if not app:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Unknown client_id",
|
||||
)
|
||||
|
||||
if not app.is_active:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Application is not active",
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if not validate_redirect_uri(app, request.redirect_uri):
|
||||
# For invalid redirect_uri, we can't redirect safely
|
||||
# Must return error instead
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Invalid redirect_uri. "
|
||||
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
try:
|
||||
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||
except ValueError as e:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
f"Invalid scope: {e}",
|
||||
)
|
||||
|
||||
if not requested_scopes:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"At least one scope is required",
|
||||
)
|
||||
|
||||
if not validate_scopes(app, requested_scopes):
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"Application is not authorized for all requested scopes. "
|
||||
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||
)
|
||||
|
||||
# Create authorization code
|
||||
auth_code = await create_authorization_code(
|
||||
application_id=app.id,
|
||||
user_id=user_id,
|
||||
scopes=requested_scopes,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_challenge=request.code_challenge,
|
||||
code_challenge_method=request.code_challenge_method,
|
||||
)
|
||||
|
||||
# Build redirect URL with authorization code
|
||||
params = {
|
||||
"code": auth_code.code,
|
||||
"state": request.state,
|
||||
}
|
||||
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"Authorization code issued for user #{user_id} "
|
||||
f"and app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"server_error",
|
||||
"An unexpected error occurred",
|
||||
)
|
||||
|
||||
|
||||
def _error_redirect_url(
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
error: str,
|
||||
error_description: Optional[str] = None,
|
||||
) -> AuthorizeResponse:
|
||||
"""Helper to build redirect URL with OAuth error parameters"""
|
||||
params = {
|
||||
"error": error,
|
||||
"state": state,
|
||||
}
|
||||
if error_description:
|
||||
params["error_description"] = error_description
|
||||
|
||||
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenRequestByCode(BaseModel):
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(description="Authorization code")
|
||||
redirect_uri: str = Field(
|
||||
description="Redirect URI (must match authorization request)"
|
||||
)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
code_verifier: str = Field(description="PKCE code verifier")
|
||||
|
||||
|
||||
class TokenRequestByRefreshToken(BaseModel):
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint
|
||||
|
||||
Exchanges authorization code or refresh token for access token.
|
||||
|
||||
Grant Types:
|
||||
1. authorization_code: Exchange authorization code for tokens
|
||||
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||
- Optional: code_verifier (required if PKCE was used)
|
||||
|
||||
2. refresh_token: Exchange refresh token for new access token
|
||||
- Required: grant_type, refresh_token, client_id, client_secret
|
||||
|
||||
Returns:
|
||||
- access_token: Bearer token for API access (1 hour TTL)
|
||||
- token_type: "Bearer"
|
||||
- expires_in: Seconds until access token expires
|
||||
- refresh_token: Token for refreshing access (30 days TTL)
|
||||
- scopes: List of scopes
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(
|
||||
request.client_id, request.client_secret
|
||||
)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Handle authorization_code grant
|
||||
if request.grant_type == "authorization_code":
|
||||
# Consume authorization code
|
||||
try:
|
||||
user_id, scopes = await consume_authorization_code(
|
||||
code=request.code,
|
||||
application_id=app.id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_verifier=request.code_verifier,
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create access and refresh tokens
|
||||
access_token = await create_access_token(app.id, user_id, scopes)
|
||||
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||
|
||||
logger.info(
|
||||
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||
"via authorization code"
|
||||
)
|
||||
|
||||
if not access_token.token or not refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=access_token.token.get_secret_value(),
|
||||
access_token_expires_at=access_token.expires_at,
|
||||
refresh_token=refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=refresh_token.expires_at,
|
||||
scopes=list(s.value for s in scopes),
|
||||
)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif request.grant_type == "refresh_token":
|
||||
# Refresh access token
|
||||
try:
|
||||
new_access_token, new_refresh_token = await refresh_tokens(
|
||||
request.refresh_token, app.id
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||
f"by app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
if not new_access_token.token or not new_refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=new_access_token.token.get_secret_value(),
|
||||
access_token_expires_at=new_access_token.expires_at,
|
||||
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||
scopes=list(s.value for s in new_access_token.scopes),
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||
"Must be 'authorization_code' or 'refresh_token'",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Body(description="Token to introspect"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||
|
||||
Allows clients to check if a token is valid and get its metadata.
|
||||
|
||||
Returns:
|
||||
- active: Whether the token is currently active
|
||||
- scopes: List of authorized scopes (if active)
|
||||
- client_id: The client the token was issued to (if active)
|
||||
- user_id: The user the token represents (if active)
|
||||
- exp: Expiration timestamp (if active)
|
||||
- token_type: "access_token" or "refresh_token" (if active)
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Introspect the token
|
||||
return await introspect_token(token, token_type_hint)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Body(description="Token to revoke"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
):
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||
|
||||
Allows clients to revoke an access or refresh token.
|
||||
|
||||
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||
Revoking an access token does NOT revoke the associated refresh token.
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Try to revoke as access token first
|
||||
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||
if token_type_hint != "refresh_token":
|
||||
revoked = await revoke_access_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Try to revoke as refresh token
|
||||
revoked = await revoke_refresh_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||
# or if token belongs to a different application.
|
||||
# This prevents token scanning attacks.
|
||||
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Management Endpoints (for app owners)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/apps/mine")
|
||||
async def list_my_oauth_apps(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[OAuthApplicationInfo]:
|
||||
"""
|
||||
List all OAuth applications owned by the current user.
|
||||
|
||||
Returns a list of OAuth applications with their details including:
|
||||
- id, name, description, logo_url
|
||||
- client_id (public identifier)
|
||||
- redirect_uris, grant_types, scopes
|
||||
- is_active status
|
||||
- created_at, updated_at timestamps
|
||||
|
||||
Note: client_secret is never returned for security reasons.
|
||||
"""
|
||||
return await list_user_oauth_applications(user_id)
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/status")
|
||||
async def update_app_status(
|
||||
app_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Enable or disable an OAuth application.
|
||||
|
||||
Only the application owner can update the status.
|
||||
When disabled, the application cannot be used for new authorizations
|
||||
and existing access tokens will fail validation.
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
action = "enabled" if is_active else "disabled"
|
||||
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
class UpdateAppLogoRequest(BaseModel):
|
||||
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/logo")
|
||||
async def update_app_logo(
|
||||
app_id: str,
|
||||
request: UpdateAppLogoRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Update the logo URL for an OAuth application.
|
||||
|
||||
Only the application owner can update the logo.
|
||||
The logo should be uploaded first using the media upload endpoint,
|
||||
then this endpoint is called with the resulting URL.
|
||||
|
||||
Logo requirements:
|
||||
- Must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=request.logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
# Logo upload constraints
|
||||
LOGO_MIN_SIZE = 512
|
||||
LOGO_MAX_SIZE = 2048
|
||||
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||
|
||||
|
||||
@router.post("/apps/{app_id}/logo/upload")
|
||||
async def upload_app_logo(
|
||||
app_id: str,
|
||||
file: UploadFile,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Upload a logo image for an OAuth application.
|
||||
|
||||
Requirements:
|
||||
- Image must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
- Allowed formats: JPEG, PNG, WebP
|
||||
- Maximum file size: 3MB
|
||||
|
||||
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||
Returns the updated application info.
|
||||
"""
|
||||
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Check GCS configuration
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Media storage is not configured",
|
||||
)
|
||||
|
||||
# Validate content type
|
||||
content_type = file.content_type
|
||||
if content_type not in LOGO_ALLOWED_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading logo file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read uploaded file",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"File too large. "
|
||||
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate image dimensions
|
||||
try:
|
||||
image = Image.open(io.BytesIO(file_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if width != height:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo must be square. Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width < LOGO_MIN_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width > LOGO_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating logo image: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid image file",
|
||||
)
|
||||
|
||||
# Scan for viruses
|
||||
filename = file.filename or "logo"
|
||||
await scan_content_safe(file_bytes, filename=filename)
|
||||
|
||||
# Generate unique filename
|
||||
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||
|
||||
# Upload to GCS
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading logo to GCS: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload logo",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
# Update the app with the new logo URL
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||
"""
|
||||
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||
"""
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||
|
||||
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||
old_path = app.logo_url.replace(storage_base_url, "")
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
await async_client.delete(bucket_name, old_path)
|
||||
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||
except Exception as e:
|
||||
# Log but don't fail - the new logo was uploaded successfully
|
||||
logger.warning(
|
||||
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||
)
|
||||
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,9 +6,9 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.otto.models as otto_models
|
||||
import backend.server.v2.otto.routes as otto_routes
|
||||
from backend.server.v2.otto.service import OttoService
|
||||
from . import models as otto_models
|
||||
from . import routes as otto_routes
|
||||
from .service import OttoService
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(otto_routes.router)
|
||||
@@ -4,12 +4,15 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from backend.api.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.data.user import (
|
||||
get_user_by_email,
|
||||
set_user_email_verification,
|
||||
unsubscribe_user_by_token,
|
||||
)
|
||||
from backend.server.routers.postmark.models import (
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import (
|
||||
PostmarkBounceEnum,
|
||||
PostmarkBounceWebhook,
|
||||
PostmarkClickWebhook,
|
||||
@@ -19,8 +22,6 @@ from backend.server.routers.postmark.models import (
|
||||
PostmarkSubscriptionChangeWebhook,
|
||||
PostmarkWebhook,
|
||||
)
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI script to backfill embeddings for store agents.
|
||||
|
||||
Usage:
|
||||
poetry run python -m backend.server.v2.store.backfill_embeddings [--batch-size N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import prisma
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
async def main(batch_size: int = 100) -> int:
|
||||
"""Run the backfill process."""
|
||||
# Initialize Prisma client
|
||||
client = prisma.Prisma()
|
||||
await client.connect()
|
||||
prisma.register(client)
|
||||
|
||||
try:
|
||||
# Get current stats
|
||||
print("Current embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
if stats["without_embeddings"] == 0:
|
||||
print("\nAll agents already have embeddings. Nothing to do.")
|
||||
return 0
|
||||
|
||||
# Run backfill
|
||||
print(f"\nBackfilling up to {batch_size} embeddings...")
|
||||
result = await backfill_missing_embeddings(batch_size=batch_size)
|
||||
print(f" Processed: {result['processed']}")
|
||||
print(f" Success: {result['success']}")
|
||||
print(f" Failed: {result['failed']}")
|
||||
|
||||
# Get final stats
|
||||
print("\nFinal embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
return 0 if result["failed"] == 0 else 1
|
||||
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Backfill embeddings for store agents")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of embeddings to generate (default: 100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sys.exit(asyncio.run(main(batch_size=args.batch_size)))
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Literal
|
||||
|
||||
import backend.server.v2.store.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
@@ -29,7 +30,7 @@ async def _get_cached_store_agents(
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
return await store_db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
@@ -42,10 +43,12 @@ async def _get_cached_store_agents(
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
async def _get_cached_agent_details(
|
||||
username: str, agent_name: str, include_changelog: bool = False
|
||||
):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
return await store_db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name, include_changelog=include_changelog
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +62,7 @@ async def _get_cached_store_creators(
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
return await store_db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
@@ -72,6 +75,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,8 +6,8 @@ import prisma.models
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
import backend.server.v2.store.db as db
|
||||
from backend.server.v2.store.model import Profile
|
||||
from . import db
|
||||
from .model import Profile
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -40,6 +40,8 @@ async def test_get_store_agents(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
@@ -83,6 +85,8 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
@@ -105,6 +109,8 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id-active",
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
@@ -0,0 +1,533 @@
|
||||
"""
|
||||
Unified Content Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for all content types
|
||||
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
from openai import OpenAI
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.util.json import dumps
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
EMBEDDING_DIM = 1536
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
"""
|
||||
try:
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.openai_internal_api_key
|
||||
if not api_key:
|
||||
logger.warning("openai_internal_api_key not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
# Truncate text to avoid token limits (~32k chars for safety)
|
||||
truncated_text = text[:32000]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.debug(f"Generated embedding with {len(embedding)} dimensions")
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text="", # Will be populated from existing data
|
||||
metadata=None,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def store_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the unified content embeddings table.
|
||||
|
||||
New function for unified content embedding storage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
metadata_json = dumps(metadata or {})
|
||||
|
||||
# Upsert the embedding
|
||||
await client.execute_raw(
|
||||
"""
|
||||
INSERT INTO platform."UnifiedContentEmbedding" (
|
||||
"contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES ($1, $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
result = await get_content_embedding(
|
||||
ContentType.STORE_AGENT, version_id, user_id=None
|
||||
)
|
||||
if result:
|
||||
# Transform to old format for backward compatibility
|
||||
return {
|
||||
"storeListingVersionId": result["contentId"],
|
||||
"embedding": result["embedding"],
|
||||
"createdAt": result["createdAt"],
|
||||
"updatedAt": result["updatedAt"],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def get_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for any content type.
|
||||
|
||||
New function for unified content embedding retrieval.
|
||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
result = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM platform."UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1 AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing. Use force=True to regenerate.
|
||||
Backward-compatible wrapper for store listings.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
||||
|
||||
|
||||
async def delete_content_embedding(content_type: ContentType, content_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for any content type.
|
||||
|
||||
New function for unified content embedding deletion.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await client.execute_raw(
|
||||
"""
|
||||
DELETE FROM platform."UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1 AND "contentId" = $2
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
)
|
||||
|
||||
logger.info(f"Deleted embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns counts of:
|
||||
- Total approved listing versions
|
||||
- Versions with embeddings
|
||||
- Versions without embeddings
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Count approved versions
|
||||
approved_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await client.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM platform."StoreListingVersion" slv
|
||||
JOIN platform."UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total_approved": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
"coverage_percent": (
|
||||
round(with_embeddings / total_approved * 100, 1)
|
||||
if total_approved > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"total_approved": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
# Find approved versions without embeddings
|
||||
missing = await client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM platform."StoreListingVersion" slv
|
||||
LEFT JOIN platform."UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
if not missing:
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_embedding(
|
||||
version_id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
return {
|
||||
"processed": len(missing),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill embeddings: {e}")
|
||||
return {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
|
||||
async def ensure_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for any content type.
|
||||
|
||||
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
||||
|
||||
Args:
|
||||
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
||||
content_id: Unique identifier for the content
|
||||
searchable_text: Combined text for embedding generation
|
||||
metadata: Optional metadata to store with embedding
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(
|
||||
f"Embedding for {content_type}:{content_id} already exists"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,359 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import prisma
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
"""Setup Prisma client for tests."""
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text():
|
||||
"""Test searchable text building from listing fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="AI Assistant",
|
||||
description="A helpful AI assistant for productivity",
|
||||
sub_heading="Boost your productivity",
|
||||
categories=["AI", "Productivity"],
|
||||
)
|
||||
|
||||
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text_empty_fields():
|
||||
"""Test searchable text building with empty fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="", description="Test description", sub_heading="", categories=[]
|
||||
)
|
||||
|
||||
assert result == "Test description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.OpenAI")
|
||||
async def test_generate_embedding_success(mock_openai_class):
|
||||
"""Test successful embedding generation."""
|
||||
# Mock OpenAI response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
with patch("backend.api.features.store.embeddings.Settings") as mock_settings:
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test-key"
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1536
|
||||
assert result[0] == 0.1
|
||||
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-3-small", input="test text"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.OpenAI")
|
||||
async def test_generate_embedding_no_api_key(mock_openai_class):
|
||||
"""Test embedding generation without API key."""
|
||||
with patch("backend.api.features.store.embeddings.Settings") as mock_settings:
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = ""
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
mock_openai_class.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.OpenAI")
|
||||
async def test_generate_embedding_api_error(mock_openai_class):
|
||||
"""Test embedding generation with API error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create.side_effect = Exception("API Error")
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
with patch("backend.api.features.store.embeddings.Settings") as mock_settings:
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test-key"
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.OpenAI")
|
||||
async def test_generate_embedding_text_truncation(mock_openai_class):
|
||||
"""Test that long text is properly truncated."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1] * 1536
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Create text longer than 32k chars
|
||||
long_text = "a" * 35000
|
||||
|
||||
with patch("backend.api.features.store.embeddings.Settings") as mock_settings:
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test-key"
|
||||
|
||||
await embeddings.generate_embedding(long_text)
|
||||
|
||||
# Verify truncated text was sent to API
|
||||
call_args = mock_client.embeddings.create.call_args
|
||||
assert len(call_args.kwargs["input"]) == 32000
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_success(mocker):
|
||||
"""Test successful embedding storage."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw = mocker.AsyncMock()
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_client.execute_raw.assert_called_once()
|
||||
call_args = mock_client.execute_raw.call_args[0]
|
||||
assert "test-version-id" in call_args
|
||||
assert "[0.1,0.2,0.3]" in call_args
|
||||
assert None in call_args # userId should be None for store agents
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_database_error(mocker):
|
||||
"""Test embedding storage with database error."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_success(mocker):
|
||||
"""Test successful embedding retrieval."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_result = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"embedding": "[0.1,0.2,0.3]",
|
||||
"searchableText": "Test text",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01T00:00:00Z",
|
||||
"updatedAt": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
]
|
||||
mock_client.query_raw.return_value = mock_result
|
||||
|
||||
with patch("prisma.get_client", return_value=mock_client):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert result["embedding"] == "[0.1,0.2,0.3]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_not_found(mocker):
|
||||
"""Test embedding retrieval when not found."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.query_raw.return_value = []
|
||||
|
||||
with patch("prisma.get_client", return_value=mock_client):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding when embedding already exists."""
|
||||
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_not_called()
|
||||
mock_store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding creating new embedding."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_called_once_with("Test Test heading Test description test")
|
||||
mock_store.assert_called_once_with(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
searchable_text="Test Test heading Test description test",
|
||||
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
|
||||
user_id=None,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||
"""Test ensure_embedding when generation fails."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = None
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats(mocker):
|
||||
"""Test embedding statistics retrieval."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved_result = [{"count": 100}]
|
||||
# Mock embedded count query
|
||||
mock_embedded_result = [{"count": 75}]
|
||||
|
||||
mock_client.query_raw.side_effect = [mock_approved_result, mock_embedded_result]
|
||||
|
||||
with patch("prisma.get_client", return_value=mock_client):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 75
|
||||
assert result["without_embeddings"] == 25
|
||||
assert result["coverage_percent"] == 75.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.ensure_embedding")
|
||||
async def test_backfill_missing_embeddings_success(mock_ensure, mocker):
|
||||
"""Test backfill with successful embedding generation."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
|
||||
# Mock missing embeddings query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Agent 1",
|
||||
"description": "Description 1",
|
||||
"subHeading": "Heading 1",
|
||||
"categories": ["AI"],
|
||||
},
|
||||
{
|
||||
"id": "version-2",
|
||||
"name": "Agent 2",
|
||||
"description": "Description 2",
|
||||
"subHeading": "Heading 2",
|
||||
"categories": ["Productivity"],
|
||||
},
|
||||
]
|
||||
mock_client.query_raw.return_value = mock_missing
|
||||
|
||||
# Mock ensure_embedding to succeed for first, fail for second
|
||||
mock_ensure.side_effect = [True, False]
|
||||
|
||||
with patch("prisma.get_client", return_value=mock_client):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 2
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert mock_ensure.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_backfill_missing_embeddings_no_missing(mocker):
|
||||
"""Test backfill when no embeddings are missing."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.query_raw.return_value = []
|
||||
|
||||
with patch("prisma.get_client", return_value=mock_client):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 0
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["message"] == "No missing embeddings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embedding_to_vector_string():
|
||||
"""Test embedding to PostgreSQL vector string conversion."""
|
||||
embedding = [0.1, 0.2, 0.3, -0.4]
|
||||
result = embeddings.embedding_to_vector_string(embedding)
|
||||
assert result == "[0.1,0.2,0.3,-0.4]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embed_query():
|
||||
"""Test embed_query function (alias for generate_embedding)."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.embed_query("test query")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_generate.assert_called_once_with("test query")
|
||||
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
|
||||
semantic: float = 0.35 # Embedding cosine similarity
|
||||
lexical: float = 0.35 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.35 semantic + 0.35 lexical + 0.20 category + 0.10 recency):
|
||||
# - 0.20 means at least ~50% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency alone (0.10 max) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
"""
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add lowercased query for category matching
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Determine if we can use hybrid search (have query embedding)
|
||||
use_hybrid = query_embedding is not None
|
||||
|
||||
if use_hybrid:
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Optimized hybrid search query:
|
||||
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
|
||||
# 2. UNION ALL approach to enable index usage for both lexical and semantic branches
|
||||
# 3. COUNT(*) OVER() to get total count in single query
|
||||
# 4. Simplified category matching with array_to_string
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches (uses GIN index on search column)
|
||||
SELECT DISTINCT sa."storeListingVersionId"
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND sa.search @@ plainto_tsquery('english', {query_param})
|
||||
|
||||
UNION
|
||||
|
||||
-- Semantic matches (uses HNSW index on embedding)
|
||||
SELECT DISTINCT sa."storeListingVersionId"
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'
|
||||
WHERE {where_clause}
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
sa.sub_heading,
|
||||
sa.description,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.categories,
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd (will be normalized later)
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: check if query appears in any category
|
||||
CASE
|
||||
WHEN LOWER(array_to_string(sa.categories, ' ')) LIKE '%' || {query_lower_param} || '%'
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: exponential decay over 90 days
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
-- Normalize lexical score by max in result set
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM search_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.semantic} * semantic_score +
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
else:
|
||||
# Fallback to lexical-only search (existing behavior)
|
||||
logger.warning("Falling back to lexical-only search (no query embedding)")
|
||||
|
||||
sql_query = f"""
|
||||
WITH lexical_scores AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
0.0 as semantic_score,
|
||||
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
|
||||
CASE
|
||||
WHEN LOWER(array_to_string(categories, ' ')) LIKE '%' || {query_lower_param} || '%'
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND search @@ plainto_tsquery('english', {query_param})
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
*,
|
||||
CASE
|
||||
WHEN MAX(lexical_raw) OVER () > 0
|
||||
THEN lexical_raw / MAX(lexical_raw) OVER ()
|
||||
ELSE 0
|
||||
END as lexical_score
|
||||
FROM lexical_scores
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
(
|
||||
{weights.lexical} * lexical_score +
|
||||
{weights.category} * category_score +
|
||||
{weights.recency} * recency_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
params.extend([page_size, offset])
|
||||
|
||||
try:
|
||||
# Execute search query - includes total_count via window function
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
# Extract total count from first result (all rows have same count)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
# Remove total_count from results before returning
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
|
||||
logger.info(
|
||||
f"Hybrid search for '{query}': {len(results)} results, {total} total "
|
||||
f"(hybrid={use_hybrid})"
|
||||
)
|
||||
|
||||
return results, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hybrid search failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -5,11 +5,12 @@ import uuid
|
||||
import fastapi
|
||||
from gcloud.aio import storage as async_storage
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_IMAGE_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
|
||||
@@ -68,61 +69,55 @@ async def upload_media(
|
||||
await file.seek(0) # Reset file pointer
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file content: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.FileReadError(
|
||||
"Failed to read file content"
|
||||
) from e
|
||||
raise store_exceptions.FileReadError("Failed to read file content") from e
|
||||
|
||||
# Validate file signature/magic bytes
|
||||
if file.content_type in ALLOWED_IMAGE_TYPES:
|
||||
# Check image file signatures
|
||||
if content.startswith(b"\xff\xd8\xff"): # JPEG
|
||||
if file.content_type != "image/jpeg":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
raise store_exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"\x89PNG\r\n\x1a\n"): # PNG
|
||||
if file.content_type != "image/png":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
raise store_exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"GIF87a") or content.startswith(b"GIF89a"): # GIF
|
||||
if file.content_type != "image/gif":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
raise store_exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"RIFF") and content[8:12] == b"WEBP": # WebP
|
||||
if file.content_type != "image/webp":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
raise store_exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
else:
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"Invalid image file signature"
|
||||
)
|
||||
raise store_exceptions.InvalidFileTypeError("Invalid image file signature")
|
||||
|
||||
elif file.content_type in ALLOWED_VIDEO_TYPES:
|
||||
# Check video file signatures
|
||||
if content.startswith(b"\x00\x00\x00") and (content[4:8] == b"ftyp"): # MP4
|
||||
if file.content_type != "video/mp4":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
raise store_exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"\x1a\x45\xdf\xa3"): # WebM
|
||||
if file.content_type != "video/webm":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
raise store_exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
else:
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"Invalid video file signature"
|
||||
)
|
||||
raise store_exceptions.InvalidFileTypeError("Invalid video file signature")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Check required settings first before doing any file processing
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
logger.error("Missing GCS bucket name setting")
|
||||
raise backend.server.v2.store.exceptions.StorageConfigError(
|
||||
raise store_exceptions.StorageConfigError(
|
||||
"Missing storage bucket configuration"
|
||||
)
|
||||
|
||||
@@ -137,7 +132,7 @@ async def upload_media(
|
||||
and content_type not in ALLOWED_VIDEO_TYPES
|
||||
):
|
||||
logger.warning(f"Invalid file type attempted: {content_type}")
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
raise store_exceptions.InvalidFileTypeError(
|
||||
f"File type not supported. Must be jpeg, png, gif, webp, mp4 or webm. Content type: {content_type}"
|
||||
)
|
||||
|
||||
@@ -150,16 +145,14 @@ async def upload_media(
|
||||
file_size += len(chunk)
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
logger.warning(f"File size too large: {file_size} bytes")
|
||||
raise backend.server.v2.store.exceptions.FileSizeTooLargeError(
|
||||
raise store_exceptions.FileSizeTooLargeError(
|
||||
"File too large. Maximum size is 50MB"
|
||||
)
|
||||
except backend.server.v2.store.exceptions.FileSizeTooLargeError:
|
||||
except store_exceptions.FileSizeTooLargeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file chunks: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.FileReadError(
|
||||
"Failed to read uploaded file"
|
||||
) from e
|
||||
raise store_exceptions.FileReadError("Failed to read uploaded file") from e
|
||||
|
||||
# Reset file pointer
|
||||
await file.seek(0)
|
||||
@@ -198,14 +191,14 @@ async def upload_media(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GCS storage error: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.StorageUploadError(
|
||||
raise store_exceptions.StorageUploadError(
|
||||
"Failed to upload file to storage"
|
||||
) from e
|
||||
|
||||
except backend.server.v2.store.exceptions.MediaUploadError:
|
||||
except store_exceptions.MediaUploadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error in upload_media")
|
||||
raise backend.server.v2.store.exceptions.MediaUploadError(
|
||||
raise store_exceptions.MediaUploadError(
|
||||
"Unexpected error during media upload"
|
||||
) from e
|
||||
@@ -6,17 +6,18 @@ import fastapi
|
||||
import pytest
|
||||
import starlette.datastructures
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.media
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
from . import media as store_media
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(monkeypatch):
|
||||
settings = Settings()
|
||||
settings.config.media_gcs_bucket_name = "test-bucket"
|
||||
settings.config.google_application_credentials = "test-credentials"
|
||||
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
|
||||
monkeypatch.setattr("backend.api.features.store.media.Settings", lambda: settings)
|
||||
return settings
|
||||
|
||||
|
||||
@@ -32,12 +33,13 @@ def mock_storage_client(mocker):
|
||||
|
||||
# Mock the constructor to return our mock client
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.media.async_storage.Storage", return_value=mock_client
|
||||
"backend.api.features.store.media.async_storage.Storage",
|
||||
return_value=mock_client,
|
||||
)
|
||||
|
||||
# Mock virus scanner to avoid actual scanning
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.media.scan_content_safe", new_callable=AsyncMock
|
||||
"backend.api.features.store.media.scan_content_safe", new_callable=AsyncMock
|
||||
)
|
||||
|
||||
return mock_client
|
||||
@@ -53,7 +55,7 @@ async def test_upload_media_success(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
result = await store_media.upload_media("test-user", test_file)
|
||||
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
@@ -69,8 +71,8 @@ async def test_upload_media_invalid_type(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
with pytest.raises(store_exceptions.InvalidFileTypeError):
|
||||
await store_media.upload_media("test-user", test_file)
|
||||
|
||||
mock_storage_client.upload.assert_not_called()
|
||||
|
||||
@@ -79,7 +81,7 @@ async def test_upload_media_missing_credentials(monkeypatch):
|
||||
settings = Settings()
|
||||
settings.config.media_gcs_bucket_name = ""
|
||||
settings.config.google_application_credentials = ""
|
||||
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
|
||||
monkeypatch.setattr("backend.api.features.store.media.Settings", lambda: settings)
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
@@ -87,8 +89,8 @@ async def test_upload_media_missing_credentials(monkeypatch):
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.StorageConfigError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
with pytest.raises(store_exceptions.StorageConfigError):
|
||||
await store_media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_video_type(mock_settings, mock_storage_client):
|
||||
@@ -98,7 +100,7 @@ async def test_upload_media_video_type(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "video/mp4"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
result = await store_media.upload_media("test-user", test_file)
|
||||
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/videos/"
|
||||
@@ -117,8 +119,8 @@ async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.FileSizeTooLargeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
with pytest.raises(store_exceptions.FileSizeTooLargeError):
|
||||
await store_media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_file_read_error(mock_settings, mock_storage_client):
|
||||
@@ -129,8 +131,8 @@ async def test_upload_media_file_read_error(mock_settings, mock_storage_client):
|
||||
)
|
||||
test_file.read = unittest.mock.AsyncMock(side_effect=Exception("Read error"))
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.FileReadError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
with pytest.raises(store_exceptions.FileReadError):
|
||||
await store_media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_png_success(mock_settings, mock_storage_client):
|
||||
@@ -140,7 +142,7 @@ async def test_upload_media_png_success(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/png"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
result = await store_media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
)
|
||||
@@ -154,7 +156,7 @@ async def test_upload_media_gif_success(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/gif"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
result = await store_media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
)
|
||||
@@ -168,7 +170,7 @@ async def test_upload_media_webp_success(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/webp"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
result = await store_media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
)
|
||||
@@ -182,7 +184,7 @@ async def test_upload_media_webm_success(mock_settings, mock_storage_client):
|
||||
headers=starlette.datastructures.Headers({"content-type": "video/webm"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
result = await store_media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/videos/"
|
||||
)
|
||||
@@ -196,8 +198,8 @@ async def test_upload_media_mismatched_signature(mock_settings, mock_storage_cli
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
with pytest.raises(store_exceptions.InvalidFileTypeError):
|
||||
await store_media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_invalid_signature(mock_settings, mock_storage_client):
|
||||
@@ -207,5 +209,5 @@ async def test_upload_media_invalid_signature(mock_settings, mock_storage_client
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
with pytest.raises(store_exceptions.InvalidFileTypeError):
|
||||
await store_media.upload_media("test-user", test_file)
|
||||
@@ -7,6 +7,12 @@ import pydantic
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
class ChangelogEntry(pydantic.BaseModel):
|
||||
version: str
|
||||
changes_summary: str
|
||||
date: datetime.datetime
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
@@ -55,12 +61,17 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
agentGraphVersions: list[str]
|
||||
agentGraphId: str
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
|
||||
# Optional changelog data when include_changelog=True
|
||||
changelog: list[ChangelogEntry] | None = None
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
@@ -99,6 +110,7 @@ class Profile(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
@@ -153,8 +165,12 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
)
|
||||
slug: str
|
||||
name: str
|
||||
sub_heading: str
|
||||
@@ -2,11 +2,11 @@ import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.model
|
||||
from . import model as store_model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = backend.server.v2.store.model.Pagination(
|
||||
pagination = store_model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
@@ -16,7 +16,7 @@ def test_pagination():
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = backend.server.v2.store.model.StoreAgent(
|
||||
agent = store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
@@ -34,9 +34,9 @@ def test_store_agent():
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
response = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
@@ -48,7 +48,7 @@ def test_store_agents_response():
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
@@ -57,7 +57,7 @@ def test_store_agents_response():
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = backend.server.v2.store.model.StoreAgentDetails(
|
||||
details = store_model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
@@ -72,6 +72,8 @@ def test_store_agent_details():
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
@@ -81,7 +83,7 @@ def test_store_agent_details():
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = backend.server.v2.store.model.Creator(
|
||||
creator = store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
@@ -96,9 +98,9 @@ def test_creator():
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = backend.server.v2.store.model.CreatorsResponse(
|
||||
response = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
backend.server.v2.store.model.Creator(
|
||||
store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
@@ -109,7 +111,7 @@ def test_creators_response():
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
@@ -118,7 +120,7 @@ def test_creators_response():
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = backend.server.v2.store.model.CreatorDetails(
|
||||
details = store_model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
@@ -135,7 +137,8 @@ def test_creator_details():
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = backend.server.v2.store.model.StoreSubmission(
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
@@ -154,9 +157,10 @@ def test_store_submission():
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
@@ -170,7 +174,7 @@ def test_store_submissions_response():
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
@@ -179,7 +183,7 @@ def test_store_submissions_response():
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
request = store_model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
@@ -9,14 +9,14 @@ import fastapi
|
||||
import fastapi.responses
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
|
||||
from . import cache as store_cache
|
||||
from . import db as store_db
|
||||
from . import image_gen as store_image_gen
|
||||
from . import media as store_media
|
||||
from . import model as store_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
@@ -32,7 +32,7 @@ router = fastapi.APIRouter()
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.ProfileDetails,
|
||||
response_model=store_model.ProfileDetails,
|
||||
)
|
||||
async def get_profile(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
@@ -41,7 +41,7 @@ async def get_profile(
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
profile = await store_db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
@@ -55,10 +55,10 @@ async def get_profile(
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: backend.server.v2.store.model.Profile,
|
||||
profile: store_model.Profile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -74,9 +74,7 @@ async def update_or_create_profile(
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
updated_profile = await backend.server.v2.store.db.update_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||
return updated_profile
|
||||
|
||||
|
||||
@@ -89,7 +87,7 @@ async def update_or_create_profile(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentsResponse,
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
@@ -152,9 +150,13 @@ async def get_agents(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentDetails,
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_agent(username: str, agent_name: str):
|
||||
async def get_agent(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
include_changelog: bool = fastapi.Query(default=False),
|
||||
):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
|
||||
@@ -164,7 +166,7 @@ async def get_agent(username: str, agent_name: str):
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await store_cache._get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
username=username, agent_name=agent_name, include_changelog=include_changelog
|
||||
)
|
||||
return agent
|
||||
|
||||
@@ -175,13 +177,13 @@ async def get_agent(username: str, agent_name: str):
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: str):
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphMeta:
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
graph = await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@@ -190,15 +192,13 @@ async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: s
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.StoreAgentDetails,
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
"""
|
||||
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
|
||||
return agent
|
||||
|
||||
@@ -208,12 +208,12 @@ async def get_store_agent(store_listing_version_id: str):
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.StoreReview,
|
||||
response_model=store_model.StoreReview,
|
||||
)
|
||||
async def create_review(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: backend.server.v2.store.model.StoreReviewCreate,
|
||||
review: store_model.StoreReviewCreate,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -231,7 +231,7 @@ async def create_review(
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
created_review = await store_db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
@@ -250,7 +250,7 @@ async def create_review(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorsResponse,
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
@@ -295,7 +295,7 @@ async def get_creators(
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_creator(
|
||||
username: str,
|
||||
@@ -319,7 +319,7 @@ async def get_creator(
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.MyAgentsResponse,
|
||||
response_model=store_model.MyAgentsResponse,
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
@@ -329,9 +329,7 @@ async def get_my_agents(
|
||||
"""
|
||||
Get user's own agents.
|
||||
"""
|
||||
agents = await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
|
||||
|
||||
@@ -356,7 +354,7 @@ async def delete_submission(
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
result = await backend.server.v2.store.db.delete_store_submission(
|
||||
result = await store_db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
@@ -369,7 +367,7 @@ async def delete_submission(
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmissionsResponse,
|
||||
response_model=store_model.StoreSubmissionsResponse,
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
@@ -399,7 +397,7 @@ async def get_submissions(
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
listings = await store_db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -412,10 +410,10 @@ async def get_submissions(
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: backend.server.v2.store.model.StoreSubmissionRequest,
|
||||
submission_request: store_model.StoreSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -431,7 +429,7 @@ async def create_submission(
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
result = await backend.server.v2.store.db.create_store_submission(
|
||||
result = await store_db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
@@ -456,11 +454,11 @@ async def create_submission(
|
||||
summary="Edit store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def edit_submission(
|
||||
store_listing_version_id: str,
|
||||
submission_request: backend.server.v2.store.model.StoreSubmissionEditRequest,
|
||||
submission_request: store_model.StoreSubmissionEditRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -477,7 +475,7 @@ async def edit_submission(
|
||||
Raises:
|
||||
HTTPException: If there is an error editing the submission
|
||||
"""
|
||||
result = await backend.server.v2.store.db.edit_store_submission(
|
||||
result = await store_db.edit_store_submission(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
name=submission_request.name,
|
||||
@@ -518,9 +516,7 @@ async def upload_submission_media(
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
media_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||
return media_url
|
||||
|
||||
|
||||
@@ -555,14 +551,12 @@ async def generate_image(
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
|
||||
existing_url = await backend.server.v2.store.media.check_media_exists(
|
||||
user_id, filename
|
||||
)
|
||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await backend.server.v2.store.image_gen.generate_agent_image(agent=agent)
|
||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
@@ -570,7 +564,7 @@ async def generate_image(
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await backend.server.v2.store.media.upload_media(
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
@@ -599,7 +593,7 @@ async def download_agent_file(
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent(store_listing_version_id)
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
@@ -8,15 +8,15 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
from . import model as store_model
|
||||
from . import routes as store_routes
|
||||
|
||||
# Using a fixed timestamp for reproducible tests
|
||||
# 2023 date is intentionally used to ensure tests work regardless of current year
|
||||
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(backend.server.v2.store.routes.router)
|
||||
app.include_router(store_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
@@ -35,23 +35,21 @@ def test_get_agents_defaults(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
mocked_value = store_model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=0,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=10,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentsResponse.model_validate(response.json())
|
||||
assert data.pagination.total_pages == 0
|
||||
assert data.agents == []
|
||||
|
||||
@@ -72,9 +70,9 @@ def test_get_agents_featured(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
mocked_value = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
store_model.StoreAgent(
|
||||
slug="featured-agent",
|
||||
agent_name="Featured Agent",
|
||||
agent_image="featured.jpg",
|
||||
@@ -86,20 +84,18 @@ def test_get_agents_featured(
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?featured=true")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentsResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].slug == "featured-agent"
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
@@ -119,9 +115,9 @@ def test_get_agents_by_creator(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
mocked_value = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
store_model.StoreAgent(
|
||||
slug="creator-agent",
|
||||
agent_name="Creator Agent",
|
||||
agent_image="agent.jpg",
|
||||
@@ -133,20 +129,18 @@ def test_get_agents_by_creator(
|
||||
rating=4.0,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?creator=specific-creator")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentsResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].creator == "specific-creator"
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
@@ -166,9 +160,9 @@ def test_get_agents_sorted(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
mocked_value = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
store_model.StoreAgent(
|
||||
slug="top-agent",
|
||||
agent_name="Top Agent",
|
||||
agent_image="top.jpg",
|
||||
@@ -180,20 +174,18 @@ def test_get_agents_sorted(
|
||||
rating=5.0,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?sorted_by=runs")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentsResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].runs == 1000
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
@@ -213,9 +205,9 @@ def test_get_agents_search(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
mocked_value = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
store_model.StoreAgent(
|
||||
slug="search-agent",
|
||||
agent_name="Search Agent",
|
||||
agent_image="search.jpg",
|
||||
@@ -227,20 +219,18 @@ def test_get_agents_search(
|
||||
rating=4.2,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?search_query=specific")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentsResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
assert "specific" in data.agents[0].description.lower()
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
@@ -260,9 +250,9 @@ def test_get_agents_category(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
mocked_value = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
store_model.StoreAgent(
|
||||
slug="category-agent",
|
||||
agent_name="Category Agent",
|
||||
agent_image="category.jpg",
|
||||
@@ -274,20 +264,18 @@ def test_get_agents_category(
|
||||
rating=4.1,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?category=test-category")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentsResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_category")
|
||||
@@ -306,9 +294,9 @@ def test_get_agents_pagination(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
mocked_value = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
store_model.StoreAgent(
|
||||
slug=f"agent-{i}",
|
||||
agent_name=f"Agent {i}",
|
||||
agent_image=f"agent{i}.jpg",
|
||||
@@ -321,20 +309,18 @@ def test_get_agents_pagination(
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=2,
|
||||
total_items=15,
|
||||
total_pages=3,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentsResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 5
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
@@ -365,7 +351,7 @@ def test_get_agents_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agents")
|
||||
mock_db_call.assert_not_called()
|
||||
|
||||
|
||||
@@ -373,7 +359,7 @@ def test_get_agent_details(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentDetails(
|
||||
mocked_value = store_model.StoreAgentDetails(
|
||||
store_listing_version_id="test-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
@@ -388,46 +374,46 @@ def test_get_agent_details(
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=FIXED_NOW,
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agent_details")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents/creator1/test-agent")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreAgentDetails.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreAgentDetails.model_validate(response.json())
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.creator == "creator1"
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agt_details")
|
||||
mock_db_call.assert_called_once_with(username="creator1", agent_name="test-agent")
|
||||
mock_db_call.assert_called_once_with(
|
||||
username="creator1", agent_name="test-agent", include_changelog=False
|
||||
)
|
||||
|
||||
|
||||
def test_get_creators_defaults(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.CreatorsResponse(
|
||||
mocked_value = store_model.CreatorsResponse(
|
||||
creators=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=0,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=10,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creators")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creators")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.CreatorsResponse.model_validate(response.json())
|
||||
assert data.pagination.total_pages == 0
|
||||
assert data.creators == []
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
@@ -441,9 +427,9 @@ def test_get_creators_pagination(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.CreatorsResponse(
|
||||
mocked_value = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
backend.server.v2.store.model.Creator(
|
||||
store_model.Creator(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
description=f"Creator {i} description",
|
||||
@@ -455,22 +441,20 @@ def test_get_creators_pagination(
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=2,
|
||||
total_items=15,
|
||||
total_pages=3,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creators")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creators?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.CreatorsResponse.model_validate(response.json())
|
||||
assert len(data.creators) == 5
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
@@ -495,7 +479,7 @@ def test_get_creators_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creators")
|
||||
mock_db_call.assert_not_called()
|
||||
|
||||
|
||||
@@ -503,7 +487,7 @@ def test_get_creator_details(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.CreatorDetails(
|
||||
mocked_value = store_model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
description="Test creator description",
|
||||
@@ -513,13 +497,15 @@ def test_get_creator_details(
|
||||
agent_runs=1000,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creator_details")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.store.db.get_store_creator_details"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creator/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorDetails.model_validate(response.json())
|
||||
data = store_model.CreatorDetails.model_validate(response.json())
|
||||
assert data.username == "creator1"
|
||||
assert data.name == "Test User"
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
@@ -532,9 +518,10 @@ def test_get_submissions_success(
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
mocked_value = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
@@ -550,22 +537,20 @@ def test_get_submissions_success(
|
||||
categories=["test-category"],
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_submissions")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/submissions")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreSubmissionsResponse.model_validate(response.json())
|
||||
assert len(data.submissions) == 1
|
||||
assert data.submissions[0].name == "Test Agent"
|
||||
assert data.pagination.current_page == 1
|
||||
@@ -579,24 +564,22 @@ def test_get_submissions_pagination(
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
mocked_value = store_model.StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=2,
|
||||
total_items=10,
|
||||
total_pages=2,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_submissions")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/submissions?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
data = store_model.StoreSubmissionsResponse.model_validate(response.json())
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
@@ -618,5 +601,5 @@ def test_get_submissions_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_submissions")
|
||||
mock_db_call.assert_not_called()
|
||||
@@ -8,10 +8,11 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.store import cache as store_cache
|
||||
from backend.server.v2.store.model import StoreAgent, StoreAgentsResponse
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from .model import StoreAgent, StoreAgentsResponse
|
||||
|
||||
|
||||
class TestCacheDeletion:
|
||||
"""Test cache deletion functionality for store routes."""
|
||||
@@ -43,7 +44,7 @@ class TestCacheDeletion:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agents",
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
@@ -152,7 +153,7 @@ class TestCacheDeletion:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agents",
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
@@ -203,7 +204,7 @@ class TestCacheDeletion:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agents",
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
@@ -28,12 +28,21 @@ from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.api.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
CreateGraph,
|
||||
GraphExecutionSource,
|
||||
RequestTopUp,
|
||||
SetGraphActiveVersion,
|
||||
TimezoneResponse,
|
||||
UpdatePermissionsRequest,
|
||||
UpdateTimezoneRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth import api_key as api_key_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -79,19 +88,6 @@ from backend.monitoring.instrumentation import (
|
||||
record_graph_execution,
|
||||
record_graph_operation,
|
||||
)
|
||||
from backend.server.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
CreateGraph,
|
||||
GraphExecutionSource,
|
||||
RequestTopUp,
|
||||
SetGraphActiveVersion,
|
||||
TimezoneResponse,
|
||||
UpdatePermissionsRequest,
|
||||
UpdateTimezoneRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
@@ -105,6 +101,10 @@ from backend.util.timezone_utils import (
|
||||
)
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
from .library import db as library_db
|
||||
from .library import model as library_model
|
||||
from .store.model import StoreAgentDetails
|
||||
|
||||
|
||||
def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
"""Create standardized file size error response."""
|
||||
@@ -118,76 +118,9 @@ settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def hide_activity_summaries_if_disabled(
|
||||
executions: list[execution_db.GraphExecutionMeta], user_id: str
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
"""Hide activity summaries and scores if AI_ACTIVITY_STATUS feature is disabled."""
|
||||
if await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
return executions # Return as-is if feature is enabled
|
||||
|
||||
# Filter out activity features if disabled
|
||||
filtered_executions = []
|
||||
for execution in executions:
|
||||
if execution.stats:
|
||||
filtered_stats = execution.stats.without_activity_features()
|
||||
execution = execution.model_copy(update={"stats": filtered_stats})
|
||||
filtered_executions.append(execution)
|
||||
return filtered_executions
|
||||
|
||||
|
||||
async def hide_activity_summary_if_disabled(
|
||||
execution: execution_db.GraphExecution | execution_db.GraphExecutionWithNodes,
|
||||
user_id: str,
|
||||
) -> execution_db.GraphExecution | execution_db.GraphExecutionWithNodes:
|
||||
"""Hide activity summary and score for a single execution if AI_ACTIVITY_STATUS feature is disabled."""
|
||||
if await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
return execution # Return as-is if feature is enabled
|
||||
|
||||
# Filter out activity features if disabled
|
||||
if execution.stats:
|
||||
filtered_stats = execution.stats.without_activity_features()
|
||||
return execution.model_copy(update={"stats": filtered_stats})
|
||||
return execution
|
||||
|
||||
|
||||
async def _update_library_agent_version_and_settings(
|
||||
user_id: str, agent_graph: graph_db.GraphModel
|
||||
) -> library_db.library_model.LibraryAgent:
|
||||
# Keep the library agent up to date with the new active version
|
||||
library = await library_db.update_agent_version_in_library(
|
||||
user_id, agent_graph.id, agent_graph.version
|
||||
)
|
||||
# If the graph has HITL node, initialize the setting if it's not already set.
|
||||
if (
|
||||
agent_graph.has_human_in_the_loop
|
||||
and library.settings.human_in_the_loop_safe_mode is None
|
||||
):
|
||||
await library_db.update_library_agent_settings(
|
||||
user_id=user_id,
|
||||
agent_id=library.id,
|
||||
settings=library.settings.model_copy(
|
||||
update={"human_in_the_loop_safe_mode": True}
|
||||
),
|
||||
)
|
||||
return library
|
||||
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter()
|
||||
|
||||
v1_router.include_router(
|
||||
backend.server.integrations.router.router,
|
||||
prefix="/integrations",
|
||||
tags=["integrations"],
|
||||
)
|
||||
|
||||
v1_router.include_router(
|
||||
backend.server.routers.analytics.router,
|
||||
prefix="/analytics",
|
||||
tags=["analytics"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Auth #############################
|
||||
@@ -953,6 +886,28 @@ async def set_graph_active_version(
|
||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||
|
||||
|
||||
async def _update_library_agent_version_and_settings(
|
||||
user_id: str, agent_graph: graph_db.GraphModel
|
||||
) -> library_model.LibraryAgent:
|
||||
# Keep the library agent up to date with the new active version
|
||||
library = await library_db.update_agent_version_in_library(
|
||||
user_id, agent_graph.id, agent_graph.version
|
||||
)
|
||||
# If the graph has HITL node, initialize the setting if it's not already set.
|
||||
if (
|
||||
agent_graph.has_human_in_the_loop
|
||||
and library.settings.human_in_the_loop_safe_mode is None
|
||||
):
|
||||
await library_db.update_library_agent_settings(
|
||||
user_id=user_id,
|
||||
agent_id=library.id,
|
||||
settings=library.settings.model_copy(
|
||||
update={"human_in_the_loop_safe_mode": True}
|
||||
),
|
||||
)
|
||||
return library
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
path="/graphs/{graph_id}/settings",
|
||||
summary="Update graph settings",
|
||||
@@ -1155,6 +1110,23 @@ async def list_graph_executions(
|
||||
)
|
||||
|
||||
|
||||
async def hide_activity_summaries_if_disabled(
|
||||
executions: list[execution_db.GraphExecutionMeta], user_id: str
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
"""Hide activity summaries and scores if AI_ACTIVITY_STATUS feature is disabled."""
|
||||
if await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
return executions # Return as-is if feature is enabled
|
||||
|
||||
# Filter out activity features if disabled
|
||||
filtered_executions = []
|
||||
for execution in executions:
|
||||
if execution.stats:
|
||||
filtered_stats = execution.stats.without_activity_features()
|
||||
execution = execution.model_copy(update={"stats": filtered_stats})
|
||||
filtered_executions.append(execution)
|
||||
return filtered_executions
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
summary="Get execution details",
|
||||
@@ -1197,6 +1169,21 @@ async def get_graph_execution(
|
||||
return result
|
||||
|
||||
|
||||
async def hide_activity_summary_if_disabled(
|
||||
execution: execution_db.GraphExecution | execution_db.GraphExecutionWithNodes,
|
||||
user_id: str,
|
||||
) -> execution_db.GraphExecution | execution_db.GraphExecutionWithNodes:
|
||||
"""Hide activity summary and score for a single execution if AI_ACTIVITY_STATUS feature is disabled."""
|
||||
if await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
return execution # Return as-is if feature is enabled
|
||||
|
||||
# Filter out activity features if disabled
|
||||
if execution.stats:
|
||||
filtered_stats = execution.stats.without_activity_features()
|
||||
return execution.model_copy(update={"stats": filtered_stats})
|
||||
return execution
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
path="/executions/{graph_exec_id}",
|
||||
summary="Delete graph execution",
|
||||
@@ -1257,7 +1244,7 @@ async def enable_execution_sharing(
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
|
||||
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
|
||||
return ShareResponse(share_url=share_url, share_token=share_token)
|
||||
@@ -11,13 +11,13 @@ import starlette.datastructures
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.v1 as v1_routes
|
||||
from backend.data.credit import AutoTopUpConfig
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.server.routers.v1 import upload_file
|
||||
|
||||
from .v1 import upload_file, v1_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_routes.v1_router)
|
||||
app.include_router(v1_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_or_create_user",
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
@@ -71,7 +71,7 @@ def test_update_user_email_route(
|
||||
) -> None:
|
||||
"""Test update user email endpoint"""
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.update_user_email",
|
||||
"backend.api.features.v1.update_user_email",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_get_graph_blocks(
|
||||
|
||||
# Mock get_blocks
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_blocks",
|
||||
"backend.api.features.v1.get_blocks",
|
||||
return_value={"test-block": lambda: mock_block},
|
||||
)
|
||||
|
||||
@@ -146,7 +146,7 @@ def test_execute_graph_block(
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_block",
|
||||
"backend.api.features.v1.get_block",
|
||||
return_value=mock_block,
|
||||
)
|
||||
|
||||
@@ -155,7 +155,7 @@ def test_execute_graph_block(
|
||||
mock_user.timezone = "UTC"
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_by_id",
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
@@ -181,7 +181,7 @@ def test_execute_graph_block_not_found(
|
||||
) -> None:
|
||||
"""Test execute block with non-existent block"""
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_block",
|
||||
"backend.api.features.v1.get_block",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
@@ -200,7 +200,7 @@ def test_get_user_credits(
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model.get_credits = AsyncMock(return_value=1000)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
"backend.api.features.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -227,7 +227,7 @@ def test_request_top_up(
|
||||
return_value="https://checkout.example.com/session123"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
"backend.api.features.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -254,7 +254,7 @@ def test_get_auto_top_up(
|
||||
mock_config = AutoTopUpConfig(threshold=100, amount=500)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_auto_top_up",
|
||||
"backend.api.features.v1.get_auto_top_up",
|
||||
return_value=mock_config,
|
||||
)
|
||||
|
||||
@@ -279,7 +279,7 @@ def test_configure_auto_top_up(
|
||||
"""Test configure auto top-up endpoint - this test would have caught the enum casting bug"""
|
||||
# Mock the set_auto_top_up function to avoid database operations
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.set_auto_top_up",
|
||||
"backend.api.features.v1.set_auto_top_up",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
@@ -289,7 +289,7 @@ def test_configure_auto_top_up(
|
||||
mock_credit_model.top_up_credits.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
"backend.api.features.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -311,7 +311,7 @@ def test_configure_auto_top_up_validation_errors(
|
||||
) -> None:
|
||||
"""Test configure auto top-up endpoint validation"""
|
||||
# Mock set_auto_top_up to avoid database operations for successful case
|
||||
mocker.patch("backend.server.routers.v1.set_auto_top_up")
|
||||
mocker.patch("backend.api.features.v1.set_auto_top_up")
|
||||
|
||||
# Mock credit model to avoid Stripe API calls for the successful case
|
||||
mock_credit_model = mocker.AsyncMock()
|
||||
@@ -319,7 +319,7 @@ def test_configure_auto_top_up_validation_errors(
|
||||
mock_credit_model.top_up_credits.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
"backend.api.features.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -393,7 +393,7 @@ def test_get_graph(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.graph_db.get_graph",
|
||||
"backend.api.features.v1.graph_db.get_graph",
|
||||
return_value=mock_graph,
|
||||
)
|
||||
|
||||
@@ -415,7 +415,7 @@ def test_get_graph_not_found(
|
||||
) -> None:
|
||||
"""Test get graph with non-existent ID"""
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.graph_db.get_graph",
|
||||
"backend.api.features.v1.graph_db.get_graph",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
@@ -443,15 +443,15 @@ def test_delete_graph(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.graph_db.get_graph",
|
||||
"backend.api.features.v1.graph_db.get_graph",
|
||||
return_value=mock_graph,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.on_graph_deactivate",
|
||||
"backend.api.features.v1.on_graph_deactivate",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.graph_db.delete_graph",
|
||||
"backend.api.features.v1.graph_db.delete_graph",
|
||||
return_value=3, # Number of versions deleted
|
||||
)
|
||||
|
||||
@@ -498,8 +498,8 @@ async def test_upload_file_success(test_user_id: str):
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.api.features.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
@@ -550,8 +550,8 @@ async def test_upload_file_no_filename(test_user_id: str):
|
||||
),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.api.features.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
@@ -610,7 +610,7 @@ async def test_upload_file_virus_scan_failure(test_user_id: str):
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan:
|
||||
with patch("backend.api.features.v1.scan_content_safe") as mock_scan:
|
||||
# Mock virus scan to raise exception
|
||||
mock_scan.side_effect = RuntimeError("Virus detected!")
|
||||
|
||||
@@ -631,8 +631,8 @@ async def test_upload_file_cloud_storage_failure(test_user_id: str):
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.api.features.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
@@ -678,8 +678,8 @@ async def test_upload_file_gcs_not_configured_fallback(test_user_id: str):
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.api.features.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user