mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-06 22:03:59 -05:00
Merge origin/dev into hackathon/copilot
Resolved conflicts from api restructuring: - backend/server/v2/* -> backend/api/features/* - Updated imports to use new paths - Kept chat/copilot functionality with new structure - Accepted openapi.json from dev (regenerate after merge) - Resolved useCredentialsInput naming conflict (singular) - NavbarView.tsx merged into Navbar.tsx
This commit is contained in:
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
@@ -57,6 +57,9 @@ class APIKeySmith:
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
if not raw_key.startswith(self.PREFIX):
|
||||
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
@@ -1,29 +1,25 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.server.v2.myroute import router
|
||||
from backend.api.features.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.server.v2.myroute import router
|
||||
from backend.api.features.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.conn_manager import ConnectionManager
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidTokenError,
|
||||
OAuthAccessTokenInfo,
|
||||
validate_access_token,
|
||||
)
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Middleware for API key authentication only"""
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||
)
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> OAuthAccessTokenInfo:
|
||||
"""Middleware for OAuth access token authentication only"""
|
||||
if bearer is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def require_auth(
|
||||
api_key: str | None = Security(api_key_header),
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
"""
|
||||
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||
|
||||
Supports two authentication methods, which are checked in order:
|
||||
1. X-API-Key header (existing API key authentication)
|
||||
2. Authorization: Bearer <token> header (OAuth access token)
|
||||
|
||||
Returns:
|
||||
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||
"""
|
||||
# Try API key first
|
||||
if api_key is not None:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
# Try OAuth bearer token
|
||||
if bearer is not None:
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
return token_info
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
# No credentials provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication. Provide API key or access token.",
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permission
|
||||
@@ -16,7 +16,9 @@ from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
@@ -28,8 +30,6 @@ from backend.data.model import (
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.integrations.models import get_all_provider_names
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -255,7 +255,7 @@ def _get_oauth_handler_for_external(
|
||||
|
||||
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
||||
async def list_providers(
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[ProviderInfo]:
|
||||
@@ -319,7 +319,7 @@ async def list_providers(
|
||||
async def initiate_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthInitiateRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthInitiateResponse:
|
||||
@@ -337,7 +337,10 @@ async def initiate_oauth(
|
||||
if not validate_callback_url(request.callback_url):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
detail=(
|
||||
f"Callback URL origin is not allowed. "
|
||||
f"Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
),
|
||||
)
|
||||
|
||||
# Validate provider
|
||||
@@ -359,13 +362,15 @@ async def initiate_oauth(
|
||||
)
|
||||
|
||||
# Store state token with external flow metadata
|
||||
# Note: initiated_by_api_key_id is only available for API key auth, not OAuth
|
||||
api_key_id = getattr(auth, "id", None) if auth.type == "api_key" else None
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
||||
scopes=request.scopes,
|
||||
callback_url=request.callback_url,
|
||||
state_metadata=request.state_metadata,
|
||||
initiated_by_api_key_id=api_key.id,
|
||||
initiated_by_api_key_id=api_key_id,
|
||||
)
|
||||
|
||||
# Build login URL
|
||||
@@ -393,7 +398,7 @@ async def initiate_oauth(
|
||||
async def complete_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthCompleteRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthCompleteResponse:
|
||||
@@ -406,7 +411,7 @@ async def complete_oauth(
|
||||
"""
|
||||
# Verify state token
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
api_key.user_id, request.state_token, provider
|
||||
auth.user_id, request.state_token, provider
|
||||
)
|
||||
|
||||
if not valid_state:
|
||||
@@ -453,7 +458,7 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
# Store credentials
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
|
||||
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
||||
|
||||
@@ -470,7 +475,7 @@ async def complete_oauth(
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
async def list_credentials(
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -479,7 +484,7 @@ async def list_credentials(
|
||||
|
||||
Returns metadata about each credential without exposing sensitive tokens.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
@@ -499,7 +504,7 @@ async def list_credentials(
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -507,7 +512,7 @@ async def list_credentials_by_provider(
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_creds_by_provider(
|
||||
api_key.user_id, provider
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
@@ -536,7 +541,7 @@ async def create_credential(
|
||||
CreateUserPasswordCredentialRequest,
|
||||
CreateHostScopedCredentialRequest,
|
||||
] = Body(..., discriminator="type"),
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> CreateCredentialResponse:
|
||||
@@ -591,7 +596,7 @@ async def create_credential(
|
||||
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
raise HTTPException(
|
||||
@@ -623,7 +628,7 @@ class DeleteCredentialResponse(BaseModel):
|
||||
async def delete_credential(
|
||||
provider: Annotated[str, Path(title="The provider")],
|
||||
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
||||
),
|
||||
) -> DeleteCredentialResponse:
|
||||
@@ -634,7 +639,7 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
@@ -645,6 +650,6 @@ async def delete_credential(
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
await creds_manager.delete(api_key.user_id, cred_id)
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
||||
@@ -5,46 +5,60 @@ from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .integrations import integrations_router
|
||||
from .tools import tools_router
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
v1_router = APIRouter()
|
||||
|
||||
|
||||
class NodeOutput(TypedDict):
|
||||
key: str
|
||||
value: Any
|
||||
v1_router.include_router(integrations_router)
|
||||
v1_router.include_router(tools_router)
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
class UserInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: Optional[str]
|
||||
email: str
|
||||
timezone: str = Field(
|
||||
description="The user's last known timezone (e.g. 'Europe/Amsterdam'), "
|
||||
"or 'not-set' if not set"
|
||||
)
|
||||
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: list[NodeOutput]
|
||||
@v1_router.get(
|
||||
path="/me",
|
||||
tags=["user", "meta"],
|
||||
)
|
||||
async def get_user_info(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.IDENTITY)
|
||||
),
|
||||
) -> UserInfoResponse:
|
||||
user = await user_db.get_user_by_id(auth.user_id)
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
return UserInfoResponse(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
timezone=user.timezone,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -65,7 +79,9 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||
),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
@@ -85,12 +101,14 @@ async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_GRAPH)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
@@ -100,6 +118,19 @@ async def execute_graph(
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||
tags=["graphs"],
|
||||
@@ -107,10 +138,12 @@ async def execute_graph(
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_GRAPH)
|
||||
),
|
||||
) -> GraphExecutionResult:
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
@@ -122,7 +155,7 @@ async def get_graph_execution_results(
|
||||
if not await graph_db.get_graph(
|
||||
graph_id=graph_exec.graph_id,
|
||||
version=graph_exec.graph_version,
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
):
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
@@ -14,19 +14,19 @@ from fastapi import APIRouter, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.server.v2.chat.tools.models import ToolResponseBase
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
|
||||
# Note: We use Security() as a function parameter dependency (auth: APIAuthorizationInfo = Security(...))
|
||||
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
||||
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
|
||||
# while still enforcing auth AND giving us access to auth for extracting user_id.
|
||||
|
||||
|
||||
# Request models
|
||||
@@ -80,7 +80,9 @@ def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
)
|
||||
async def find_agent(
|
||||
request: FindAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for agents in the marketplace based on capabilities and user needs.
|
||||
@@ -91,9 +93,9 @@ async def find_agent(
|
||||
Returns:
|
||||
List of matching agents or no results response
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
result = await find_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
session=session,
|
||||
query=request.query,
|
||||
)
|
||||
@@ -105,7 +107,9 @@ async def find_agent(
|
||||
)
|
||||
async def run_agent(
|
||||
request: RunAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run or schedule an agent from the marketplace.
|
||||
@@ -129,9 +133,9 @@ async def run_agent(
|
||||
- execution_started: If agent was run or scheduled successfully
|
||||
- error: If something went wrong
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
result = await run_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
session=session,
|
||||
username_agent_slug=request.username_agent_slug,
|
||||
inputs=request.inputs,
|
||||
@@ -6,9 +6,10 @@ from fastapi import APIRouter, Body, Security
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .model import AddUserCreditsResponse, UserHistoryResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
|
||||
import backend.server.v2.admin.model as admin_model
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .credit_admin_routes import router as credit_admin_router
|
||||
from .model import UserHistoryResponse
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(credit_admin_routes.router)
|
||||
app.include_router(credit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
@@ -30,7 +31,7 @@ def setup_app_admin_auth(mock_jwt_admin):
|
||||
|
||||
|
||||
def test_add_user_credits_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
admin_user_id: str,
|
||||
target_user_id: str,
|
||||
@@ -42,7 +43,7 @@ def test_add_user_credits_success(
|
||||
return_value=(1500, "transaction-123-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -84,7 +85,7 @@ def test_add_user_credits_success(
|
||||
|
||||
|
||||
def test_add_user_credits_negative_amount(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test credit deduction by admin (negative amount)"""
|
||||
@@ -94,7 +95,7 @@ def test_add_user_credits_negative_amount(
|
||||
return_value=(200, "transaction-456-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -119,12 +120,12 @@ def test_add_user_credits_negative_amount(
|
||||
|
||||
|
||||
def test_get_user_history_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful retrieval of user credit history"""
|
||||
# Mock the admin_get_user_history function
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[
|
||||
UserTransaction(
|
||||
user_id="user-1",
|
||||
@@ -150,7 +151,7 @@ def test_get_user_history_success(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -170,12 +171,12 @@ def test_get_user_history_success(
|
||||
|
||||
|
||||
def test_get_user_history_with_filters(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test user credit history with search and filter parameters"""
|
||||
# Mock the admin_get_user_history function
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[
|
||||
UserTransaction(
|
||||
user_id="user-3",
|
||||
@@ -194,7 +195,7 @@ def test_get_user_history_with_filters(
|
||||
)
|
||||
|
||||
mock_get_history = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -230,12 +231,12 @@ def test_get_user_history_with_filters(
|
||||
|
||||
|
||||
def test_get_user_history_empty_results(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test user credit history with no results"""
|
||||
# Mock empty history response
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
@@ -246,7 +247,7 @@ def test_get_user_history_empty_results(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -7,10 +7,10 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.embeddings as store_embeddings
|
||||
import backend.server.v2.store.model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.embeddings as store_embeddings
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,7 +25,7 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
@@ -49,7 +49,7 @@ async def get_admin_listings_with_versions(
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
"""
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_admin_listings_with_versions(
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
@@ -69,11 +69,11 @@ async def get_admin_listings_with_versions(
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -88,12 +88,10 @@ async def review_submission(
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = (
|
||||
await backend.server.v2.store.db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
@@ -137,7 +135,7 @@ async def admin_download_agent_file(
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent_as_admin(
|
||||
graph_data = await store_db.get_agent_as_admin(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
@@ -6,10 +6,11 @@ from typing import Annotated
|
||||
import fastapi
|
||||
import pydantic
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
import backend.data.analytics
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
router = fastapi.APIRouter(dependencies=[fastapi.Security(requires_user)])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -6,17 +6,20 @@ from typing import Sequence
|
||||
|
||||
import prisma
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.db as store_db
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
@@ -26,8 +29,6 @@ from backend.server.v2.builder.model import (
|
||||
ProviderResponse,
|
||||
SearchEntry,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
@@ -2,8 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.data.block import BlockInfo
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
@@ -4,11 +4,12 @@ from typing import Annotated, Sequence
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
|
||||
import backend.server.v2.builder.db as builder_db
|
||||
import backend.server.v2.builder.model as builder_model
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import db as builder_db
|
||||
from . import model as builder_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
@@ -21,11 +21,12 @@ from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.server.v2.chat import db as chat_db
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.model import (
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
@@ -9,10 +9,11 @@ from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.chat.service as chat_service
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
@@ -7,18 +7,23 @@ import orjson
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
import backend.server.v2.chat.config
|
||||
import backend.server.v2.chat.db as chat_db
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.server.v2.chat.model import ChatMessage, ChatSession, Usage
|
||||
from backend.server.v2.chat.model import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
create_chat_session as model_create_chat_session,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.server.v2.chat.model import get_chat_session, upsert_chat_session
|
||||
from backend.server.v2.chat.response_model import (
|
||||
from .response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
@@ -29,12 +34,11 @@ from backend.server.v2.chat.response_model import (
|
||||
StreamToolExecutionResult,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.server.v2.chat.tools import execute_tool, tools
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from .tools import execute_tool, tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = backend.server.v2.chat.config.ChatConfig()
|
||||
config = ChatConfig()
|
||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
|
||||
@@ -161,9 +165,7 @@ async def get_user_sessions(
|
||||
"""
|
||||
Get all chat sessions for a user.
|
||||
"""
|
||||
from backend.server.v2.chat.model import (
|
||||
get_user_sessions as model_get_user_sessions,
|
||||
)
|
||||
from .model import get_user_sessions as model_get_user_sessions
|
||||
|
||||
return await model_get_user_sessions(user_id, limit, offset)
|
||||
|
||||
@@ -3,8 +3,8 @@ from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.chat.service as chat_service
|
||||
from backend.server.v2.chat.response_model import (
|
||||
from . import service as chat_service
|
||||
from .response_model import (
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
@@ -17,7 +17,7 @@ from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
# Initialize tool instances
|
||||
add_understanding_tool = AddUnderstandingTool()
|
||||
@@ -5,6 +5,8 @@ from os import getenv
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
@@ -13,8 +15,6 @@ from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -3,17 +3,18 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -5,14 +5,21 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
ErrorResponse,
|
||||
@@ -23,20 +30,15 @@ from backend.server.v2.chat.tools.models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from backend.server.v2.chat.tools.utils import (
|
||||
from .utils import (
|
||||
check_user_has_required_credentials,
|
||||
extract_credentials_from_schema,
|
||||
fetch_graph_from_store_slug,
|
||||
get_or_create_library_agent,
|
||||
match_user_credentials_to_graph,
|
||||
)
|
||||
from backend.server.v2.library import db as library_db
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -3,13 +3,13 @@ import uuid
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
from ._test_data import (
|
||||
make_session,
|
||||
setup_firecrawl_test_data,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.run_agent import RunAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
setup_llm_test_data = setup_llm_test_data
|
||||
@@ -3,13 +3,13 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.library import db as library_db
|
||||
from backend.server.v2.library import model as library_model
|
||||
from backend.server.v2.store import db as store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -7,9 +7,10 @@ import pytest_mock
|
||||
from prisma.enums import ReviewStatus
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.server.rest_api import handle_internal_http_error
|
||||
from backend.server.v2.executions.review.model import PendingHumanReviewModel
|
||||
from backend.server.v2.executions.review.routes import router
|
||||
from backend.api.rest_api import handle_internal_http_error
|
||||
|
||||
from .model import PendingHumanReviewModel
|
||||
from .routes import router
|
||||
|
||||
# Using a fixed timestamp for reproducible tests
|
||||
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
@@ -54,13 +55,13 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel:
|
||||
|
||||
|
||||
def test_get_pending_reviews_empty(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews when none exist"""
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||
)
|
||||
mock_get_reviews.return_value = []
|
||||
|
||||
@@ -72,14 +73,14 @@ def test_get_pending_reviews_empty(
|
||||
|
||||
|
||||
def test_get_pending_reviews_with_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews with data"""
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||
)
|
||||
mock_get_reviews.return_value = [sample_pending_review]
|
||||
|
||||
@@ -94,14 +95,14 @@ def test_get_pending_reviews_with_data(
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews for specific execution"""
|
||||
mock_get_graph_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_get_graph_execution.return_value = {
|
||||
"id": "test_graph_exec_456",
|
||||
@@ -109,7 +110,7 @@ def test_get_pending_reviews_for_execution_success(
|
||||
}
|
||||
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews.return_value = [sample_pending_review]
|
||||
|
||||
@@ -121,24 +122,23 @@ def test_get_pending_reviews_for_execution_success(
|
||||
assert data[0]["graph_exec_id"] == "test_graph_exec_456"
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_access_denied(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
def test_get_pending_reviews_for_execution_not_available(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test access denied when user doesn't own the execution"""
|
||||
mock_get_graph_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_get_graph_execution.return_value = None
|
||||
|
||||
response = client.get("/api/review/execution/test_graph_exec_456")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "Access denied" in response.json()["detail"]
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_process_review_action_approve_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -146,12 +146,12 @@ def test_process_review_action_approve_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# Create approved review for return
|
||||
approved_review = PendingHumanReviewModel(
|
||||
@@ -174,11 +174,11 @@ def test_process_review_action_approve_success(
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
mocker.patch("backend.server.v2.executions.review.routes.add_graph_execution")
|
||||
mocker.patch("backend.api.features.executions.review.routes.add_graph_execution")
|
||||
|
||||
request_data = {
|
||||
"reviews": [
|
||||
@@ -202,7 +202,7 @@ def test_process_review_action_approve_success(
|
||||
|
||||
|
||||
def test_process_review_action_reject_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -210,12 +210,12 @@ def test_process_review_action_reject_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
rejected_review = PendingHumanReviewModel(
|
||||
node_exec_id="test_node_123",
|
||||
@@ -237,7 +237,7 @@ def test_process_review_action_reject_success(
|
||||
mock_process_all_reviews.return_value = {"test_node_123": rejected_review}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
@@ -262,7 +262,7 @@ def test_process_review_action_reject_success(
|
||||
|
||||
|
||||
def test_process_review_action_mixed_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -289,12 +289,12 @@ def test_process_review_action_mixed_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review, second_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# Create approved version of first review
|
||||
approved_review = PendingHumanReviewModel(
|
||||
@@ -338,7 +338,7 @@ def test_process_review_action_mixed_success(
|
||||
}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
@@ -369,7 +369,7 @@ def test_process_review_action_mixed_success(
|
||||
|
||||
|
||||
def test_process_review_action_empty_request(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when no reviews provided"""
|
||||
@@ -386,19 +386,19 @@ def test_process_review_action_empty_request(
|
||||
|
||||
|
||||
def test_process_review_action_review_not_found(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when review is not found"""
|
||||
# Mock the functions that extract graph execution ID from the request
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [] # No reviews found
|
||||
|
||||
# Mock process_all_reviews to simulate not finding reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# This should raise a ValueError with "Reviews not found" message based on the data/human_review.py logic
|
||||
mock_process_all_reviews.side_effect = ValueError(
|
||||
@@ -422,20 +422,20 @@ def test_process_review_action_review_not_found(
|
||||
|
||||
|
||||
def test_process_review_action_partial_failure(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test handling of partial failures in review processing"""
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
# Mock partial failure in processing
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.side_effect = ValueError("Some reviews failed validation")
|
||||
|
||||
@@ -456,20 +456,20 @@ def test_process_review_action_partial_failure(
|
||||
|
||||
|
||||
def test_process_review_action_invalid_node_exec_id(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test failure when trying to process review with invalid node execution ID"""
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
# Mock validation failure - this should return 400, not 500
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.side_effect = ValueError(
|
||||
"Invalid node execution ID format"
|
||||
@@ -13,11 +13,8 @@ from backend.data.human_review import (
|
||||
process_all_reviews_for_execution,
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
ReviewRequest,
|
||||
ReviewResponse,
|
||||
)
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,8 +67,7 @@ async def list_pending_reviews(
|
||||
response_model=List[PendingHumanReviewModel],
|
||||
responses={
|
||||
200: {"description": "List of pending reviews for the execution"},
|
||||
400: {"description": "Invalid graph execution ID"},
|
||||
403: {"description": "Access denied to graph execution"},
|
||||
404: {"description": "Graph execution not found"},
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
@@ -94,7 +90,7 @@ async def list_pending_reviews_for_execution(
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 403: If user doesn't own the graph execution
|
||||
- 404: If the graph execution doesn't exist or isn't owned by this user
|
||||
- 500: If authentication fails or database error occurs
|
||||
|
||||
Note:
|
||||
@@ -108,8 +104,8 @@ async def list_pending_reviews_for_execution(
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to graph execution",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
@@ -17,6 +17,8 @@ from fastapi import (
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.data.graph import NodeModel, get_graph, set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
@@ -45,13 +47,6 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.integrations.models import (
|
||||
ProviderConstants,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
)
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
from backend.util.exceptions import (
|
||||
GraphNotInLibraryError,
|
||||
MissingConfigError,
|
||||
@@ -60,6 +55,8 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
@@ -4,16 +4,14 @@ from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
import prisma.errors
|
||||
import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.image_gen as store_image_gen
|
||||
import backend.api.features.store.media as store_media
|
||||
import backend.data.graph as graph_db
|
||||
import backend.data.integrations as integrations_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
@@ -28,6 +26,8 @@ from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
@@ -538,6 +538,7 @@ async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
auto_update_version: Optional[bool] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
is_favorite: Optional[bool] = None,
|
||||
is_archived: Optional[bool] = None,
|
||||
is_deleted: Optional[Literal[False]] = None,
|
||||
@@ -550,6 +551,7 @@ async def update_library_agent(
|
||||
library_agent_id: The ID of the LibraryAgent to update.
|
||||
user_id: The owner of this LibraryAgent.
|
||||
auto_update_version: Whether the agent should auto-update to active version.
|
||||
graph_version: Specific graph version to update to.
|
||||
is_favorite: Whether this agent is marked as a favorite.
|
||||
is_archived: Whether this agent is archived.
|
||||
settings: User-specific settings for this library agent.
|
||||
@@ -563,8 +565,8 @@ async def update_library_agent(
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating library agent {library_agent_id} for user {user_id} with "
|
||||
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
|
||||
f"is_archived={is_archived}, settings={settings}"
|
||||
f"auto_update_version={auto_update_version}, graph_version={graph_version}, "
|
||||
f"is_favorite={is_favorite}, is_archived={is_archived}, settings={settings}"
|
||||
)
|
||||
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
|
||||
if auto_update_version is not None:
|
||||
@@ -581,10 +583,23 @@ async def update_library_agent(
|
||||
update_fields["isDeleted"] = is_deleted
|
||||
if settings is not None:
|
||||
update_fields["settings"] = SafeJson(settings.model_dump())
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
try:
|
||||
# If graph_version is provided, update to that specific version
|
||||
if graph_version is not None:
|
||||
# Get the current agent to find its graph_id
|
||||
agent = await get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
# Update to the specified version using existing function
|
||||
return await update_agent_version_in_library(
|
||||
user_id=user_id,
|
||||
agent_graph_id=agent.graph_id,
|
||||
agent_graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Otherwise, just update the simple fields
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id},
|
||||
data=update_fields,
|
||||
@@ -1,16 +1,15 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
from . import db
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agents(mocker):
|
||||
@@ -88,7 +87,7 @@ async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.server.v2.library.db.transaction")
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
@@ -151,7 +150,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.server.v2.library.db.graph_db")
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
@@ -159,7 +158,9 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_graph_db.get_graph = mocker.AsyncMock(return_value=mock_graph_model)
|
||||
|
||||
# Mock the model conversion
|
||||
mock_from_db = mocker.patch("backend.server.v2.library.model.LibraryAgent.from_db")
|
||||
mock_from_db = mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db"
|
||||
)
|
||||
mock_from_db.return_value = mocker.Mock()
|
||||
|
||||
# Call function
|
||||
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
@@ -385,6 +385,9 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
auto_update_version: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Auto-update the agent version"
|
||||
)
|
||||
graph_version: Optional[int] = pydantic.Field(
|
||||
default=None, description="Specific graph version to update to"
|
||||
)
|
||||
is_favorite: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Mark the agent as a favorite"
|
||||
)
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -6,12 +6,13 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
from prisma.enums import OnboardingStep
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
from backend.data.onboarding import complete_onboarding_step
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .. import db as library_db
|
||||
from .. import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
@@ -284,6 +285,7 @@ async def update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
graph_version=payload.graph_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
@@ -4,8 +4,6 @@ from typing import Any, Optional
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
@@ -17,6 +15,9 @@ from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import db
|
||||
from .. import model as models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
@@ -7,10 +7,11 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from backend.server.v2.library.routes import router as library_router
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import model as library_model
|
||||
from .routes import router as library_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(library_router)
|
||||
|
||||
@@ -86,7 +87,7 @@ async def test_get_library_agents_success(
|
||||
total_items=2, total_pages=1, current_page=1, page_size=50
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
@@ -112,7 +113,7 @@ async def test_get_library_agents_success(
|
||||
|
||||
|
||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
@@ -161,7 +162,7 @@ async def test_get_favorite_library_agents_success(
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
"backend.api.features.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
@@ -184,7 +185,7 @@ def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
"backend.api.features.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
@@ -223,11 +224,11 @@ def test_add_agent_to_library_success(
|
||||
)
|
||||
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
"backend.api.features.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.return_value = mock_library_agent
|
||||
mock_complete_onboarding = mocker.patch(
|
||||
"backend.server.v2.library.routes.agents.complete_onboarding_step",
|
||||
"backend.api.features.library.routes.agents.complete_onboarding_step",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
@@ -249,7 +250,7 @@ def test_add_agent_to_library_success(
|
||||
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
"backend.api.features.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Endpoints
|
||||
|
||||
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||
from gcloud.aio import storage as async_storage
|
||||
from PIL import Image
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
OAuthApplicationInfo,
|
||||
TokenIntrospectionResult,
|
||||
consume_authorization_code,
|
||||
create_access_token,
|
||||
create_authorization_code,
|
||||
create_refresh_token,
|
||||
get_oauth_application,
|
||||
get_oauth_application_by_id,
|
||||
introspect_token,
|
||||
list_user_oauth_applications,
|
||||
refresh_tokens,
|
||||
revoke_access_token,
|
||||
revoke_refresh_token,
|
||||
update_oauth_application,
|
||||
validate_client_credentials,
|
||||
validate_redirect_uri,
|
||||
validate_scopes,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 token response"""
|
||||
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
access_token: str
|
||||
access_token_expires_at: datetime
|
||||
refresh_token: str
|
||||
refresh_token_expires_at: datetime
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OAuth 2.0 error response"""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthApplicationPublicInfo(BaseModel):
|
||||
"""Public information about an OAuth application (for consent screen)"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Info Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/app/{client_id}",
|
||||
responses={
|
||||
404: {"description": "Application not found or disabled"},
|
||||
},
|
||||
)
|
||||
async def get_oauth_app_info(
|
||||
client_id: str, user_id: str = Security(get_user_id)
|
||||
) -> OAuthApplicationPublicInfo:
|
||||
"""
|
||||
Get public information about an OAuth application.
|
||||
|
||||
This endpoint is used by the consent screen to display application details
|
||||
to the user before they authorize access.
|
||||
|
||||
Returns:
|
||||
- name: Application name
|
||||
- description: Application description (if provided)
|
||||
- scopes: List of scopes the application is allowed to request
|
||||
"""
|
||||
app = await get_oauth_application(client_id)
|
||||
if not app or not app.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found",
|
||||
)
|
||||
|
||||
return OAuthApplicationPublicInfo(
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logo_url,
|
||||
scopes=[s.value for s in app.scopes],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuthorizeRequest(BaseModel):
|
||||
"""OAuth 2.0 authorization request"""
|
||||
|
||||
client_id: str = Field(description="Client identifier")
|
||||
redirect_uri: str = Field(description="Redirect URI")
|
||||
scopes: list[str] = Field(description="List of scopes")
|
||||
state: str = Field(description="Anti-CSRF token from client")
|
||||
response_type: str = Field(
|
||||
default="code", description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
"""OAuth 2.0 authorization response with redirect URL"""
|
||||
|
||||
redirect_url: str = Field(description="URL to redirect the user to")
|
||||
|
||||
|
||||
@router.post("/authorize")
|
||||
async def authorize(
|
||||
request: AuthorizeRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint
|
||||
|
||||
User must be logged in (authenticated with Supabase JWT).
|
||||
This endpoint creates an authorization code and returns a redirect URL.
|
||||
|
||||
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||
|
||||
The frontend consent screen should call this endpoint after the user approves,
|
||||
then redirect the user to the returned `redirect_url`.
|
||||
|
||||
Request Body:
|
||||
- client_id: The OAuth application's client ID
|
||||
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||
- response_type: Must be "code" (for authorization code flow)
|
||||
- code_challenge: PKCE code challenge (required)
|
||||
- code_challenge_method: "S256" (recommended) or "plain"
|
||||
|
||||
Returns:
|
||||
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||
|
||||
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||
for critical errors (like invalid redirect_uri).
|
||||
"""
|
||||
try:
|
||||
# Validate response_type
|
||||
if request.response_type != "code":
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"unsupported_response_type",
|
||||
"Only 'code' response type is supported",
|
||||
)
|
||||
|
||||
# Get application
|
||||
app = await get_oauth_application(request.client_id)
|
||||
if not app:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Unknown client_id",
|
||||
)
|
||||
|
||||
if not app.is_active:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Application is not active",
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if not validate_redirect_uri(app, request.redirect_uri):
|
||||
# For invalid redirect_uri, we can't redirect safely
|
||||
# Must return error instead
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Invalid redirect_uri. "
|
||||
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
try:
|
||||
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||
except ValueError as e:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
f"Invalid scope: {e}",
|
||||
)
|
||||
|
||||
if not requested_scopes:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"At least one scope is required",
|
||||
)
|
||||
|
||||
if not validate_scopes(app, requested_scopes):
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"Application is not authorized for all requested scopes. "
|
||||
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||
)
|
||||
|
||||
# Create authorization code
|
||||
auth_code = await create_authorization_code(
|
||||
application_id=app.id,
|
||||
user_id=user_id,
|
||||
scopes=requested_scopes,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_challenge=request.code_challenge,
|
||||
code_challenge_method=request.code_challenge_method,
|
||||
)
|
||||
|
||||
# Build redirect URL with authorization code
|
||||
params = {
|
||||
"code": auth_code.code,
|
||||
"state": request.state,
|
||||
}
|
||||
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"Authorization code issued for user #{user_id} "
|
||||
f"and app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"server_error",
|
||||
"An unexpected error occurred",
|
||||
)
|
||||
|
||||
|
||||
def _error_redirect_url(
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
error: str,
|
||||
error_description: Optional[str] = None,
|
||||
) -> AuthorizeResponse:
|
||||
"""Helper to build redirect URL with OAuth error parameters"""
|
||||
params = {
|
||||
"error": error,
|
||||
"state": state,
|
||||
}
|
||||
if error_description:
|
||||
params["error_description"] = error_description
|
||||
|
||||
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenRequestByCode(BaseModel):
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(description="Authorization code")
|
||||
redirect_uri: str = Field(
|
||||
description="Redirect URI (must match authorization request)"
|
||||
)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
code_verifier: str = Field(description="PKCE code verifier")
|
||||
|
||||
|
||||
class TokenRequestByRefreshToken(BaseModel):
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint
|
||||
|
||||
Exchanges authorization code or refresh token for access token.
|
||||
|
||||
Grant Types:
|
||||
1. authorization_code: Exchange authorization code for tokens
|
||||
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||
- Optional: code_verifier (required if PKCE was used)
|
||||
|
||||
2. refresh_token: Exchange refresh token for new access token
|
||||
- Required: grant_type, refresh_token, client_id, client_secret
|
||||
|
||||
Returns:
|
||||
- access_token: Bearer token for API access (1 hour TTL)
|
||||
- token_type: "Bearer"
|
||||
- expires_in: Seconds until access token expires
|
||||
- refresh_token: Token for refreshing access (30 days TTL)
|
||||
- scopes: List of scopes
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(
|
||||
request.client_id, request.client_secret
|
||||
)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Handle authorization_code grant
|
||||
if request.grant_type == "authorization_code":
|
||||
# Consume authorization code
|
||||
try:
|
||||
user_id, scopes = await consume_authorization_code(
|
||||
code=request.code,
|
||||
application_id=app.id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_verifier=request.code_verifier,
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create access and refresh tokens
|
||||
access_token = await create_access_token(app.id, user_id, scopes)
|
||||
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||
|
||||
logger.info(
|
||||
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||
"via authorization code"
|
||||
)
|
||||
|
||||
if not access_token.token or not refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=access_token.token.get_secret_value(),
|
||||
access_token_expires_at=access_token.expires_at,
|
||||
refresh_token=refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=refresh_token.expires_at,
|
||||
scopes=list(s.value for s in scopes),
|
||||
)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif request.grant_type == "refresh_token":
|
||||
# Refresh access token
|
||||
try:
|
||||
new_access_token, new_refresh_token = await refresh_tokens(
|
||||
request.refresh_token, app.id
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||
f"by app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
if not new_access_token.token or not new_refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=new_access_token.token.get_secret_value(),
|
||||
access_token_expires_at=new_access_token.expires_at,
|
||||
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||
scopes=list(s.value for s in new_access_token.scopes),
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||
"Must be 'authorization_code' or 'refresh_token'",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Body(description="Token to introspect"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||
|
||||
Allows clients to check if a token is valid and get its metadata.
|
||||
|
||||
Returns:
|
||||
- active: Whether the token is currently active
|
||||
- scopes: List of authorized scopes (if active)
|
||||
- client_id: The client the token was issued to (if active)
|
||||
- user_id: The user the token represents (if active)
|
||||
- exp: Expiration timestamp (if active)
|
||||
- token_type: "access_token" or "refresh_token" (if active)
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Introspect the token
|
||||
return await introspect_token(token, token_type_hint)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Body(description="Token to revoke"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
):
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||
|
||||
Allows clients to revoke an access or refresh token.
|
||||
|
||||
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||
Revoking an access token does NOT revoke the associated refresh token.
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Try to revoke as access token first
|
||||
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||
if token_type_hint != "refresh_token":
|
||||
revoked = await revoke_access_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Try to revoke as refresh token
|
||||
revoked = await revoke_refresh_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||
# or if token belongs to a different application.
|
||||
# This prevents token scanning attacks.
|
||||
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Management Endpoints (for app owners)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/apps/mine")
|
||||
async def list_my_oauth_apps(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[OAuthApplicationInfo]:
|
||||
"""
|
||||
List all OAuth applications owned by the current user.
|
||||
|
||||
Returns a list of OAuth applications with their details including:
|
||||
- id, name, description, logo_url
|
||||
- client_id (public identifier)
|
||||
- redirect_uris, grant_types, scopes
|
||||
- is_active status
|
||||
- created_at, updated_at timestamps
|
||||
|
||||
Note: client_secret is never returned for security reasons.
|
||||
"""
|
||||
return await list_user_oauth_applications(user_id)
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/status")
|
||||
async def update_app_status(
|
||||
app_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Enable or disable an OAuth application.
|
||||
|
||||
Only the application owner can update the status.
|
||||
When disabled, the application cannot be used for new authorizations
|
||||
and existing access tokens will fail validation.
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
action = "enabled" if is_active else "disabled"
|
||||
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
class UpdateAppLogoRequest(BaseModel):
|
||||
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/logo")
|
||||
async def update_app_logo(
|
||||
app_id: str,
|
||||
request: UpdateAppLogoRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Update the logo URL for an OAuth application.
|
||||
|
||||
Only the application owner can update the logo.
|
||||
The logo should be uploaded first using the media upload endpoint,
|
||||
then this endpoint is called with the resulting URL.
|
||||
|
||||
Logo requirements:
|
||||
- Must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=request.logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
# Logo upload constraints
|
||||
LOGO_MIN_SIZE = 512
|
||||
LOGO_MAX_SIZE = 2048
|
||||
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||
|
||||
|
||||
@router.post("/apps/{app_id}/logo/upload")
|
||||
async def upload_app_logo(
|
||||
app_id: str,
|
||||
file: UploadFile,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Upload a logo image for an OAuth application.
|
||||
|
||||
Requirements:
|
||||
- Image must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
- Allowed formats: JPEG, PNG, WebP
|
||||
- Maximum file size: 3MB
|
||||
|
||||
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||
Returns the updated application info.
|
||||
"""
|
||||
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Check GCS configuration
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Media storage is not configured",
|
||||
)
|
||||
|
||||
# Validate content type
|
||||
content_type = file.content_type
|
||||
if content_type not in LOGO_ALLOWED_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading logo file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read uploaded file",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"File too large. "
|
||||
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate image dimensions
|
||||
try:
|
||||
image = Image.open(io.BytesIO(file_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if width != height:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo must be square. Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width < LOGO_MIN_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width > LOGO_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating logo image: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid image file",
|
||||
)
|
||||
|
||||
# Scan for viruses
|
||||
filename = file.filename or "logo"
|
||||
await scan_content_safe(file_bytes, filename=filename)
|
||||
|
||||
# Generate unique filename
|
||||
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||
|
||||
# Upload to GCS
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading logo to GCS: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload logo",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
# Update the app with the new logo URL
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||
"""
|
||||
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||
"""
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||
|
||||
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||
old_path = app.logo_url.replace(storage_base_url, "")
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
await async_client.delete(bucket_name, old_path)
|
||||
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||
except Exception as e:
|
||||
# Log but don't fail - the new logo was uploaded successfully
|
||||
logger.warning(
|
||||
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||
)
|
||||
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,9 +6,9 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.otto.models as otto_models
|
||||
import backend.server.v2.otto.routes as otto_routes
|
||||
from backend.server.v2.otto.service import OttoService
|
||||
from . import models as otto_models
|
||||
from . import routes as otto_routes
|
||||
from .service import OttoService
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(otto_routes.router)
|
||||
@@ -4,12 +4,15 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from backend.api.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.data.user import (
|
||||
get_user_by_email,
|
||||
set_user_email_verification,
|
||||
unsubscribe_user_by_token,
|
||||
)
|
||||
from backend.server.routers.postmark.models import (
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import (
|
||||
PostmarkBounceEnum,
|
||||
PostmarkBounceWebhook,
|
||||
PostmarkClickWebhook,
|
||||
@@ -19,8 +22,6 @@ from backend.server.routers.postmark.models import (
|
||||
PostmarkSubscriptionChangeWebhook,
|
||||
PostmarkWebhook,
|
||||
)
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Literal
|
||||
|
||||
import backend.server.v2.store.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
@@ -29,7 +30,7 @@ async def _get_cached_store_agents(
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
return await store_db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
@@ -42,10 +43,12 @@ async def _get_cached_store_agents(
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
async def _get_cached_agent_details(
|
||||
username: str, agent_name: str, include_changelog: bool = False
|
||||
):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
return await store_db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name, include_changelog=include_changelog
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +62,7 @@ async def _get_cached_store_creators(
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
return await store_db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
@@ -72,6 +75,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
@@ -9,9 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.db import transaction
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -29,6 +27,9 @@ from backend.notifications.notifications import queue_notification_async
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
from . import model as store_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
@@ -46,7 +47,7 @@ async def get_store_agents(
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
"""
|
||||
@@ -73,10 +74,10 @@ async def get_store_agents(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
@@ -122,11 +123,11 @@ async def get_store_agents(
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
@@ -148,9 +149,9 @@ async def get_store_agents(
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
return store_model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
@@ -181,8 +182,8 @@ async def log_search_term(search_query: str):
|
||||
|
||||
|
||||
async def get_store_agent_details(
|
||||
username: str, agent_name: str
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
username: str, agent_name: str, include_changelog: bool = False
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get PUBLIC store agent details from the StoreAgent view"""
|
||||
logger.debug(f"Getting store agent details for {username}/{agent_name}")
|
||||
|
||||
@@ -193,7 +194,7 @@ async def get_store_agent_details(
|
||||
|
||||
if not agent:
|
||||
logger.warning(f"Agent not found: {username}/{agent_name}")
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent {username}/{agent_name} not found"
|
||||
)
|
||||
|
||||
@@ -246,8 +247,29 @@ async def get_store_agent_details(
|
||||
else:
|
||||
recommended_schedule_cron = None
|
||||
|
||||
# Fetch changelog data if requested
|
||||
changelog_data = None
|
||||
if include_changelog and store_listing:
|
||||
changelog_versions = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"storeListingId": store_listing.id,
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
},
|
||||
order=[{"version": "desc"}],
|
||||
)
|
||||
)
|
||||
changelog_data = [
|
||||
store_model.ChangelogEntry(
|
||||
version=str(version.version),
|
||||
changes_summary=version.changesSummary or "No changes recorded",
|
||||
date=version.createdAt,
|
||||
)
|
||||
for version in changelog_versions
|
||||
]
|
||||
|
||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
@@ -262,12 +284,15 @@ async def get_store_agent_details(
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
active_version_id=active_version_id,
|
||||
has_approved_version=has_approved_version,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
changelog=changelog_data,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
except store_exceptions.AgentNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {e}")
|
||||
@@ -303,7 +328,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
|
||||
async def get_store_agent_by_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
) -> store_model.StoreAgentDetails:
|
||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||
|
||||
try:
|
||||
@@ -313,12 +338,12 @@ async def get_store_agent_by_version_id(
|
||||
|
||||
if not agent:
|
||||
logger.warning(f"Agent not found: {store_listing_version_id}")
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
logger.debug(f"Found agent details for {store_listing_version_id}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
@@ -333,9 +358,11 @@ async def get_store_agent_by_version_id(
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
except store_exceptions.AgentNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {e}")
|
||||
@@ -348,7 +375,7 @@ async def get_store_creators(
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||
) -> store_model.CreatorsResponse:
|
||||
"""Get PUBLIC store creators from the Creator view"""
|
||||
logger.debug(
|
||||
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
|
||||
@@ -423,7 +450,7 @@ async def get_store_creators(
|
||||
|
||||
# Convert to response model
|
||||
creator_models = [
|
||||
backend.server.v2.store.model.Creator(
|
||||
store_model.Creator(
|
||||
username=creator.username,
|
||||
name=creator.name,
|
||||
description=creator.description,
|
||||
@@ -437,9 +464,9 @@ async def get_store_creators(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(creator_models)} creators")
|
||||
return backend.server.v2.store.model.CreatorsResponse(
|
||||
return store_model.CreatorsResponse(
|
||||
creators=creator_models,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
@@ -453,7 +480,7 @@ async def get_store_creators(
|
||||
|
||||
async def get_store_creator_details(
|
||||
username: str,
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
) -> store_model.CreatorDetails:
|
||||
logger.debug(f"Getting store creator details for {username}")
|
||||
|
||||
try:
|
||||
@@ -464,12 +491,10 @@ async def get_store_creator_details(
|
||||
|
||||
if not creator:
|
||||
logger.warning(f"Creator not found: {username}")
|
||||
raise backend.server.v2.store.exceptions.CreatorNotFoundError(
|
||||
f"Creator {username} not found"
|
||||
)
|
||||
raise store_exceptions.CreatorNotFoundError(f"Creator {username} not found")
|
||||
|
||||
logger.debug(f"Found creator details for {username}")
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
return store_model.CreatorDetails(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
description=creator.description,
|
||||
@@ -479,7 +504,7 @@ async def get_store_creator_details(
|
||||
agent_runs=creator.agent_runs,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.CreatorNotFoundError:
|
||||
except store_exceptions.CreatorNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creator details: {e}")
|
||||
@@ -488,7 +513,7 @@ async def get_store_creator_details(
|
||||
|
||||
async def get_store_submissions(
|
||||
user_id: str, page: int = 1, page_size: int = 20
|
||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
||||
) -> store_model.StoreSubmissionsResponse:
|
||||
"""Get store submissions for the authenticated user -- not an admin"""
|
||||
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
|
||||
|
||||
@@ -513,7 +538,7 @@ async def get_store_submissions(
|
||||
# Convert to response models
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = backend.server.v2.store.model.StoreSubmission(
|
||||
submission_model = store_model.StoreSubmission(
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -538,9 +563,9 @@ async def get_store_submissions(
|
||||
submission_models.append(submission_model)
|
||||
|
||||
logger.debug(f"Found {len(submission_models)} submissions")
|
||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
return store_model.StoreSubmissionsResponse(
|
||||
submissions=submission_models,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
@@ -551,9 +576,9 @@ async def get_store_submissions(
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching store submissions: {e}")
|
||||
# Return empty response rather than exposing internal errors
|
||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
return store_model.StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
@@ -586,7 +611,7 @@ async def delete_store_submission(
|
||||
|
||||
if not submission:
|
||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
|
||||
raise store_exceptions.SubmissionNotFoundError(
|
||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||
)
|
||||
|
||||
@@ -618,7 +643,7 @@ async def create_store_submission(
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial Submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
) -> store_model.StoreSubmission:
|
||||
"""
|
||||
Create the first (and only) store listing and thus submission as a normal user
|
||||
|
||||
@@ -659,7 +684,7 @@ async def create_store_submission(
|
||||
logger.warning(
|
||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
@@ -732,7 +757,7 @@ async def create_store_submission(
|
||||
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -755,7 +780,7 @@ async def create_store_submission(
|
||||
logger.debug(
|
||||
f"Slug '{slug}' is already in use by another agent (agent_id: {agent_id}) for user {user_id}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.SlugAlreadyInUseError(
|
||||
raise store_exceptions.SlugAlreadyInUseError(
|
||||
f"The URL slug '{slug}' is already in use by another one of your agents. Please choose a different slug."
|
||||
) from exc
|
||||
else:
|
||||
@@ -764,8 +789,8 @@ async def create_store_submission(
|
||||
f"Unique constraint violated (not slug): {error_str}"
|
||||
) from exc
|
||||
except (
|
||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
||||
backend.server.v2.store.exceptions.ListingExistsError,
|
||||
store_exceptions.AgentNotFoundError,
|
||||
store_exceptions.ListingExistsError,
|
||||
):
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
@@ -786,7 +811,7 @@ async def edit_store_submission(
|
||||
changes_summary: str | None = "Update submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
instructions: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
) -> store_model.StoreSubmission:
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
|
||||
@@ -828,7 +853,7 @@ async def edit_store_submission(
|
||||
)
|
||||
|
||||
if not current_version:
|
||||
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
|
||||
raise store_exceptions.SubmissionNotFoundError(
|
||||
f"Store listing version not found: {store_listing_version_id}"
|
||||
)
|
||||
|
||||
@@ -837,7 +862,7 @@ async def edit_store_submission(
|
||||
not current_version.StoreListing
|
||||
or current_version.StoreListing.owningUserId != user_id
|
||||
):
|
||||
raise backend.server.v2.store.exceptions.UnauthorizedError(
|
||||
raise store_exceptions.UnauthorizedError(
|
||||
f"User {user_id} does not own submission {store_listing_version_id}"
|
||||
)
|
||||
|
||||
@@ -846,7 +871,7 @@ async def edit_store_submission(
|
||||
|
||||
# Check if we can edit this submission
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
||||
raise backend.server.v2.store.exceptions.InvalidOperationError(
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
"Cannot edit a rejected submission"
|
||||
)
|
||||
|
||||
@@ -895,7 +920,7 @@ async def edit_store_submission(
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
@@ -916,16 +941,16 @@ async def edit_store_submission(
|
||||
)
|
||||
|
||||
else:
|
||||
raise backend.server.v2.store.exceptions.InvalidOperationError(
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
||||
)
|
||||
|
||||
except (
|
||||
backend.server.v2.store.exceptions.SubmissionNotFoundError,
|
||||
backend.server.v2.store.exceptions.UnauthorizedError,
|
||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
||||
backend.server.v2.store.exceptions.ListingExistsError,
|
||||
backend.server.v2.store.exceptions.InvalidOperationError,
|
||||
store_exceptions.SubmissionNotFoundError,
|
||||
store_exceptions.UnauthorizedError,
|
||||
store_exceptions.AgentNotFoundError,
|
||||
store_exceptions.ListingExistsError,
|
||||
store_exceptions.InvalidOperationError,
|
||||
):
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
@@ -948,7 +973,7 @@ async def create_store_version(
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
) -> store_model.StoreSubmission:
|
||||
"""
|
||||
Create a new version for an existing store listing
|
||||
|
||||
@@ -981,7 +1006,7 @@ async def create_store_version(
|
||||
)
|
||||
|
||||
if not listing:
|
||||
raise backend.server.v2.store.exceptions.ListingNotFoundError(
|
||||
raise store_exceptions.ListingNotFoundError(
|
||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
||||
)
|
||||
|
||||
@@ -993,7 +1018,7 @@ async def create_store_version(
|
||||
)
|
||||
|
||||
if not agent:
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
@@ -1028,7 +1053,7 @@ async def create_store_version(
|
||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||
)
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -1055,7 +1080,7 @@ async def create_store_review(
|
||||
store_listing_version_id: str,
|
||||
score: int,
|
||||
comments: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreReview:
|
||||
) -> store_model.StoreReview:
|
||||
"""Create a review for a store listing as a user to detail their experience"""
|
||||
try:
|
||||
data = prisma.types.StoreListingReviewUpsertInput(
|
||||
@@ -1080,7 +1105,7 @@ async def create_store_review(
|
||||
data=data,
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.StoreReview(
|
||||
return store_model.StoreReview(
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
@@ -1092,7 +1117,7 @@ async def create_store_review(
|
||||
|
||||
async def get_user_profile(
|
||||
user_id: str,
|
||||
) -> backend.server.v2.store.model.ProfileDetails | None:
|
||||
) -> store_model.ProfileDetails | None:
|
||||
logger.debug(f"Getting user profile for {user_id}")
|
||||
|
||||
try:
|
||||
@@ -1102,7 +1127,7 @@ async def get_user_profile(
|
||||
|
||||
if not profile:
|
||||
return None
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
return store_model.ProfileDetails(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
description=profile.description,
|
||||
@@ -1115,8 +1140,8 @@ async def get_user_profile(
|
||||
|
||||
|
||||
async def update_profile(
|
||||
user_id: str, profile: backend.server.v2.store.model.Profile
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
user_id: str, profile: store_model.Profile
|
||||
) -> store_model.CreatorDetails:
|
||||
"""
|
||||
Update the store profile for a user or create a new one if it doesn't exist.
|
||||
Args:
|
||||
@@ -1139,7 +1164,7 @@ async def update_profile(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if not existing_profile:
|
||||
raise backend.server.v2.store.exceptions.ProfileNotFoundError(
|
||||
raise store_exceptions.ProfileNotFoundError(
|
||||
f"Profile not found for user {user_id}. This should not be possible."
|
||||
)
|
||||
|
||||
@@ -1175,7 +1200,7 @@ async def update_profile(
|
||||
logger.error(f"Failed to update profile for user {user_id}")
|
||||
raise DatabaseError("Failed to update profile")
|
||||
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
return store_model.CreatorDetails(
|
||||
name=updated_profile.name,
|
||||
username=updated_profile.username,
|
||||
description=updated_profile.description,
|
||||
@@ -1195,7 +1220,7 @@ async def get_my_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
||||
) -> store_model.MyAgentsResponse:
|
||||
"""Get the agents for the authenticated user"""
|
||||
logger.debug(f"Getting my agents for user {user_id}, page={page}")
|
||||
|
||||
@@ -1232,7 +1257,7 @@ async def get_my_agents(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
my_agents = [
|
||||
backend.server.v2.store.model.MyAgent(
|
||||
store_model.MyAgent(
|
||||
agent_id=graph.id,
|
||||
agent_version=graph.version,
|
||||
agent_name=graph.name or "",
|
||||
@@ -1245,9 +1270,9 @@ async def get_my_agents(
|
||||
if (graph := library_agent.AgentGraph)
|
||||
]
|
||||
|
||||
return backend.server.v2.store.model.MyAgentsResponse(
|
||||
return store_model.MyAgentsResponse(
|
||||
agents=my_agents,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
@@ -1394,7 +1419,7 @@ async def review_store_submission(
|
||||
external_comments: str,
|
||||
internal_comments: str,
|
||||
reviewer_id: str,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Review a store listing submission as an admin."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
@@ -1625,7 +1650,7 @@ async def review_store_submission(
|
||||
pass
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
@@ -1660,7 +1685,7 @@ async def get_admin_listings_with_versions(
|
||||
search_query: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreListingsWithVersionsResponse:
|
||||
) -> store_model.StoreListingsWithVersionsResponse:
|
||||
"""
|
||||
Get store listings for admins with all their versions.
|
||||
|
||||
@@ -1759,10 +1784,10 @@ async def get_admin_listings_with_versions(
|
||||
# Convert to response models
|
||||
listings_with_versions = []
|
||||
for listing in listings:
|
||||
versions: list[backend.server.v2.store.model.StoreSubmission] = []
|
||||
versions: list[store_model.StoreSubmission] = []
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
||||
version_model = store_model.StoreSubmission(
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
@@ -1790,26 +1815,24 @@ async def get_admin_listings_with_versions(
|
||||
|
||||
creator_email = listing.OwningUser.email if listing.OwningUser else None
|
||||
|
||||
listing_with_versions = (
|
||||
backend.server.v2.store.model.StoreListingWithVersions(
|
||||
listing_id=listing.id,
|
||||
slug=listing.slug,
|
||||
agent_id=listing.agentGraphId,
|
||||
agent_version=listing.agentGraphVersion,
|
||||
active_version_id=listing.activeVersionId,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
creator_email=creator_email,
|
||||
latest_version=latest_version,
|
||||
versions=versions,
|
||||
)
|
||||
listing_with_versions = store_model.StoreListingWithVersions(
|
||||
listing_id=listing.id,
|
||||
slug=listing.slug,
|
||||
agent_id=listing.agentGraphId,
|
||||
agent_version=listing.agentGraphVersion,
|
||||
active_version_id=listing.activeVersionId,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
creator_email=creator_email,
|
||||
latest_version=latest_version,
|
||||
versions=versions,
|
||||
)
|
||||
|
||||
listings_with_versions.append(listing_with_versions)
|
||||
|
||||
logger.debug(f"Found {len(listings_with_versions)} listings for admin")
|
||||
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
|
||||
return store_model.StoreListingsWithVersionsResponse(
|
||||
listings=listings_with_versions,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
@@ -1819,9 +1842,9 @@ async def get_admin_listings_with_versions(
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching admin store listings: {e}")
|
||||
# Return empty response rather than exposing internal errors
|
||||
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
|
||||
return store_model.StoreListingsWithVersionsResponse(
|
||||
listings=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
pagination=store_model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
@@ -6,8 +6,8 @@ import prisma.models
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
import backend.server.v2.store.db as db
|
||||
from backend.server.v2.store.model import Profile
|
||||
from . import db
|
||||
from .model import Profile
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -40,6 +40,8 @@ async def test_get_store_agents(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
@@ -83,6 +85,8 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
@@ -105,6 +109,8 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id-active",
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user