mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-07 22:33:57 -05:00
Merge origin/dev into hackathon/copilot
Resolved conflicts from api restructuring: - backend/server/v2/* -> backend/api/features/* - Updated imports to use new paths - Kept chat/copilot functionality with new structure - Accepted openapi.json from dev (regenerate after merge) - Resolved useCredentialsInput naming conflict (singular) - NavbarView.tsx merged into Navbar.tsx
This commit is contained in:
@@ -11,7 +11,7 @@ jobs:
|
|||||||
stale:
|
stale:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v9
|
- uses: actions/stale@v10
|
||||||
with:
|
with:
|
||||||
# operations-per-run: 5000
|
# operations-per-run: 5000
|
||||||
stale-issue-message: >
|
stale-issue-message: >
|
||||||
|
|||||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/labeler@v5
|
- uses: actions/labeler@v6
|
||||||
with:
|
with:
|
||||||
sync-labels: true
|
sync-labels: true
|
||||||
|
|||||||
@@ -57,6 +57,9 @@ class APIKeySmith:
|
|||||||
|
|
||||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||||
"""Migrate a legacy hash to secure hash format."""
|
"""Migrate a legacy hash to secure hash format."""
|
||||||
|
if not raw_key.startswith(self.PREFIX):
|
||||||
|
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||||
|
|
||||||
salt = self._generate_salt()
|
salt = self._generate_salt()
|
||||||
hash = self._hash_key_with_salt(raw_key, salt)
|
hash = self._hash_key_with_salt(raw_key, salt)
|
||||||
return hash, salt.hex()
|
return hash, salt.hex()
|
||||||
|
|||||||
@@ -1,29 +1,25 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
|
|
||||||
from .jwt_utils import bearer_jwt_auth
|
from .jwt_utils import bearer_jwt_auth
|
||||||
|
|
||||||
|
|
||||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||||
"""
|
"""
|
||||||
Set up custom OpenAPI schema generation that adds 401 responses
|
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||||
to all authenticated endpoints.
|
to all authenticated endpoints.
|
||||||
|
|
||||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||||
401 responses instead of 403, but FastAPI only automatically adds security
|
401 responses instead of 403, but FastAPI only automatically adds security
|
||||||
responses when auto_error=True.
|
responses when auto_error=True.
|
||||||
"""
|
"""
|
||||||
|
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||||
|
wrapped_openapi = app.openapi
|
||||||
|
|
||||||
def custom_openapi():
|
def custom_openapi():
|
||||||
if app.openapi_schema:
|
if app.openapi_schema:
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
openapi_schema = get_openapi(
|
openapi_schema = wrapped_openapi()
|
||||||
title=app.title,
|
|
||||||
version=app.version,
|
|
||||||
description=app.description,
|
|
||||||
routes=app.routes,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add 401 response to all endpoints that have security requirements
|
# Add 401 response to all endpoints that have security requirements
|
||||||
for path, methods in openapi_schema["paths"].items():
|
for path, methods in openapi_schema["paths"].items():
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
|||||||
import pytest
|
import pytest
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
from backend.server.v2.myroute import router
|
from backend.api.features.myroute import router
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
|||||||
import fastapi
|
import fastapi
|
||||||
import fastapi.testclient
|
import fastapi.testclient
|
||||||
import pytest
|
import pytest
|
||||||
from backend.server.v2.myroute import router
|
from backend.api.features.myroute import router
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ from typing import Dict, Set
|
|||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionEventType,
|
ExecutionEventType,
|
||||||
GraphExecutionEvent,
|
GraphExecutionEvent,
|
||||||
NodeExecutionEvent,
|
NodeExecutionEvent,
|
||||||
)
|
)
|
||||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
|
||||||
|
|
||||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
from backend.api.conn_manager import ConnectionManager
|
||||||
|
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
GraphExecutionEvent,
|
GraphExecutionEvent,
|
||||||
NodeExecutionEvent,
|
NodeExecutionEvent,
|
||||||
)
|
)
|
||||||
from backend.server.conn_manager import ConnectionManager
|
|
||||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||||
|
from backend.monitoring.instrumentation import instrument_fastapi
|
||||||
|
|
||||||
|
from .v1.routes import v1_router
|
||||||
|
|
||||||
|
external_api = FastAPI(
|
||||||
|
title="AutoGPT External API",
|
||||||
|
description="External API for AutoGPT integrations",
|
||||||
|
docs_url="/docs",
|
||||||
|
version="1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
external_api.include_router(v1_router, prefix="/v1")
|
||||||
|
|
||||||
|
# Add Prometheus instrumentation
|
||||||
|
instrument_fastapi(
|
||||||
|
external_api,
|
||||||
|
service_name="external-api",
|
||||||
|
expose_endpoint=True,
|
||||||
|
endpoint="/metrics",
|
||||||
|
include_in_schema=True,
|
||||||
|
)
|
||||||
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
from fastapi import HTTPException, Security, status
|
||||||
|
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from prisma.enums import APIKeyPermission
|
||||||
|
|
||||||
|
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||||
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
|
from backend.data.auth.oauth import (
|
||||||
|
InvalidClientError,
|
||||||
|
InvalidTokenError,
|
||||||
|
OAuthAccessTokenInfo,
|
||||||
|
validate_access_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
bearer_auth = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||||
|
"""Middleware for API key authentication only"""
|
||||||
|
if api_key is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key_obj = await validate_api_key(api_key)
|
||||||
|
|
||||||
|
if not api_key_obj:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_key_obj
|
||||||
|
|
||||||
|
|
||||||
|
async def require_access_token(
|
||||||
|
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||||
|
) -> OAuthAccessTokenInfo:
|
||||||
|
"""Middleware for OAuth access token authentication only"""
|
||||||
|
if bearer is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing Authorization header",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
token_info, _ = await validate_access_token(bearer.credentials)
|
||||||
|
except (InvalidClientError, InvalidTokenError) as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||||
|
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
|
||||||
|
async def require_auth(
|
||||||
|
api_key: str | None = Security(api_key_header),
|
||||||
|
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||||
|
) -> APIAuthorizationInfo:
|
||||||
|
"""
|
||||||
|
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||||
|
|
||||||
|
Supports two authentication methods, which are checked in order:
|
||||||
|
1. X-API-Key header (existing API key authentication)
|
||||||
|
2. Authorization: Bearer <token> header (OAuth access token)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||||
|
"""
|
||||||
|
# Try API key first
|
||||||
|
if api_key is not None:
|
||||||
|
api_key_info = await validate_api_key(api_key)
|
||||||
|
if api_key_info:
|
||||||
|
return api_key_info
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try OAuth bearer token
|
||||||
|
if bearer is not None:
|
||||||
|
try:
|
||||||
|
token_info, _ = await validate_access_token(bearer.credentials)
|
||||||
|
return token_info
|
||||||
|
except (InvalidClientError, InvalidTokenError) as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||||
|
|
||||||
|
# No credentials provided
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing authentication. Provide API key or access token.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(permission: APIKeyPermission):
|
||||||
|
"""
|
||||||
|
Dependency function for checking specific permissions
|
||||||
|
(works with API keys and OAuth tokens)
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def check_permission(
|
||||||
|
auth: APIAuthorizationInfo = Security(require_auth),
|
||||||
|
) -> APIAuthorizationInfo:
|
||||||
|
if permission not in auth.scopes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Missing required permission: {permission.value}",
|
||||||
|
)
|
||||||
|
return auth
|
||||||
|
|
||||||
|
return check_permission
|
||||||
@@ -16,7 +16,9 @@ from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
|||||||
from prisma.enums import APIKeyPermission
|
from prisma.enums import APIKeyPermission
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from backend.data.api_key import APIKeyInfo
|
from backend.api.external.middleware import require_permission
|
||||||
|
from backend.api.features.integrations.models import get_all_provider_names
|
||||||
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Credentials,
|
Credentials,
|
||||||
@@ -28,8 +30,6 @@ from backend.data.model import (
|
|||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.server.external.middleware import require_permission
|
|
||||||
from backend.server.integrations.models import get_all_provider_names
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -255,7 +255,7 @@ def _get_oauth_handler_for_external(
|
|||||||
|
|
||||||
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
||||||
async def list_providers(
|
async def list_providers(
|
||||||
api_key: APIKeyInfo = Security(
|
auth: APIAuthorizationInfo = Security(
|
||||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||||
),
|
),
|
||||||
) -> list[ProviderInfo]:
|
) -> list[ProviderInfo]:
|
||||||
@@ -319,7 +319,7 @@ async def list_providers(
|
|||||||
async def initiate_oauth(
|
async def initiate_oauth(
|
||||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||||
request: OAuthInitiateRequest,
|
request: OAuthInitiateRequest,
|
||||||
api_key: APIKeyInfo = Security(
|
auth: APIAuthorizationInfo = Security(
|
||||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||||
),
|
),
|
||||||
) -> OAuthInitiateResponse:
|
) -> OAuthInitiateResponse:
|
||||||
@@ -337,7 +337,10 @@ async def initiate_oauth(
|
|||||||
if not validate_callback_url(request.callback_url):
|
if not validate_callback_url(request.callback_url):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
|
detail=(
|
||||||
|
f"Callback URL origin is not allowed. "
|
||||||
|
f"Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate provider
|
# Validate provider
|
||||||
@@ -359,13 +362,15 @@ async def initiate_oauth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store state token with external flow metadata
|
# Store state token with external flow metadata
|
||||||
|
# Note: initiated_by_api_key_id is only available for API key auth, not OAuth
|
||||||
|
api_key_id = getattr(auth, "id", None) if auth.type == "api_key" else None
|
||||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||||
user_id=api_key.user_id,
|
user_id=auth.user_id,
|
||||||
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
||||||
scopes=request.scopes,
|
scopes=request.scopes,
|
||||||
callback_url=request.callback_url,
|
callback_url=request.callback_url,
|
||||||
state_metadata=request.state_metadata,
|
state_metadata=request.state_metadata,
|
||||||
initiated_by_api_key_id=api_key.id,
|
initiated_by_api_key_id=api_key_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build login URL
|
# Build login URL
|
||||||
@@ -393,7 +398,7 @@ async def initiate_oauth(
|
|||||||
async def complete_oauth(
|
async def complete_oauth(
|
||||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||||
request: OAuthCompleteRequest,
|
request: OAuthCompleteRequest,
|
||||||
api_key: APIKeyInfo = Security(
|
auth: APIAuthorizationInfo = Security(
|
||||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||||
),
|
),
|
||||||
) -> OAuthCompleteResponse:
|
) -> OAuthCompleteResponse:
|
||||||
@@ -406,7 +411,7 @@ async def complete_oauth(
|
|||||||
"""
|
"""
|
||||||
# Verify state token
|
# Verify state token
|
||||||
valid_state = await creds_manager.store.verify_state_token(
|
valid_state = await creds_manager.store.verify_state_token(
|
||||||
api_key.user_id, request.state_token, provider
|
auth.user_id, request.state_token, provider
|
||||||
)
|
)
|
||||||
|
|
||||||
if not valid_state:
|
if not valid_state:
|
||||||
@@ -453,7 +458,7 @@ async def complete_oauth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store credentials
|
# Store credentials
|
||||||
await creds_manager.create(api_key.user_id, credentials)
|
await creds_manager.create(auth.user_id, credentials)
|
||||||
|
|
||||||
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
||||||
|
|
||||||
@@ -470,7 +475,7 @@ async def complete_oauth(
|
|||||||
|
|
||||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||||
async def list_credentials(
|
async def list_credentials(
|
||||||
api_key: APIKeyInfo = Security(
|
auth: APIAuthorizationInfo = Security(
|
||||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||||
),
|
),
|
||||||
) -> list[CredentialSummary]:
|
) -> list[CredentialSummary]:
|
||||||
@@ -479,7 +484,7 @@ async def list_credentials(
|
|||||||
|
|
||||||
Returns metadata about each credential without exposing sensitive tokens.
|
Returns metadata about each credential without exposing sensitive tokens.
|
||||||
"""
|
"""
|
||||||
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
|
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||||
return [
|
return [
|
||||||
CredentialSummary(
|
CredentialSummary(
|
||||||
id=cred.id,
|
id=cred.id,
|
||||||
@@ -499,7 +504,7 @@ async def list_credentials(
|
|||||||
)
|
)
|
||||||
async def list_credentials_by_provider(
|
async def list_credentials_by_provider(
|
||||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||||
api_key: APIKeyInfo = Security(
|
auth: APIAuthorizationInfo = Security(
|
||||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||||
),
|
),
|
||||||
) -> list[CredentialSummary]:
|
) -> list[CredentialSummary]:
|
||||||
@@ -507,7 +512,7 @@ async def list_credentials_by_provider(
|
|||||||
List credentials for a specific provider.
|
List credentials for a specific provider.
|
||||||
"""
|
"""
|
||||||
credentials = await creds_manager.store.get_creds_by_provider(
|
credentials = await creds_manager.store.get_creds_by_provider(
|
||||||
api_key.user_id, provider
|
auth.user_id, provider
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
CredentialSummary(
|
CredentialSummary(
|
||||||
@@ -536,7 +541,7 @@ async def create_credential(
|
|||||||
CreateUserPasswordCredentialRequest,
|
CreateUserPasswordCredentialRequest,
|
||||||
CreateHostScopedCredentialRequest,
|
CreateHostScopedCredentialRequest,
|
||||||
] = Body(..., discriminator="type"),
|
] = Body(..., discriminator="type"),
|
||||||
api_key: APIKeyInfo = Security(
|
auth: APIAuthorizationInfo = Security(
|
||||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||||
),
|
),
|
||||||
) -> CreateCredentialResponse:
|
) -> CreateCredentialResponse:
|
||||||
@@ -591,7 +596,7 @@ async def create_credential(
|
|||||||
|
|
||||||
# Store credentials
|
# Store credentials
|
||||||
try:
|
try:
|
||||||
await creds_manager.create(api_key.user_id, credentials)
|
await creds_manager.create(auth.user_id, credentials)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to store credentials: {e}")
|
logger.error(f"Failed to store credentials: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -623,7 +628,7 @@ class DeleteCredentialResponse(BaseModel):
|
|||||||
async def delete_credential(
|
async def delete_credential(
|
||||||
provider: Annotated[str, Path(title="The provider")],
|
provider: Annotated[str, Path(title="The provider")],
|
||||||
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
||||||
api_key: APIKeyInfo = Security(
|
auth: APIAuthorizationInfo = Security(
|
||||||
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
||||||
),
|
),
|
||||||
) -> DeleteCredentialResponse:
|
) -> DeleteCredentialResponse:
|
||||||
@@ -634,7 +639,7 @@ async def delete_credential(
|
|||||||
use the main API's delete endpoint which handles webhook cleanup and
|
use the main API's delete endpoint which handles webhook cleanup and
|
||||||
token revocation.
|
token revocation.
|
||||||
"""
|
"""
|
||||||
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
|
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||||
if not creds:
|
if not creds:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||||
@@ -645,6 +650,6 @@ async def delete_credential(
|
|||||||
detail="Credentials do not match the specified provider",
|
detail="Credentials do not match the specified provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
await creds_manager.delete(api_key.user_id, cred_id)
|
await creds_manager.delete(auth.user_id, cred_id)
|
||||||
|
|
||||||
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
||||||
@@ -5,46 +5,60 @@ from typing import Annotated, Any, Literal, Optional, Sequence
|
|||||||
|
|
||||||
from fastapi import APIRouter, Body, HTTPException, Security
|
from fastapi import APIRouter, Body, HTTPException, Security
|
||||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
import backend.api.features.store.cache as store_cache
|
||||||
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.server.v2.store.cache as store_cache
|
from backend.api.external.middleware import require_permission
|
||||||
import backend.server.v2.store.model as store_model
|
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.api_key import APIKeyInfo
|
from backend.data import user as user_db
|
||||||
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.server.external.middleware import require_permission
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from .integrations import integrations_router
|
||||||
|
from .tools import tools_router
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
v1_router = APIRouter()
|
v1_router = APIRouter()
|
||||||
|
|
||||||
|
v1_router.include_router(integrations_router)
|
||||||
class NodeOutput(TypedDict):
|
v1_router.include_router(tools_router)
|
||||||
key: str
|
|
||||||
value: Any
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionNode(TypedDict):
|
class UserInfoResponse(BaseModel):
|
||||||
node_id: str
|
id: str
|
||||||
input: Any
|
name: Optional[str]
|
||||||
output: dict[str, Any]
|
email: str
|
||||||
|
timezone: str = Field(
|
||||||
|
description="The user's last known timezone (e.g. 'Europe/Amsterdam'), "
|
||||||
|
"or 'not-set' if not set"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExecutionNodeOutput(TypedDict):
|
@v1_router.get(
|
||||||
node_id: str
|
path="/me",
|
||||||
outputs: list[NodeOutput]
|
tags=["user", "meta"],
|
||||||
|
)
|
||||||
|
async def get_user_info(
|
||||||
|
auth: APIAuthorizationInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.IDENTITY)
|
||||||
|
),
|
||||||
|
) -> UserInfoResponse:
|
||||||
|
user = await user_db.get_user_by_id(auth.user_id)
|
||||||
|
|
||||||
|
return UserInfoResponse(
|
||||||
class GraphExecutionResult(TypedDict):
|
id=user.id,
|
||||||
execution_id: str
|
name=user.name,
|
||||||
status: str
|
email=user.email,
|
||||||
nodes: list[ExecutionNode]
|
timezone=user.timezone,
|
||||||
output: Optional[list[dict[str, str]]]
|
)
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
@@ -65,7 +79,9 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
|||||||
async def execute_graph_block(
|
async def execute_graph_block(
|
||||||
block_id: str,
|
block_id: str,
|
||||||
data: BlockInput,
|
data: BlockInput,
|
||||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
auth: APIAuthorizationInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||||
|
),
|
||||||
) -> CompletedBlockOutput:
|
) -> CompletedBlockOutput:
|
||||||
obj = backend.data.block.get_block(block_id)
|
obj = backend.data.block.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
@@ -85,12 +101,14 @@ async def execute_graph(
|
|||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
auth: APIAuthorizationInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.EXECUTE_GRAPH)
|
||||||
|
),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
graph_exec = await add_graph_execution(
|
graph_exec = await add_graph_execution(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
user_id=api_key.user_id,
|
user_id=auth.user_id,
|
||||||
inputs=node_input,
|
inputs=node_input,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
)
|
)
|
||||||
@@ -100,6 +118,19 @@ async def execute_graph(
|
|||||||
raise HTTPException(status_code=400, detail=msg)
|
raise HTTPException(status_code=400, detail=msg)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionNode(TypedDict):
|
||||||
|
node_id: str
|
||||||
|
input: Any
|
||||||
|
output: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphExecutionResult(TypedDict):
|
||||||
|
execution_id: str
|
||||||
|
status: str
|
||||||
|
nodes: list[ExecutionNode]
|
||||||
|
output: Optional[list[dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||||
tags=["graphs"],
|
tags=["graphs"],
|
||||||
@@ -107,10 +138,12 @@ async def execute_graph(
|
|||||||
async def get_graph_execution_results(
|
async def get_graph_execution_results(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
auth: APIAuthorizationInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.READ_GRAPH)
|
||||||
|
),
|
||||||
) -> GraphExecutionResult:
|
) -> GraphExecutionResult:
|
||||||
graph_exec = await execution_db.get_graph_execution(
|
graph_exec = await execution_db.get_graph_execution(
|
||||||
user_id=api_key.user_id,
|
user_id=auth.user_id,
|
||||||
execution_id=graph_exec_id,
|
execution_id=graph_exec_id,
|
||||||
include_node_executions=True,
|
include_node_executions=True,
|
||||||
)
|
)
|
||||||
@@ -122,7 +155,7 @@ async def get_graph_execution_results(
|
|||||||
if not await graph_db.get_graph(
|
if not await graph_db.get_graph(
|
||||||
graph_id=graph_exec.graph_id,
|
graph_id=graph_exec.graph_id,
|
||||||
version=graph_exec.graph_version,
|
version=graph_exec.graph_version,
|
||||||
user_id=api_key.user_id,
|
user_id=auth.user_id,
|
||||||
):
|
):
|
||||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||||
|
|
||||||
@@ -14,19 +14,19 @@ from fastapi import APIRouter, Security
|
|||||||
from prisma.enums import APIKeyPermission
|
from prisma.enums import APIKeyPermission
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.api_key import APIKeyInfo
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.server.external.middleware import require_permission
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.server.v2.chat.model import ChatSession
|
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||||
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||||
from backend.server.v2.chat.tools.models import ToolResponseBase
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
||||||
|
|
||||||
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
|
# Note: We use Security() as a function parameter dependency (auth: APIAuthorizationInfo = Security(...))
|
||||||
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
||||||
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
|
# while still enforcing auth AND giving us access to auth for extracting user_id.
|
||||||
|
|
||||||
|
|
||||||
# Request models
|
# Request models
|
||||||
@@ -80,7 +80,9 @@ def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
|||||||
)
|
)
|
||||||
async def find_agent(
|
async def find_agent(
|
||||||
request: FindAgentRequest,
|
request: FindAgentRequest,
|
||||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
auth: APIAuthorizationInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.USE_TOOLS)
|
||||||
|
),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for agents in the marketplace based on capabilities and user needs.
|
Search for agents in the marketplace based on capabilities and user needs.
|
||||||
@@ -91,9 +93,9 @@ async def find_agent(
|
|||||||
Returns:
|
Returns:
|
||||||
List of matching agents or no results response
|
List of matching agents or no results response
|
||||||
"""
|
"""
|
||||||
session = _create_ephemeral_session(api_key.user_id)
|
session = _create_ephemeral_session(auth.user_id)
|
||||||
result = await find_agent_tool._execute(
|
result = await find_agent_tool._execute(
|
||||||
user_id=api_key.user_id,
|
user_id=auth.user_id,
|
||||||
session=session,
|
session=session,
|
||||||
query=request.query,
|
query=request.query,
|
||||||
)
|
)
|
||||||
@@ -105,7 +107,9 @@ async def find_agent(
|
|||||||
)
|
)
|
||||||
async def run_agent(
|
async def run_agent(
|
||||||
request: RunAgentRequest,
|
request: RunAgentRequest,
|
||||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
auth: APIAuthorizationInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.USE_TOOLS)
|
||||||
|
),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run or schedule an agent from the marketplace.
|
Run or schedule an agent from the marketplace.
|
||||||
@@ -129,9 +133,9 @@ async def run_agent(
|
|||||||
- execution_started: If agent was run or scheduled successfully
|
- execution_started: If agent was run or scheduled successfully
|
||||||
- error: If something went wrong
|
- error: If something went wrong
|
||||||
"""
|
"""
|
||||||
session = _create_ephemeral_session(api_key.user_id)
|
session = _create_ephemeral_session(auth.user_id)
|
||||||
result = await run_agent_tool._execute(
|
result = await run_agent_tool._execute(
|
||||||
user_id=api_key.user_id,
|
user_id=auth.user_id,
|
||||||
session=session,
|
session=session,
|
||||||
username_agent_slug=request.username_agent_slug,
|
username_agent_slug=request.username_agent_slug,
|
||||||
inputs=request.inputs,
|
inputs=request.inputs,
|
||||||
@@ -6,9 +6,10 @@ from fastapi import APIRouter, Body, Security
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
|
|
||||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||||
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
|
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
from .model import AddUserCreditsResponse, UserHistoryResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -9,14 +9,15 @@ import pytest_mock
|
|||||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
|
|
||||||
import backend.server.v2.admin.model as admin_model
|
|
||||||
from backend.data.model import UserTransaction
|
from backend.data.model import UserTransaction
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
from .credit_admin_routes import router as credit_admin_router
|
||||||
|
from .model import UserHistoryResponse
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
app.include_router(credit_admin_routes.router)
|
app.include_router(credit_admin_router)
|
||||||
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ def setup_app_admin_auth(mock_jwt_admin):
|
|||||||
|
|
||||||
|
|
||||||
def test_add_user_credits_success(
|
def test_add_user_credits_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
configured_snapshot: Snapshot,
|
configured_snapshot: Snapshot,
|
||||||
admin_user_id: str,
|
admin_user_id: str,
|
||||||
target_user_id: str,
|
target_user_id: str,
|
||||||
@@ -42,7 +43,7 @@ def test_add_user_credits_success(
|
|||||||
return_value=(1500, "transaction-123-uuid")
|
return_value=(1500, "transaction-123-uuid")
|
||||||
)
|
)
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||||
return_value=mock_credit_model,
|
return_value=mock_credit_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,7 +85,7 @@ def test_add_user_credits_success(
|
|||||||
|
|
||||||
|
|
||||||
def test_add_user_credits_negative_amount(
|
def test_add_user_credits_negative_amount(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
snapshot: Snapshot,
|
snapshot: Snapshot,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test credit deduction by admin (negative amount)"""
|
"""Test credit deduction by admin (negative amount)"""
|
||||||
@@ -94,7 +95,7 @@ def test_add_user_credits_negative_amount(
|
|||||||
return_value=(200, "transaction-456-uuid")
|
return_value=(200, "transaction-456-uuid")
|
||||||
)
|
)
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||||
return_value=mock_credit_model,
|
return_value=mock_credit_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,12 +120,12 @@ def test_add_user_credits_negative_amount(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_user_history_success(
|
def test_get_user_history_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
snapshot: Snapshot,
|
snapshot: Snapshot,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful retrieval of user credit history"""
|
"""Test successful retrieval of user credit history"""
|
||||||
# Mock the admin_get_user_history function
|
# Mock the admin_get_user_history function
|
||||||
mock_history_response = admin_model.UserHistoryResponse(
|
mock_history_response = UserHistoryResponse(
|
||||||
history=[
|
history=[
|
||||||
UserTransaction(
|
UserTransaction(
|
||||||
user_id="user-1",
|
user_id="user-1",
|
||||||
@@ -150,7 +151,7 @@ def test_get_user_history_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||||
return_value=mock_history_response,
|
return_value=mock_history_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -170,12 +171,12 @@ def test_get_user_history_success(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_user_history_with_filters(
|
def test_get_user_history_with_filters(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
snapshot: Snapshot,
|
snapshot: Snapshot,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test user credit history with search and filter parameters"""
|
"""Test user credit history with search and filter parameters"""
|
||||||
# Mock the admin_get_user_history function
|
# Mock the admin_get_user_history function
|
||||||
mock_history_response = admin_model.UserHistoryResponse(
|
mock_history_response = UserHistoryResponse(
|
||||||
history=[
|
history=[
|
||||||
UserTransaction(
|
UserTransaction(
|
||||||
user_id="user-3",
|
user_id="user-3",
|
||||||
@@ -194,7 +195,7 @@ def test_get_user_history_with_filters(
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_get_history = mocker.patch(
|
mock_get_history = mocker.patch(
|
||||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||||
return_value=mock_history_response,
|
return_value=mock_history_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -230,12 +231,12 @@ def test_get_user_history_with_filters(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_user_history_empty_results(
|
def test_get_user_history_empty_results(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
snapshot: Snapshot,
|
snapshot: Snapshot,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test user credit history with no results"""
|
"""Test user credit history with no results"""
|
||||||
# Mock empty history response
|
# Mock empty history response
|
||||||
mock_history_response = admin_model.UserHistoryResponse(
|
mock_history_response = UserHistoryResponse(
|
||||||
history=[],
|
history=[],
|
||||||
pagination=Pagination(
|
pagination=Pagination(
|
||||||
total_items=0,
|
total_items=0,
|
||||||
@@ -246,7 +247,7 @@ def test_get_user_history_empty_results(
|
|||||||
)
|
)
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||||
return_value=mock_history_response,
|
return_value=mock_history_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -7,10 +7,10 @@ import fastapi
|
|||||||
import fastapi.responses
|
import fastapi.responses
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
|
|
||||||
import backend.server.v2.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.server.v2.store.db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.server.v2.store.embeddings as store_embeddings
|
import backend.api.features.store.embeddings as store_embeddings
|
||||||
import backend.server.v2.store.model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.util.json
|
import backend.util.json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -25,7 +25,7 @@ router = fastapi.APIRouter(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/listings",
|
"/listings",
|
||||||
summary="Get Admin Listings History",
|
summary="Get Admin Listings History",
|
||||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||||
)
|
)
|
||||||
async def get_admin_listings_with_versions(
|
async def get_admin_listings_with_versions(
|
||||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||||
@@ -49,7 +49,7 @@ async def get_admin_listings_with_versions(
|
|||||||
StoreListingsWithVersionsResponse with listings and their versions
|
StoreListingsWithVersionsResponse with listings and their versions
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
listings = await backend.server.v2.store.db.get_admin_listings_with_versions(
|
listings = await store_db.get_admin_listings_with_versions(
|
||||||
status=status,
|
status=status,
|
||||||
search_query=search,
|
search_query=search,
|
||||||
page=page,
|
page=page,
|
||||||
@@ -69,11 +69,11 @@ async def get_admin_listings_with_versions(
|
|||||||
@router.post(
|
@router.post(
|
||||||
"/submissions/{store_listing_version_id}/review",
|
"/submissions/{store_listing_version_id}/review",
|
||||||
summary="Review Store Submission",
|
summary="Review Store Submission",
|
||||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
response_model=store_model.StoreSubmission,
|
||||||
)
|
)
|
||||||
async def review_submission(
|
async def review_submission(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
request: store_model.ReviewSubmissionRequest,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -88,12 +88,10 @@ async def review_submission(
|
|||||||
StoreSubmission with updated review information
|
StoreSubmission with updated review information
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
already_approved = (
|
already_approved = await store_db.check_submission_already_approved(
|
||||||
await backend.server.v2.store.db.check_submission_already_approved(
|
store_listing_version_id=store_listing_version_id,
|
||||||
store_listing_version_id=store_listing_version_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
submission = await backend.server.v2.store.db.review_store_submission(
|
submission = await store_db.review_store_submission(
|
||||||
store_listing_version_id=store_listing_version_id,
|
store_listing_version_id=store_listing_version_id,
|
||||||
is_approved=request.is_approved,
|
is_approved=request.is_approved,
|
||||||
external_comments=request.comments,
|
external_comments=request.comments,
|
||||||
@@ -137,7 +135,7 @@ async def admin_download_agent_file(
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||||
"""
|
"""
|
||||||
graph_data = await backend.server.v2.store.db.get_agent_as_admin(
|
graph_data = await store_db.get_agent_as_admin(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
store_listing_version_id=store_listing_version_id,
|
store_listing_version_id=store_listing_version_id,
|
||||||
)
|
)
|
||||||
@@ -6,10 +6,11 @@ from typing import Annotated
|
|||||||
import fastapi
|
import fastapi
|
||||||
import pydantic
|
import pydantic
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
|
from autogpt_libs.auth.dependencies import requires_user
|
||||||
|
|
||||||
import backend.data.analytics
|
import backend.data.analytics
|
||||||
|
|
||||||
router = fastapi.APIRouter()
|
router = fastapi.APIRouter(dependencies=[fastapi.Security(requires_user)])
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
"""Tests for analytics API endpoints."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
|
from .analytics import router as analytics_router
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(analytics_router)
|
||||||
|
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_auth(mock_jwt_user):
|
||||||
|
"""Setup auth overrides for all tests in this module."""
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# /log_raw_metric endpoint tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_raw_metric_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful raw metric logging."""
|
||||||
|
mock_result = Mock(id="metric-123-uuid")
|
||||||
|
mock_log_metric = mocker.patch(
|
||||||
|
"backend.data.analytics.log_raw_metric",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"metric_name": "page_load_time",
|
||||||
|
"metric_value": 2.5,
|
||||||
|
"data_string": "/dashboard",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/log_raw_metric", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||||
|
assert response.json() == "metric-123-uuid"
|
||||||
|
|
||||||
|
mock_log_metric.assert_called_once_with(
|
||||||
|
user_id=test_user_id,
|
||||||
|
metric_name="page_load_time",
|
||||||
|
metric_value=2.5,
|
||||||
|
data_string="/dashboard",
|
||||||
|
)
|
||||||
|
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||||
|
"analytics_log_metric_success",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"metric_value,metric_name,data_string,test_id",
|
||||||
|
[
|
||||||
|
(100, "api_calls_count", "external_api", "integer_value"),
|
||||||
|
(0, "error_count", "no_errors", "zero_value"),
|
||||||
|
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||||
|
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||||
|
(999999999, "large_number", "max_value", "large_number"),
|
||||||
|
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_log_raw_metric_various_values(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
metric_value: float,
|
||||||
|
metric_name: str,
|
||||||
|
data_string: str,
|
||||||
|
test_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test raw metric logging with various metric values."""
|
||||||
|
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.analytics.log_raw_metric",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"metric_name": metric_name,
|
||||||
|
"metric_value": metric_value,
|
||||||
|
"data_string": data_string,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/log_raw_metric", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||||
|
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(
|
||||||
|
{"metric_id": response.json(), "test_case": test_id},
|
||||||
|
indent=2,
|
||||||
|
sort_keys=True,
|
||||||
|
),
|
||||||
|
f"analytics_metric_{test_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_data,expected_error",
|
||||||
|
[
|
||||||
|
({}, "Field required"),
|
||||||
|
({"metric_name": "test"}, "Field required"),
|
||||||
|
(
|
||||||
|
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||||
|
"Input should be a valid number",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||||
|
"String should have at least 1 character",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||||
|
"String should have at least 1 character",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"empty_request",
|
||||||
|
"missing_metric_value_and_data_string",
|
||||||
|
"invalid_metric_value_type",
|
||||||
|
"empty_metric_name",
|
||||||
|
"empty_data_string",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_log_raw_metric_validation_errors(
|
||||||
|
invalid_data: dict,
|
||||||
|
expected_error: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test validation errors for invalid metric requests."""
|
||||||
|
response = client.post("/log_raw_metric", json=invalid_data)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
error_detail = response.json()
|
||||||
|
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||||
|
|
||||||
|
error_text = json.dumps(error_detail)
|
||||||
|
assert (
|
||||||
|
expected_error in error_text
|
||||||
|
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_raw_metric_service_error(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test error handling when analytics service fails."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.analytics.log_raw_metric",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Database connection failed"),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"metric_name": "test_metric",
|
||||||
|
"metric_value": 1.0,
|
||||||
|
"data_string": "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/log_raw_metric", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
error_detail = response.json()["detail"]
|
||||||
|
assert "Database connection failed" in error_detail["message"]
|
||||||
|
assert "hint" in error_detail
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# /log_raw_analytics endpoint tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_raw_analytics_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful raw analytics logging."""
|
||||||
|
mock_result = Mock(id="analytics-789-uuid")
|
||||||
|
mock_log_analytics = mocker.patch(
|
||||||
|
"backend.data.analytics.log_raw_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"type": "user_action",
|
||||||
|
"data": {
|
||||||
|
"action": "button_click",
|
||||||
|
"button_id": "submit_form",
|
||||||
|
"timestamp": "2023-01-01T00:00:00Z",
|
||||||
|
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||||
|
},
|
||||||
|
"data_index": "button_click_submit_form",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/log_raw_analytics", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||||
|
assert response.json() == "analytics-789-uuid"
|
||||||
|
|
||||||
|
mock_log_analytics.assert_called_once_with(
|
||||||
|
test_user_id,
|
||||||
|
"user_action",
|
||||||
|
request_data["data"],
|
||||||
|
"button_click_submit_form",
|
||||||
|
)
|
||||||
|
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||||
|
"analytics_log_analytics_success",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_raw_analytics_complex_data(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test raw analytics logging with complex nested data structures."""
|
||||||
|
mock_result = Mock(id="analytics-complex-uuid")
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.analytics.log_raw_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"type": "agent_execution",
|
||||||
|
"data": {
|
||||||
|
"agent_id": "agent_123",
|
||||||
|
"execution_id": "exec_456",
|
||||||
|
"status": "completed",
|
||||||
|
"duration_ms": 3500,
|
||||||
|
"nodes_executed": 15,
|
||||||
|
"blocks_used": [
|
||||||
|
{"block_id": "llm_block", "count": 3},
|
||||||
|
{"block_id": "http_block", "count": 5},
|
||||||
|
{"block_id": "code_block", "count": 2},
|
||||||
|
],
|
||||||
|
"errors": [],
|
||||||
|
"metadata": {
|
||||||
|
"trigger": "manual",
|
||||||
|
"user_tier": "premium",
|
||||||
|
"environment": "production",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"data_index": "agent_123_exec_456",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/log_raw_analytics", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(
|
||||||
|
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||||
|
indent=2,
|
||||||
|
sort_keys=True,
|
||||||
|
),
|
||||||
|
"analytics_log_analytics_complex_data",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_data,expected_error",
|
||||||
|
[
|
||||||
|
({}, "Field required"),
|
||||||
|
({"type": "test"}, "Field required"),
|
||||||
|
(
|
||||||
|
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||||
|
"Input should be a valid dictionary",
|
||||||
|
),
|
||||||
|
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"empty_request",
|
||||||
|
"missing_data_and_data_index",
|
||||||
|
"invalid_data_type",
|
||||||
|
"missing_data_index",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_log_raw_analytics_validation_errors(
|
||||||
|
invalid_data: dict,
|
||||||
|
expected_error: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test validation errors for invalid analytics requests."""
|
||||||
|
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
error_detail = response.json()
|
||||||
|
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||||
|
|
||||||
|
error_text = json.dumps(error_detail)
|
||||||
|
assert (
|
||||||
|
expected_error in error_text
|
||||||
|
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_raw_analytics_service_error(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test error handling when analytics service fails."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.analytics.log_raw_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Analytics DB unreachable"),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"type": "test_event",
|
||||||
|
"data": {"key": "value"},
|
||||||
|
"data_index": "test_index",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/log_raw_analytics", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
error_detail = response.json()["detail"]
|
||||||
|
assert "Analytics DB unreachable" in error_detail["message"]
|
||||||
|
assert "hint" in error_detail
|
||||||
@@ -6,17 +6,20 @@ from typing import Sequence
|
|||||||
|
|
||||||
import prisma
|
import prisma
|
||||||
|
|
||||||
|
import backend.api.features.library.db as library_db
|
||||||
|
import backend.api.features.library.model as library_model
|
||||||
|
import backend.api.features.store.db as store_db
|
||||||
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.server.v2.library.db as library_db
|
|
||||||
import backend.server.v2.library.model as library_model
|
|
||||||
import backend.server.v2.store.db as store_db
|
|
||||||
import backend.server.v2.store.model as store_model
|
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.server.v2.builder.model import (
|
from backend.util.cache import cached
|
||||||
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
from .model import (
|
||||||
BlockCategoryResponse,
|
BlockCategoryResponse,
|
||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockType,
|
BlockType,
|
||||||
@@ -26,8 +29,6 @@ from backend.server.v2.builder.model import (
|
|||||||
ProviderResponse,
|
ProviderResponse,
|
||||||
SearchEntry,
|
SearchEntry,
|
||||||
)
|
)
|
||||||
from backend.util.cache import cached
|
|
||||||
from backend.util.models import Pagination
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||||
@@ -2,8 +2,8 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import backend.server.v2.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.server.v2.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.data.block import BlockInfo
|
from backend.data.block import BlockInfo
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
@@ -4,11 +4,12 @@ from typing import Annotated, Sequence
|
|||||||
import fastapi
|
import fastapi
|
||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
|
|
||||||
import backend.server.v2.builder.db as builder_db
|
|
||||||
import backend.server.v2.builder.model as builder_model
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
from . import db as builder_db
|
||||||
|
from . import model as builder_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = fastapi.APIRouter(
|
router = fastapi.APIRouter(
|
||||||
@@ -21,11 +21,12 @@ from prisma.models import ChatSession as PrismaChatSession
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.server.v2.chat import db as chat_db
|
|
||||||
from backend.server.v2.chat.config import ChatConfig
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.exceptions import RedisError
|
from backend.util.exceptions import RedisError
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
|
from .config import ChatConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.server.v2.chat.model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
Usage,
|
Usage,
|
||||||
@@ -9,10 +9,11 @@ from fastapi import APIRouter, Depends, Query, Security
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import backend.server.v2.chat.service as chat_service
|
|
||||||
from backend.server.v2.chat.config import ChatConfig
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from . import service as chat_service
|
||||||
|
from .config import ChatConfig
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
|
||||||
@@ -7,18 +7,23 @@ import orjson
|
|||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
|
|
||||||
import backend.server.v2.chat.config
|
|
||||||
import backend.server.v2.chat.db as chat_db
|
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
format_understanding_for_prompt,
|
format_understanding_for_prompt,
|
||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.server.v2.chat.model import ChatMessage, ChatSession, Usage
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.server.v2.chat.model import (
|
|
||||||
|
from . import db as chat_db
|
||||||
|
from .config import ChatConfig
|
||||||
|
from .model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
Usage,
|
||||||
create_chat_session as model_create_chat_session,
|
create_chat_session as model_create_chat_session,
|
||||||
|
get_chat_session,
|
||||||
|
upsert_chat_session,
|
||||||
)
|
)
|
||||||
from backend.server.v2.chat.model import get_chat_session, upsert_chat_session
|
from .response_model import (
|
||||||
from backend.server.v2.chat.response_model import (
|
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamEnd,
|
StreamEnd,
|
||||||
StreamError,
|
StreamError,
|
||||||
@@ -29,12 +34,11 @@ from backend.server.v2.chat.response_model import (
|
|||||||
StreamToolExecutionResult,
|
StreamToolExecutionResult,
|
||||||
StreamUsage,
|
StreamUsage,
|
||||||
)
|
)
|
||||||
from backend.server.v2.chat.tools import execute_tool, tools
|
from .tools import execute_tool, tools
|
||||||
from backend.util.exceptions import NotFoundError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
config = backend.server.v2.chat.config.ChatConfig()
|
config = ChatConfig()
|
||||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
|
|
||||||
@@ -161,9 +165,7 @@ async def get_user_sessions(
|
|||||||
"""
|
"""
|
||||||
Get all chat sessions for a user.
|
Get all chat sessions for a user.
|
||||||
"""
|
"""
|
||||||
from backend.server.v2.chat.model import (
|
from .model import get_user_sessions as model_get_user_sessions
|
||||||
get_user_sessions as model_get_user_sessions,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await model_get_user_sessions(user_id, limit, offset)
|
return await model_get_user_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
@@ -3,8 +3,8 @@ from os import getenv
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import backend.server.v2.chat.service as chat_service
|
from . import service as chat_service
|
||||||
from backend.server.v2.chat.response_model import (
|
from .response_model import (
|
||||||
StreamEnd,
|
StreamEnd,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamTextChunk,
|
StreamTextChunk,
|
||||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.server.v2.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .add_understanding import AddUnderstandingTool
|
from .add_understanding import AddUnderstandingTool
|
||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
@@ -17,7 +17,7 @@ from .run_block import RunBlockTool
|
|||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||||
|
|
||||||
# Initialize tool instances
|
# Initialize tool instances
|
||||||
add_understanding_tool = AddUnderstandingTool()
|
add_understanding_tool = AddUnderstandingTool()
|
||||||
@@ -5,6 +5,8 @@ from os import getenv
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.store import db as store_db
|
||||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
@@ -13,8 +15,6 @@ from backend.data.graph import Graph, Link, Node, create_graph
|
|||||||
from backend.data.model import APIKeyCredentials
|
from backend.data.model import APIKeyCredentials
|
||||||
from backend.data.user import get_or_create_user
|
from backend.data.user import get_or_create_user
|
||||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||||
from backend.server.v2.chat.model import ChatSession
|
|
||||||
from backend.server.v2.store import db as store_db
|
|
||||||
|
|
||||||
|
|
||||||
def make_session(user_id: str | None = None):
|
def make_session(user_id: str | None = None):
|
||||||
@@ -5,8 +5,8 @@ from typing import Any
|
|||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.server.v2.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||||
|
|
||||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||||
|
|
||||||
@@ -3,17 +3,18 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.server.v2.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.server.v2.chat.tools.base import BaseTool
|
from backend.api.features.store import db as store_db
|
||||||
from backend.server.v2.chat.tools.models import (
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
AgentCarouselResponse,
|
AgentCarouselResponse,
|
||||||
AgentInfo,
|
AgentInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
)
|
)
|
||||||
from backend.server.v2.store import db as store_db
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -5,14 +5,21 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from backend.api.features.chat.config import ChatConfig
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.user import get_user_by_id
|
from backend.data.user import get_user_by_id
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.server.v2.chat.config import ChatConfig
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.server.v2.chat.model import ChatSession
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
from backend.server.v2.chat.tools.base import BaseTool
|
from backend.util.timezone_utils import (
|
||||||
from backend.server.v2.chat.tools.models import (
|
convert_utc_time_to_user_timezone,
|
||||||
|
get_user_timezone_or_utc,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -23,20 +30,15 @@ from backend.server.v2.chat.tools.models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from backend.server.v2.chat.tools.utils import (
|
from .utils import (
|
||||||
check_user_has_required_credentials,
|
check_user_has_required_credentials,
|
||||||
extract_credentials_from_schema,
|
extract_credentials_from_schema,
|
||||||
fetch_graph_from_store_slug,
|
fetch_graph_from_store_slug,
|
||||||
get_or_create_library_agent,
|
get_or_create_library_agent,
|
||||||
match_user_credentials_to_graph,
|
match_user_credentials_to_graph,
|
||||||
)
|
)
|
||||||
from backend.server.v2.library import db as library_db
|
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.api.features.library import db as library_db
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
|
||||||
from backend.util.timezone_utils import (
|
|
||||||
convert_utc_time_to_user_timezone,
|
|
||||||
get_user_timezone_or_utc,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
@@ -3,13 +3,13 @@ import uuid
|
|||||||
import orjson
|
import orjson
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.server.v2.chat.tools._test_data import (
|
from ._test_data import (
|
||||||
make_session,
|
make_session,
|
||||||
setup_firecrawl_test_data,
|
setup_firecrawl_test_data,
|
||||||
setup_llm_test_data,
|
setup_llm_test_data,
|
||||||
setup_test_data,
|
setup_test_data,
|
||||||
)
|
)
|
||||||
from backend.server.v2.chat.tools.run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
|
|
||||||
# This is so the formatter doesn't remove the fixture imports
|
# This is so the formatter doesn't remove the fixture imports
|
||||||
setup_llm_test_data = setup_llm_test_data
|
setup_llm_test_data = setup_llm_test_data
|
||||||
@@ -3,13 +3,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
|
from backend.api.features.library import model as library_model
|
||||||
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.server.v2.library import db as library_db
|
|
||||||
from backend.server.v2.library import model as library_model
|
|
||||||
from backend.server.v2.store import db as store_db
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -7,9 +7,10 @@ import pytest_mock
|
|||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
from backend.server.rest_api import handle_internal_http_error
|
from backend.api.rest_api import handle_internal_http_error
|
||||||
from backend.server.v2.executions.review.model import PendingHumanReviewModel
|
|
||||||
from backend.server.v2.executions.review.routes import router
|
from .model import PendingHumanReviewModel
|
||||||
|
from .routes import router
|
||||||
|
|
||||||
# Using a fixed timestamp for reproducible tests
|
# Using a fixed timestamp for reproducible tests
|
||||||
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||||
@@ -54,13 +55,13 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel:
|
|||||||
|
|
||||||
|
|
||||||
def test_get_pending_reviews_empty(
|
def test_get_pending_reviews_empty(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
snapshot: Snapshot,
|
snapshot: Snapshot,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test getting pending reviews when none exist"""
|
"""Test getting pending reviews when none exist"""
|
||||||
mock_get_reviews = mocker.patch(
|
mock_get_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||||
)
|
)
|
||||||
mock_get_reviews.return_value = []
|
mock_get_reviews.return_value = []
|
||||||
|
|
||||||
@@ -72,14 +73,14 @@ def test_get_pending_reviews_empty(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_pending_reviews_with_data(
|
def test_get_pending_reviews_with_data(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
sample_pending_review: PendingHumanReviewModel,
|
sample_pending_review: PendingHumanReviewModel,
|
||||||
snapshot: Snapshot,
|
snapshot: Snapshot,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test getting pending reviews with data"""
|
"""Test getting pending reviews with data"""
|
||||||
mock_get_reviews = mocker.patch(
|
mock_get_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||||
)
|
)
|
||||||
mock_get_reviews.return_value = [sample_pending_review]
|
mock_get_reviews.return_value = [sample_pending_review]
|
||||||
|
|
||||||
@@ -94,14 +95,14 @@ def test_get_pending_reviews_with_data(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_pending_reviews_for_execution_success(
|
def test_get_pending_reviews_for_execution_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
sample_pending_review: PendingHumanReviewModel,
|
sample_pending_review: PendingHumanReviewModel,
|
||||||
snapshot: Snapshot,
|
snapshot: Snapshot,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test getting pending reviews for specific execution"""
|
"""Test getting pending reviews for specific execution"""
|
||||||
mock_get_graph_execution = mocker.patch(
|
mock_get_graph_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||||
)
|
)
|
||||||
mock_get_graph_execution.return_value = {
|
mock_get_graph_execution.return_value = {
|
||||||
"id": "test_graph_exec_456",
|
"id": "test_graph_exec_456",
|
||||||
@@ -109,7 +110,7 @@ def test_get_pending_reviews_for_execution_success(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock_get_reviews = mocker.patch(
|
mock_get_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_get_reviews.return_value = [sample_pending_review]
|
mock_get_reviews.return_value = [sample_pending_review]
|
||||||
|
|
||||||
@@ -121,24 +122,23 @@ def test_get_pending_reviews_for_execution_success(
|
|||||||
assert data[0]["graph_exec_id"] == "test_graph_exec_456"
|
assert data[0]["graph_exec_id"] == "test_graph_exec_456"
|
||||||
|
|
||||||
|
|
||||||
def test_get_pending_reviews_for_execution_access_denied(
|
def test_get_pending_reviews_for_execution_not_available(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test access denied when user doesn't own the execution"""
|
"""Test access denied when user doesn't own the execution"""
|
||||||
mock_get_graph_execution = mocker.patch(
|
mock_get_graph_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||||
)
|
)
|
||||||
mock_get_graph_execution.return_value = None
|
mock_get_graph_execution.return_value = None
|
||||||
|
|
||||||
response = client.get("/api/review/execution/test_graph_exec_456")
|
response = client.get("/api/review/execution/test_graph_exec_456")
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 404
|
||||||
assert "Access denied" in response.json()["detail"]
|
assert "not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
def test_process_review_action_approve_success(
|
def test_process_review_action_approve_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
sample_pending_review: PendingHumanReviewModel,
|
sample_pending_review: PendingHumanReviewModel,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -146,12 +146,12 @@ def test_process_review_action_approve_success(
|
|||||||
# Mock the route functions
|
# Mock the route functions
|
||||||
|
|
||||||
mock_get_reviews_for_execution = mocker.patch(
|
mock_get_reviews_for_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||||
|
|
||||||
mock_process_all_reviews = mocker.patch(
|
mock_process_all_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||||
)
|
)
|
||||||
# Create approved review for return
|
# Create approved review for return
|
||||||
approved_review = PendingHumanReviewModel(
|
approved_review = PendingHumanReviewModel(
|
||||||
@@ -174,11 +174,11 @@ def test_process_review_action_approve_success(
|
|||||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||||
|
|
||||||
mock_has_pending = mocker.patch(
|
mock_has_pending = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||||
)
|
)
|
||||||
mock_has_pending.return_value = False
|
mock_has_pending.return_value = False
|
||||||
|
|
||||||
mocker.patch("backend.server.v2.executions.review.routes.add_graph_execution")
|
mocker.patch("backend.api.features.executions.review.routes.add_graph_execution")
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"reviews": [
|
"reviews": [
|
||||||
@@ -202,7 +202,7 @@ def test_process_review_action_approve_success(
|
|||||||
|
|
||||||
|
|
||||||
def test_process_review_action_reject_success(
|
def test_process_review_action_reject_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
sample_pending_review: PendingHumanReviewModel,
|
sample_pending_review: PendingHumanReviewModel,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -210,12 +210,12 @@ def test_process_review_action_reject_success(
|
|||||||
# Mock the route functions
|
# Mock the route functions
|
||||||
|
|
||||||
mock_get_reviews_for_execution = mocker.patch(
|
mock_get_reviews_for_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||||
|
|
||||||
mock_process_all_reviews = mocker.patch(
|
mock_process_all_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||||
)
|
)
|
||||||
rejected_review = PendingHumanReviewModel(
|
rejected_review = PendingHumanReviewModel(
|
||||||
node_exec_id="test_node_123",
|
node_exec_id="test_node_123",
|
||||||
@@ -237,7 +237,7 @@ def test_process_review_action_reject_success(
|
|||||||
mock_process_all_reviews.return_value = {"test_node_123": rejected_review}
|
mock_process_all_reviews.return_value = {"test_node_123": rejected_review}
|
||||||
|
|
||||||
mock_has_pending = mocker.patch(
|
mock_has_pending = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||||
)
|
)
|
||||||
mock_has_pending.return_value = False
|
mock_has_pending.return_value = False
|
||||||
|
|
||||||
@@ -262,7 +262,7 @@ def test_process_review_action_reject_success(
|
|||||||
|
|
||||||
|
|
||||||
def test_process_review_action_mixed_success(
|
def test_process_review_action_mixed_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
sample_pending_review: PendingHumanReviewModel,
|
sample_pending_review: PendingHumanReviewModel,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -289,12 +289,12 @@ def test_process_review_action_mixed_success(
|
|||||||
# Mock the route functions
|
# Mock the route functions
|
||||||
|
|
||||||
mock_get_reviews_for_execution = mocker.patch(
|
mock_get_reviews_for_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_get_reviews_for_execution.return_value = [sample_pending_review, second_review]
|
mock_get_reviews_for_execution.return_value = [sample_pending_review, second_review]
|
||||||
|
|
||||||
mock_process_all_reviews = mocker.patch(
|
mock_process_all_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||||
)
|
)
|
||||||
# Create approved version of first review
|
# Create approved version of first review
|
||||||
approved_review = PendingHumanReviewModel(
|
approved_review = PendingHumanReviewModel(
|
||||||
@@ -338,7 +338,7 @@ def test_process_review_action_mixed_success(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock_has_pending = mocker.patch(
|
mock_has_pending = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||||
)
|
)
|
||||||
mock_has_pending.return_value = False
|
mock_has_pending.return_value = False
|
||||||
|
|
||||||
@@ -369,7 +369,7 @@ def test_process_review_action_mixed_success(
|
|||||||
|
|
||||||
|
|
||||||
def test_process_review_action_empty_request(
|
def test_process_review_action_empty_request(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test error when no reviews provided"""
|
"""Test error when no reviews provided"""
|
||||||
@@ -386,19 +386,19 @@ def test_process_review_action_empty_request(
|
|||||||
|
|
||||||
|
|
||||||
def test_process_review_action_review_not_found(
|
def test_process_review_action_review_not_found(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test error when review is not found"""
|
"""Test error when review is not found"""
|
||||||
# Mock the functions that extract graph execution ID from the request
|
# Mock the functions that extract graph execution ID from the request
|
||||||
mock_get_reviews_for_execution = mocker.patch(
|
mock_get_reviews_for_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_get_reviews_for_execution.return_value = [] # No reviews found
|
mock_get_reviews_for_execution.return_value = [] # No reviews found
|
||||||
|
|
||||||
# Mock process_all_reviews to simulate not finding reviews
|
# Mock process_all_reviews to simulate not finding reviews
|
||||||
mock_process_all_reviews = mocker.patch(
|
mock_process_all_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||||
)
|
)
|
||||||
# This should raise a ValueError with "Reviews not found" message based on the data/human_review.py logic
|
# This should raise a ValueError with "Reviews not found" message based on the data/human_review.py logic
|
||||||
mock_process_all_reviews.side_effect = ValueError(
|
mock_process_all_reviews.side_effect = ValueError(
|
||||||
@@ -422,20 +422,20 @@ def test_process_review_action_review_not_found(
|
|||||||
|
|
||||||
|
|
||||||
def test_process_review_action_partial_failure(
|
def test_process_review_action_partial_failure(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
sample_pending_review: PendingHumanReviewModel,
|
sample_pending_review: PendingHumanReviewModel,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test handling of partial failures in review processing"""
|
"""Test handling of partial failures in review processing"""
|
||||||
# Mock the route functions
|
# Mock the route functions
|
||||||
mock_get_reviews_for_execution = mocker.patch(
|
mock_get_reviews_for_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||||
|
|
||||||
# Mock partial failure in processing
|
# Mock partial failure in processing
|
||||||
mock_process_all_reviews = mocker.patch(
|
mock_process_all_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_process_all_reviews.side_effect = ValueError("Some reviews failed validation")
|
mock_process_all_reviews.side_effect = ValueError("Some reviews failed validation")
|
||||||
|
|
||||||
@@ -456,20 +456,20 @@ def test_process_review_action_partial_failure(
|
|||||||
|
|
||||||
|
|
||||||
def test_process_review_action_invalid_node_exec_id(
|
def test_process_review_action_invalid_node_exec_id(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
sample_pending_review: PendingHumanReviewModel,
|
sample_pending_review: PendingHumanReviewModel,
|
||||||
test_user_id: str,
|
test_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test failure when trying to process review with invalid node execution ID"""
|
"""Test failure when trying to process review with invalid node execution ID"""
|
||||||
# Mock the route functions
|
# Mock the route functions
|
||||||
mock_get_reviews_for_execution = mocker.patch(
|
mock_get_reviews_for_execution = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||||
|
|
||||||
# Mock validation failure - this should return 400, not 500
|
# Mock validation failure - this should return 400, not 500
|
||||||
mock_process_all_reviews = mocker.patch(
|
mock_process_all_reviews = mocker.patch(
|
||||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||||
)
|
)
|
||||||
mock_process_all_reviews.side_effect = ValueError(
|
mock_process_all_reviews.side_effect = ValueError(
|
||||||
"Invalid node execution ID format"
|
"Invalid node execution ID format"
|
||||||
@@ -13,11 +13,8 @@ from backend.data.human_review import (
|
|||||||
process_all_reviews_for_execution,
|
process_all_reviews_for_execution,
|
||||||
)
|
)
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.server.v2.executions.review.model import (
|
|
||||||
PendingHumanReviewModel,
|
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||||
ReviewRequest,
|
|
||||||
ReviewResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -70,8 +67,7 @@ async def list_pending_reviews(
|
|||||||
response_model=List[PendingHumanReviewModel],
|
response_model=List[PendingHumanReviewModel],
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "List of pending reviews for the execution"},
|
200: {"description": "List of pending reviews for the execution"},
|
||||||
400: {"description": "Invalid graph execution ID"},
|
404: {"description": "Graph execution not found"},
|
||||||
403: {"description": "Access denied to graph execution"},
|
|
||||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -94,7 +90,7 @@ async def list_pending_reviews_for_execution(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException:
|
HTTPException:
|
||||||
- 403: If user doesn't own the graph execution
|
- 404: If the graph execution doesn't exist or isn't owned by this user
|
||||||
- 500: If authentication fails or database error occurs
|
- 500: If authentication fails or database error occurs
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@@ -108,8 +104,8 @@ async def list_pending_reviews_for_execution(
|
|||||||
)
|
)
|
||||||
if not graph_exec:
|
if not graph_exec:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="Access denied to graph execution",
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||||
@@ -17,6 +17,8 @@ from fastapi import (
|
|||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
|
from backend.api.features.library.model import LibraryAgentPreset
|
||||||
from backend.data.graph import NodeModel, get_graph, set_node_webhook
|
from backend.data.graph import NodeModel, get_graph, set_node_webhook
|
||||||
from backend.data.integrations import (
|
from backend.data.integrations import (
|
||||||
WebhookEvent,
|
WebhookEvent,
|
||||||
@@ -45,13 +47,6 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
from backend.server.integrations.models import (
|
|
||||||
ProviderConstants,
|
|
||||||
ProviderNamesResponse,
|
|
||||||
get_all_provider_names,
|
|
||||||
)
|
|
||||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
|
||||||
from backend.server.v2.library.model import LibraryAgentPreset
|
|
||||||
from backend.util.exceptions import (
|
from backend.util.exceptions import (
|
||||||
GraphNotInLibraryError,
|
GraphNotInLibraryError,
|
||||||
MissingConfigError,
|
MissingConfigError,
|
||||||
@@ -60,6 +55,8 @@ from backend.util.exceptions import (
|
|||||||
)
|
)
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.integrations.oauth import BaseOAuthHandler
|
from backend.integrations.oauth import BaseOAuthHandler
|
||||||
|
|
||||||
@@ -4,16 +4,14 @@ from typing import Literal, Optional
|
|||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.errors
|
import prisma.errors
|
||||||
import prisma.fields
|
|
||||||
import prisma.models
|
import prisma.models
|
||||||
import prisma.types
|
import prisma.types
|
||||||
|
|
||||||
|
import backend.api.features.store.exceptions as store_exceptions
|
||||||
|
import backend.api.features.store.image_gen as store_image_gen
|
||||||
|
import backend.api.features.store.media as store_media
|
||||||
import backend.data.graph as graph_db
|
import backend.data.graph as graph_db
|
||||||
import backend.data.integrations as integrations_db
|
import backend.data.integrations as integrations_db
|
||||||
import backend.server.v2.library.model as library_model
|
|
||||||
import backend.server.v2.store.exceptions as store_exceptions
|
|
||||||
import backend.server.v2.store.image_gen as store_image_gen
|
|
||||||
import backend.server.v2.store.media as store_media
|
|
||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.execution import get_graph_execution
|
from backend.data.execution import get_graph_execution
|
||||||
@@ -28,6 +26,8 @@ from backend.util.json import SafeJson
|
|||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
from . import model as library_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = Config()
|
config = Config()
|
||||||
integration_creds_manager = IntegrationCredentialsManager()
|
integration_creds_manager = IntegrationCredentialsManager()
|
||||||
@@ -538,6 +538,7 @@ async def update_library_agent(
|
|||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
auto_update_version: Optional[bool] = None,
|
auto_update_version: Optional[bool] = None,
|
||||||
|
graph_version: Optional[int] = None,
|
||||||
is_favorite: Optional[bool] = None,
|
is_favorite: Optional[bool] = None,
|
||||||
is_archived: Optional[bool] = None,
|
is_archived: Optional[bool] = None,
|
||||||
is_deleted: Optional[Literal[False]] = None,
|
is_deleted: Optional[Literal[False]] = None,
|
||||||
@@ -550,6 +551,7 @@ async def update_library_agent(
|
|||||||
library_agent_id: The ID of the LibraryAgent to update.
|
library_agent_id: The ID of the LibraryAgent to update.
|
||||||
user_id: The owner of this LibraryAgent.
|
user_id: The owner of this LibraryAgent.
|
||||||
auto_update_version: Whether the agent should auto-update to active version.
|
auto_update_version: Whether the agent should auto-update to active version.
|
||||||
|
graph_version: Specific graph version to update to.
|
||||||
is_favorite: Whether this agent is marked as a favorite.
|
is_favorite: Whether this agent is marked as a favorite.
|
||||||
is_archived: Whether this agent is archived.
|
is_archived: Whether this agent is archived.
|
||||||
settings: User-specific settings for this library agent.
|
settings: User-specific settings for this library agent.
|
||||||
@@ -563,8 +565,8 @@ async def update_library_agent(
|
|||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Updating library agent {library_agent_id} for user {user_id} with "
|
f"Updating library agent {library_agent_id} for user {user_id} with "
|
||||||
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
|
f"auto_update_version={auto_update_version}, graph_version={graph_version}, "
|
||||||
f"is_archived={is_archived}, settings={settings}"
|
f"is_favorite={is_favorite}, is_archived={is_archived}, settings={settings}"
|
||||||
)
|
)
|
||||||
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
|
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
|
||||||
if auto_update_version is not None:
|
if auto_update_version is not None:
|
||||||
@@ -581,10 +583,23 @@ async def update_library_agent(
|
|||||||
update_fields["isDeleted"] = is_deleted
|
update_fields["isDeleted"] = is_deleted
|
||||||
if settings is not None:
|
if settings is not None:
|
||||||
update_fields["settings"] = SafeJson(settings.model_dump())
|
update_fields["settings"] = SafeJson(settings.model_dump())
|
||||||
if not update_fields:
|
|
||||||
raise ValueError("No values were passed to update")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# If graph_version is provided, update to that specific version
|
||||||
|
if graph_version is not None:
|
||||||
|
# Get the current agent to find its graph_id
|
||||||
|
agent = await get_library_agent(id=library_agent_id, user_id=user_id)
|
||||||
|
# Update to the specified version using existing function
|
||||||
|
return await update_agent_version_in_library(
|
||||||
|
user_id=user_id,
|
||||||
|
agent_graph_id=agent.graph_id,
|
||||||
|
agent_graph_version=graph_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise, just update the simple fields
|
||||||
|
if not update_fields:
|
||||||
|
raise ValueError("No values were passed to update")
|
||||||
|
|
||||||
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
|
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
|
||||||
where={"id": library_agent_id, "userId": user_id},
|
where={"id": library_agent_id, "userId": user_id},
|
||||||
data=update_fields,
|
data=update_fields,
|
||||||
@@ -1,16 +1,15 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
import prisma.errors
|
|
||||||
import prisma.models
|
import prisma.models
|
||||||
import prisma.types
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import backend.server.v2.library.db as db
|
import backend.api.features.store.exceptions
|
||||||
import backend.server.v2.store.exceptions
|
|
||||||
from backend.data.db import connect
|
from backend.data.db import connect
|
||||||
from backend.data.includes import library_agent_include
|
from backend.data.includes import library_agent_include
|
||||||
|
|
||||||
|
from . import db
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_library_agents(mocker):
|
async def test_get_library_agents(mocker):
|
||||||
@@ -88,7 +87,7 @@ async def test_add_agent_to_library(mocker):
|
|||||||
await connect()
|
await connect()
|
||||||
|
|
||||||
# Mock the transaction context
|
# Mock the transaction context
|
||||||
mock_transaction = mocker.patch("backend.server.v2.library.db.transaction")
|
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||||
# Mock data
|
# Mock data
|
||||||
@@ -151,7 +150,7 @@ async def test_add_agent_to_library(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||||
mock_graph_db = mocker.patch("backend.server.v2.library.db.graph_db")
|
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||||
mock_graph_model = mocker.Mock()
|
mock_graph_model = mocker.Mock()
|
||||||
mock_graph_model.nodes = (
|
mock_graph_model.nodes = (
|
||||||
[]
|
[]
|
||||||
@@ -159,7 +158,9 @@ async def test_add_agent_to_library(mocker):
|
|||||||
mock_graph_db.get_graph = mocker.AsyncMock(return_value=mock_graph_model)
|
mock_graph_db.get_graph = mocker.AsyncMock(return_value=mock_graph_model)
|
||||||
|
|
||||||
# Mock the model conversion
|
# Mock the model conversion
|
||||||
mock_from_db = mocker.patch("backend.server.v2.library.model.LibraryAgent.from_db")
|
mock_from_db = mocker.patch(
|
||||||
|
"backend.api.features.library.model.LibraryAgent.from_db"
|
||||||
|
)
|
||||||
mock_from_db.return_value = mocker.Mock()
|
mock_from_db.return_value = mocker.Mock()
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Call function and verify exception
|
# Call function and verify exception
|
||||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||||
await db.add_store_agent_to_library("version123", "test-user")
|
await db.add_store_agent_to_library("version123", "test-user")
|
||||||
|
|
||||||
# Verify mock called correctly
|
# Verify mock called correctly
|
||||||
@@ -385,6 +385,9 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
|||||||
auto_update_version: Optional[bool] = pydantic.Field(
|
auto_update_version: Optional[bool] = pydantic.Field(
|
||||||
default=None, description="Auto-update the agent version"
|
default=None, description="Auto-update the agent version"
|
||||||
)
|
)
|
||||||
|
graph_version: Optional[int] = pydantic.Field(
|
||||||
|
default=None, description="Specific graph version to update to"
|
||||||
|
)
|
||||||
is_favorite: Optional[bool] = pydantic.Field(
|
is_favorite: Optional[bool] = pydantic.Field(
|
||||||
default=None, description="Mark the agent as a favorite"
|
default=None, description="Mark the agent as a favorite"
|
||||||
)
|
)
|
||||||
@@ -3,7 +3,7 @@ import datetime
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import backend.server.v2.library.model as library_model
|
from . import model as library_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -6,12 +6,13 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
|||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from prisma.enums import OnboardingStep
|
from prisma.enums import OnboardingStep
|
||||||
|
|
||||||
import backend.server.v2.library.db as library_db
|
import backend.api.features.store.exceptions as store_exceptions
|
||||||
import backend.server.v2.library.model as library_model
|
|
||||||
import backend.server.v2.store.exceptions as store_exceptions
|
|
||||||
from backend.data.onboarding import complete_onboarding_step
|
from backend.data.onboarding import complete_onboarding_step
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
|
from .. import db as library_db
|
||||||
|
from .. import model as library_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -284,6 +285,7 @@ async def update_library_agent(
|
|||||||
library_agent_id=library_agent_id,
|
library_agent_id=library_agent_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
auto_update_version=payload.auto_update_version,
|
auto_update_version=payload.auto_update_version,
|
||||||
|
graph_version=payload.graph_version,
|
||||||
is_favorite=payload.is_favorite,
|
is_favorite=payload.is_favorite,
|
||||||
is_archived=payload.is_archived,
|
is_archived=payload.is_archived,
|
||||||
settings=payload.settings,
|
settings=payload.settings,
|
||||||
@@ -4,8 +4,6 @@ from typing import Any, Optional
|
|||||||
import autogpt_libs.auth as autogpt_auth_lib
|
import autogpt_libs.auth as autogpt_auth_lib
|
||||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||||
|
|
||||||
import backend.server.v2.library.db as db
|
|
||||||
import backend.server.v2.library.model as models
|
|
||||||
from backend.data.execution import GraphExecutionMeta
|
from backend.data.execution import GraphExecutionMeta
|
||||||
from backend.data.graph import get_graph
|
from backend.data.graph import get_graph
|
||||||
from backend.data.integrations import get_webhook
|
from backend.data.integrations import get_webhook
|
||||||
@@ -17,6 +15,9 @@ from backend.integrations.webhooks import get_webhook_manager
|
|||||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from .. import db
|
||||||
|
from .. import model as models
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
credentials_manager = IntegrationCredentialsManager()
|
credentials_manager = IntegrationCredentialsManager()
|
||||||
@@ -7,10 +7,11 @@ import pytest
|
|||||||
import pytest_mock
|
import pytest_mock
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
import backend.server.v2.library.model as library_model
|
|
||||||
from backend.server.v2.library.routes import router as library_router
|
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
from . import model as library_model
|
||||||
|
from .routes import router as library_router
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
app.include_router(library_router)
|
app.include_router(library_router)
|
||||||
|
|
||||||
@@ -86,7 +87,7 @@ async def test_get_library_agents_success(
|
|||||||
total_items=2, total_pages=1, current_page=1, page_size=50
|
total_items=2, total_pages=1, current_page=1, page_size=50
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||||
mock_db_call.return_value = mocked_value
|
mock_db_call.return_value = mocked_value
|
||||||
|
|
||||||
response = client.get("/agents?search_term=test")
|
response = client.get("/agents?search_term=test")
|
||||||
@@ -112,7 +113,7 @@ async def test_get_library_agents_success(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
mock_db_call.side_effect = Exception("Test error")
|
||||||
|
|
||||||
response = client.get("/agents?search_term=test")
|
response = client.get("/agents?search_term=test")
|
||||||
@@ -161,7 +162,7 @@ async def test_get_favorite_library_agents_success(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
mock_db_call = mocker.patch(
|
mock_db_call = mocker.patch(
|
||||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
"backend.api.features.library.db.list_favorite_library_agents"
|
||||||
)
|
)
|
||||||
mock_db_call.return_value = mocked_value
|
mock_db_call.return_value = mocked_value
|
||||||
|
|
||||||
@@ -184,7 +185,7 @@ def test_get_favorite_library_agents_error(
|
|||||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||||
):
|
):
|
||||||
mock_db_call = mocker.patch(
|
mock_db_call = mocker.patch(
|
||||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
"backend.api.features.library.db.list_favorite_library_agents"
|
||||||
)
|
)
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
mock_db_call.side_effect = Exception("Test error")
|
||||||
|
|
||||||
@@ -223,11 +224,11 @@ def test_add_agent_to_library_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_db_call = mocker.patch(
|
mock_db_call = mocker.patch(
|
||||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
"backend.api.features.library.db.add_store_agent_to_library"
|
||||||
)
|
)
|
||||||
mock_db_call.return_value = mock_library_agent
|
mock_db_call.return_value = mock_library_agent
|
||||||
mock_complete_onboarding = mocker.patch(
|
mock_complete_onboarding = mocker.patch(
|
||||||
"backend.server.v2.library.routes.agents.complete_onboarding_step",
|
"backend.api.features.library.routes.agents.complete_onboarding_step",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -249,7 +250,7 @@ def test_add_agent_to_library_success(
|
|||||||
|
|
||||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||||
mock_db_call = mocker.patch(
|
mock_db_call = mocker.patch(
|
||||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
"backend.api.features.library.db.add_store_agent_to_library"
|
||||||
)
|
)
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
mock_db_call.side_effect = Exception("Test error")
|
||||||
|
|
||||||
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
@@ -0,0 +1,833 @@
|
|||||||
|
"""
|
||||||
|
OAuth 2.0 Provider Endpoints
|
||||||
|
|
||||||
|
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||||
|
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||||
|
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||||
|
4. User approves → backend creates authorization code
|
||||||
|
5. User redirected back to app with code
|
||||||
|
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||||
|
7. App uses access token to call external API endpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from autogpt_libs.auth import get_user_id
|
||||||
|
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||||
|
from gcloud.aio import storage as async_storage
|
||||||
|
from PIL import Image
|
||||||
|
from prisma.enums import APIKeyPermission
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.data.auth.oauth import (
|
||||||
|
InvalidClientError,
|
||||||
|
InvalidGrantError,
|
||||||
|
OAuthApplicationInfo,
|
||||||
|
TokenIntrospectionResult,
|
||||||
|
consume_authorization_code,
|
||||||
|
create_access_token,
|
||||||
|
create_authorization_code,
|
||||||
|
create_refresh_token,
|
||||||
|
get_oauth_application,
|
||||||
|
get_oauth_application_by_id,
|
||||||
|
introspect_token,
|
||||||
|
list_user_oauth_applications,
|
||||||
|
refresh_tokens,
|
||||||
|
revoke_access_token,
|
||||||
|
revoke_refresh_token,
|
||||||
|
update_oauth_application,
|
||||||
|
validate_client_credentials,
|
||||||
|
validate_redirect_uri,
|
||||||
|
validate_scopes,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Request/Response Models
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
"""OAuth 2.0 token response"""
|
||||||
|
|
||||||
|
token_type: Literal["Bearer"] = "Bearer"
|
||||||
|
access_token: str
|
||||||
|
access_token_expires_at: datetime
|
||||||
|
refresh_token: str
|
||||||
|
refresh_token_expires_at: datetime
|
||||||
|
scopes: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""OAuth 2.0 error response"""
|
||||||
|
|
||||||
|
error: str
|
||||||
|
error_description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthApplicationPublicInfo(BaseModel):
|
||||||
|
"""Public information about an OAuth application (for consent screen)"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
logo_url: Optional[str] = None
|
||||||
|
scopes: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Application Info Endpoint
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/app/{client_id}",
|
||||||
|
responses={
|
||||||
|
404: {"description": "Application not found or disabled"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_oauth_app_info(
|
||||||
|
client_id: str, user_id: str = Security(get_user_id)
|
||||||
|
) -> OAuthApplicationPublicInfo:
|
||||||
|
"""
|
||||||
|
Get public information about an OAuth application.
|
||||||
|
|
||||||
|
This endpoint is used by the consent screen to display application details
|
||||||
|
to the user before they authorize access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- name: Application name
|
||||||
|
- description: Application description (if provided)
|
||||||
|
- scopes: List of scopes the application is allowed to request
|
||||||
|
"""
|
||||||
|
app = await get_oauth_application(client_id)
|
||||||
|
if not app or not app.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Application not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return OAuthApplicationPublicInfo(
|
||||||
|
name=app.name,
|
||||||
|
description=app.description,
|
||||||
|
logo_url=app.logo_url,
|
||||||
|
scopes=[s.value for s in app.scopes],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Authorization Endpoint
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizeRequest(BaseModel):
|
||||||
|
"""OAuth 2.0 authorization request"""
|
||||||
|
|
||||||
|
client_id: str = Field(description="Client identifier")
|
||||||
|
redirect_uri: str = Field(description="Redirect URI")
|
||||||
|
scopes: list[str] = Field(description="List of scopes")
|
||||||
|
state: str = Field(description="Anti-CSRF token from client")
|
||||||
|
response_type: str = Field(
|
||||||
|
default="code", description="Must be 'code' for authorization code flow"
|
||||||
|
)
|
||||||
|
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||||
|
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||||
|
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizeResponse(BaseModel):
|
||||||
|
"""OAuth 2.0 authorization response with redirect URL"""
|
||||||
|
|
||||||
|
redirect_url: str = Field(description="URL to redirect the user to")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/authorize")
|
||||||
|
async def authorize(
|
||||||
|
request: AuthorizeRequest = Body(),
|
||||||
|
user_id: str = Security(get_user_id),
|
||||||
|
) -> AuthorizeResponse:
|
||||||
|
"""
|
||||||
|
OAuth 2.0 Authorization Endpoint
|
||||||
|
|
||||||
|
User must be logged in (authenticated with Supabase JWT).
|
||||||
|
This endpoint creates an authorization code and returns a redirect URL.
|
||||||
|
|
||||||
|
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||||
|
|
||||||
|
The frontend consent screen should call this endpoint after the user approves,
|
||||||
|
then redirect the user to the returned `redirect_url`.
|
||||||
|
|
||||||
|
Request Body:
|
||||||
|
- client_id: The OAuth application's client ID
|
||||||
|
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||||
|
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||||
|
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||||
|
- response_type: Must be "code" (for authorization code flow)
|
||||||
|
- code_challenge: PKCE code challenge (required)
|
||||||
|
- code_challenge_method: "S256" (recommended) or "plain"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||||
|
|
||||||
|
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||||
|
for critical errors (like invalid redirect_uri).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate response_type
|
||||||
|
if request.response_type != "code":
|
||||||
|
return _error_redirect_url(
|
||||||
|
request.redirect_uri,
|
||||||
|
request.state,
|
||||||
|
"unsupported_response_type",
|
||||||
|
"Only 'code' response type is supported",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get application
|
||||||
|
app = await get_oauth_application(request.client_id)
|
||||||
|
if not app:
|
||||||
|
return _error_redirect_url(
|
||||||
|
request.redirect_uri,
|
||||||
|
request.state,
|
||||||
|
"invalid_client",
|
||||||
|
"Unknown client_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not app.is_active:
|
||||||
|
return _error_redirect_url(
|
||||||
|
request.redirect_uri,
|
||||||
|
request.state,
|
||||||
|
"invalid_client",
|
||||||
|
"Application is not active",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate redirect URI
|
||||||
|
if not validate_redirect_uri(app, request.redirect_uri):
|
||||||
|
# For invalid redirect_uri, we can't redirect safely
|
||||||
|
# Must return error instead
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(
|
||||||
|
"Invalid redirect_uri. "
|
||||||
|
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse and validate scopes
|
||||||
|
try:
|
||||||
|
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||||
|
except ValueError as e:
|
||||||
|
return _error_redirect_url(
|
||||||
|
request.redirect_uri,
|
||||||
|
request.state,
|
||||||
|
"invalid_scope",
|
||||||
|
f"Invalid scope: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not requested_scopes:
|
||||||
|
return _error_redirect_url(
|
||||||
|
request.redirect_uri,
|
||||||
|
request.state,
|
||||||
|
"invalid_scope",
|
||||||
|
"At least one scope is required",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not validate_scopes(app, requested_scopes):
|
||||||
|
return _error_redirect_url(
|
||||||
|
request.redirect_uri,
|
||||||
|
request.state,
|
||||||
|
"invalid_scope",
|
||||||
|
"Application is not authorized for all requested scopes. "
|
||||||
|
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create authorization code
|
||||||
|
auth_code = await create_authorization_code(
|
||||||
|
application_id=app.id,
|
||||||
|
user_id=user_id,
|
||||||
|
scopes=requested_scopes,
|
||||||
|
redirect_uri=request.redirect_uri,
|
||||||
|
code_challenge=request.code_challenge,
|
||||||
|
code_challenge_method=request.code_challenge_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build redirect URL with authorization code
|
||||||
|
params = {
|
||||||
|
"code": auth_code.code,
|
||||||
|
"state": request.state,
|
||||||
|
}
|
||||||
|
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Authorization code issued for user #{user_id} "
|
||||||
|
f"and app {app.name} (#{app.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthorizeResponse(redirect_url=redirect_url)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||||
|
return _error_redirect_url(
|
||||||
|
request.redirect_uri,
|
||||||
|
request.state,
|
||||||
|
"server_error",
|
||||||
|
"An unexpected error occurred",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _error_redirect_url(
|
||||||
|
redirect_uri: str,
|
||||||
|
state: str,
|
||||||
|
error: str,
|
||||||
|
error_description: Optional[str] = None,
|
||||||
|
) -> AuthorizeResponse:
|
||||||
|
"""Helper to build redirect URL with OAuth error parameters"""
|
||||||
|
params = {
|
||||||
|
"error": error,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
if error_description:
|
||||||
|
params["error_description"] = error_description
|
||||||
|
|
||||||
|
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||||
|
return AuthorizeResponse(redirect_url=redirect_url)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Endpoint
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRequestByCode(BaseModel):
|
||||||
|
grant_type: Literal["authorization_code"]
|
||||||
|
code: str = Field(description="Authorization code")
|
||||||
|
redirect_uri: str = Field(
|
||||||
|
description="Redirect URI (must match authorization request)"
|
||||||
|
)
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
code_verifier: str = Field(description="PKCE code verifier")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRequestByRefreshToken(BaseModel):
|
||||||
|
grant_type: Literal["refresh_token"]
|
||||||
|
refresh_token: str
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/token")
|
||||||
|
async def token(
|
||||||
|
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||||
|
) -> TokenResponse:
|
||||||
|
"""
|
||||||
|
OAuth 2.0 Token Endpoint
|
||||||
|
|
||||||
|
Exchanges authorization code or refresh token for access token.
|
||||||
|
|
||||||
|
Grant Types:
|
||||||
|
1. authorization_code: Exchange authorization code for tokens
|
||||||
|
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||||
|
- Optional: code_verifier (required if PKCE was used)
|
||||||
|
|
||||||
|
2. refresh_token: Exchange refresh token for new access token
|
||||||
|
- Required: grant_type, refresh_token, client_id, client_secret
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- access_token: Bearer token for API access (1 hour TTL)
|
||||||
|
- token_type: "Bearer"
|
||||||
|
- expires_in: Seconds until access token expires
|
||||||
|
- refresh_token: Token for refreshing access (30 days TTL)
|
||||||
|
- scopes: List of scopes
|
||||||
|
"""
|
||||||
|
# Validate client credentials
|
||||||
|
try:
|
||||||
|
app = await validate_client_credentials(
|
||||||
|
request.client_id, request.client_secret
|
||||||
|
)
|
||||||
|
except InvalidClientError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle authorization_code grant
|
||||||
|
if request.grant_type == "authorization_code":
|
||||||
|
# Consume authorization code
|
||||||
|
try:
|
||||||
|
user_id, scopes = await consume_authorization_code(
|
||||||
|
code=request.code,
|
||||||
|
application_id=app.id,
|
||||||
|
redirect_uri=request.redirect_uri,
|
||||||
|
code_verifier=request.code_verifier,
|
||||||
|
)
|
||||||
|
except InvalidGrantError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create access and refresh tokens
|
||||||
|
access_token = await create_access_token(app.id, user_id, scopes)
|
||||||
|
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||||
|
"via authorization code"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not access_token.token or not refresh_token.token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to generate tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
token_type="Bearer",
|
||||||
|
access_token=access_token.token.get_secret_value(),
|
||||||
|
access_token_expires_at=access_token.expires_at,
|
||||||
|
refresh_token=refresh_token.token.get_secret_value(),
|
||||||
|
refresh_token_expires_at=refresh_token.expires_at,
|
||||||
|
scopes=list(s.value for s in scopes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle refresh_token grant
|
||||||
|
elif request.grant_type == "refresh_token":
|
||||||
|
# Refresh access token
|
||||||
|
try:
|
||||||
|
new_access_token, new_refresh_token = await refresh_tokens(
|
||||||
|
request.refresh_token, app.id
|
||||||
|
)
|
||||||
|
except InvalidGrantError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||||
|
f"by app {app.name} (#{app.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not new_access_token.token or not new_refresh_token.token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to generate tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
token_type="Bearer",
|
||||||
|
access_token=new_access_token.token.get_secret_value(),
|
||||||
|
access_token_expires_at=new_access_token.expires_at,
|
||||||
|
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||||
|
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||||
|
scopes=list(s.value for s in new_access_token.scopes),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||||
|
"Must be 'authorization_code' or 'refresh_token'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Introspection Endpoint
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/introspect")
|
||||||
|
async def introspect(
|
||||||
|
token: str = Body(description="Token to introspect"),
|
||||||
|
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||||
|
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||||
|
),
|
||||||
|
client_id: str = Body(description="Client identifier"),
|
||||||
|
client_secret: str = Body(description="Client secret"),
|
||||||
|
) -> TokenIntrospectionResult:
|
||||||
|
"""
|
||||||
|
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||||
|
|
||||||
|
Allows clients to check if a token is valid and get its metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- active: Whether the token is currently active
|
||||||
|
- scopes: List of authorized scopes (if active)
|
||||||
|
- client_id: The client the token was issued to (if active)
|
||||||
|
- user_id: The user the token represents (if active)
|
||||||
|
- exp: Expiration timestamp (if active)
|
||||||
|
- token_type: "access_token" or "refresh_token" (if active)
|
||||||
|
"""
|
||||||
|
# Validate client credentials
|
||||||
|
try:
|
||||||
|
await validate_client_credentials(client_id, client_secret)
|
||||||
|
except InvalidClientError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Introspect the token
|
||||||
|
return await introspect_token(token, token_type_hint)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Revocation Endpoint
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/revoke")
|
||||||
|
async def revoke(
|
||||||
|
token: str = Body(description="Token to revoke"),
|
||||||
|
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||||
|
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||||
|
),
|
||||||
|
client_id: str = Body(description="Client identifier"),
|
||||||
|
client_secret: str = Body(description="Client secret"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||||
|
|
||||||
|
Allows clients to revoke an access or refresh token.
|
||||||
|
|
||||||
|
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||||
|
Revoking an access token does NOT revoke the associated refresh token.
|
||||||
|
"""
|
||||||
|
# Validate client credentials
|
||||||
|
try:
|
||||||
|
app = await validate_client_credentials(client_id, client_secret)
|
||||||
|
except InvalidClientError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to revoke as access token first
|
||||||
|
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||||
|
if token_type_hint != "refresh_token":
|
||||||
|
revoked = await revoke_access_token(token, app.id)
|
||||||
|
if revoked:
|
||||||
|
logger.info(
|
||||||
|
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||||
|
f"user #{revoked.user_id}"
|
||||||
|
)
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
# Try to revoke as refresh token
|
||||||
|
revoked = await revoke_refresh_token(token, app.id)
|
||||||
|
if revoked:
|
||||||
|
logger.info(
|
||||||
|
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||||
|
f"user #{revoked.user_id}"
|
||||||
|
)
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||||
|
# or if token belongs to a different application.
|
||||||
|
# This prevents token scanning attacks.
|
||||||
|
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Application Management Endpoints (for app owners)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/apps/mine")
|
||||||
|
async def list_my_oauth_apps(
|
||||||
|
user_id: str = Security(get_user_id),
|
||||||
|
) -> list[OAuthApplicationInfo]:
|
||||||
|
"""
|
||||||
|
List all OAuth applications owned by the current user.
|
||||||
|
|
||||||
|
Returns a list of OAuth applications with their details including:
|
||||||
|
- id, name, description, logo_url
|
||||||
|
- client_id (public identifier)
|
||||||
|
- redirect_uris, grant_types, scopes
|
||||||
|
- is_active status
|
||||||
|
- created_at, updated_at timestamps
|
||||||
|
|
||||||
|
Note: client_secret is never returned for security reasons.
|
||||||
|
"""
|
||||||
|
return await list_user_oauth_applications(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/apps/{app_id}/status")
|
||||||
|
async def update_app_status(
|
||||||
|
app_id: str,
|
||||||
|
user_id: str = Security(get_user_id),
|
||||||
|
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||||
|
) -> OAuthApplicationInfo:
|
||||||
|
"""
|
||||||
|
Enable or disable an OAuth application.
|
||||||
|
|
||||||
|
Only the application owner can update the status.
|
||||||
|
When disabled, the application cannot be used for new authorizations
|
||||||
|
and existing access tokens will fail validation.
|
||||||
|
|
||||||
|
Returns the updated application info.
|
||||||
|
"""
|
||||||
|
updated_app = await update_oauth_application(
|
||||||
|
app_id=app_id,
|
||||||
|
owner_id=user_id,
|
||||||
|
is_active=is_active,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated_app:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Application not found or you don't have permission to update it",
|
||||||
|
)
|
||||||
|
|
||||||
|
action = "enabled" if is_active else "disabled"
|
||||||
|
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||||
|
|
||||||
|
return updated_app
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAppLogoRequest(BaseModel):
|
||||||
|
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/apps/{app_id}/logo")
|
||||||
|
async def update_app_logo(
|
||||||
|
app_id: str,
|
||||||
|
request: UpdateAppLogoRequest = Body(),
|
||||||
|
user_id: str = Security(get_user_id),
|
||||||
|
) -> OAuthApplicationInfo:
|
||||||
|
"""
|
||||||
|
Update the logo URL for an OAuth application.
|
||||||
|
|
||||||
|
Only the application owner can update the logo.
|
||||||
|
The logo should be uploaded first using the media upload endpoint,
|
||||||
|
then this endpoint is called with the resulting URL.
|
||||||
|
|
||||||
|
Logo requirements:
|
||||||
|
- Must be square (1:1 aspect ratio)
|
||||||
|
- Minimum 512x512 pixels
|
||||||
|
- Maximum 2048x2048 pixels
|
||||||
|
|
||||||
|
Returns the updated application info.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
not (app := await get_oauth_application_by_id(app_id))
|
||||||
|
or app.owner_id != user_id
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="OAuth App not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||||
|
await _delete_app_current_logo_file(app)
|
||||||
|
|
||||||
|
updated_app = await update_oauth_application(
|
||||||
|
app_id=app_id,
|
||||||
|
owner_id=user_id,
|
||||||
|
logo_url=request.logo_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated_app:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Application not found or you don't have permission to update it",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_app
|
||||||
|
|
||||||
|
|
||||||
|
# Logo upload constraints
|
||||||
|
LOGO_MIN_SIZE = 512
|
||||||
|
LOGO_MAX_SIZE = 2048
|
||||||
|
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||||
|
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/apps/{app_id}/logo/upload")
|
||||||
|
async def upload_app_logo(
|
||||||
|
app_id: str,
|
||||||
|
file: UploadFile,
|
||||||
|
user_id: str = Security(get_user_id),
|
||||||
|
) -> OAuthApplicationInfo:
|
||||||
|
"""
|
||||||
|
Upload a logo image for an OAuth application.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Image must be square (1:1 aspect ratio)
|
||||||
|
- Minimum 512x512 pixels
|
||||||
|
- Maximum 2048x2048 pixels
|
||||||
|
- Allowed formats: JPEG, PNG, WebP
|
||||||
|
- Maximum file size: 3MB
|
||||||
|
|
||||||
|
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||||
|
Returns the updated application info.
|
||||||
|
"""
|
||||||
|
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||||
|
if (
|
||||||
|
not (app := await get_oauth_application_by_id(app_id))
|
||||||
|
or app.owner_id != user_id
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="OAuth App not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check GCS configuration
|
||||||
|
if not settings.config.media_gcs_bucket_name:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Media storage is not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate content type
|
||||||
|
content_type = file.content_type
|
||||||
|
if content_type not in LOGO_ALLOWED_TYPES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
file_bytes = await file.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading logo file: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to read uploaded file",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(
|
||||||
|
"File too large. "
|
||||||
|
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate image dimensions
|
||||||
|
try:
|
||||||
|
image = Image.open(io.BytesIO(file_bytes))
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
if width != height:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Logo must be square. Got {width}x{height}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if width < LOGO_MIN_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||||
|
f"Got {width}x{height}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if width > LOGO_MAX_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||||
|
f"Got {width}x{height}",
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating logo image: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid image file",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scan for viruses
|
||||||
|
filename = file.filename or "logo"
|
||||||
|
await scan_content_safe(file_bytes, filename=filename)
|
||||||
|
|
||||||
|
# Generate unique filename
|
||||||
|
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||||
|
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||||
|
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||||
|
|
||||||
|
# Upload to GCS
|
||||||
|
try:
|
||||||
|
async with async_storage.Storage() as async_client:
|
||||||
|
bucket_name = settings.config.media_gcs_bucket_name
|
||||||
|
|
||||||
|
await async_client.upload(
|
||||||
|
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||||
|
)
|
||||||
|
|
||||||
|
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error uploading logo to GCS: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to upload logo",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||||
|
await _delete_app_current_logo_file(app)
|
||||||
|
|
||||||
|
# Update the app with the new logo URL
|
||||||
|
updated_app = await update_oauth_application(
|
||||||
|
app_id=app_id,
|
||||||
|
owner_id=user_id,
|
||||||
|
logo_url=logo_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated_app:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Application not found or you don't have permission to update it",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_app
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||||
|
"""
|
||||||
|
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||||
|
"""
|
||||||
|
bucket_name = settings.config.media_gcs_bucket_name
|
||||||
|
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||||
|
|
||||||
|
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||||
|
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||||
|
old_path = app.logo_url.replace(storage_base_url, "")
|
||||||
|
try:
|
||||||
|
async with async_storage.Storage() as async_client:
|
||||||
|
await async_client.delete(bucket_name, old_path)
|
||||||
|
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - the new logo was uploaded successfully
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||||
|
)
|
||||||
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,9 +6,9 @@ import pytest
|
|||||||
import pytest_mock
|
import pytest_mock
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
import backend.server.v2.otto.models as otto_models
|
from . import models as otto_models
|
||||||
import backend.server.v2.otto.routes as otto_routes
|
from . import routes as otto_routes
|
||||||
from backend.server.v2.otto.service import OttoService
|
from .service import OttoService
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
app.include_router(otto_routes.router)
|
app.include_router(otto_routes.router)
|
||||||
@@ -4,12 +4,15 @@ from typing import Annotated
|
|||||||
from fastapi import APIRouter, Body, HTTPException, Query, Security
|
from fastapi import APIRouter, Body, HTTPException, Query, Security
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from backend.api.utils.api_key_auth import APIKeyAuthenticator
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_user_by_email,
|
get_user_by_email,
|
||||||
set_user_email_verification,
|
set_user_email_verification,
|
||||||
unsubscribe_user_by_token,
|
unsubscribe_user_by_token,
|
||||||
)
|
)
|
||||||
from backend.server.routers.postmark.models import (
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from .models import (
|
||||||
PostmarkBounceEnum,
|
PostmarkBounceEnum,
|
||||||
PostmarkBounceWebhook,
|
PostmarkBounceWebhook,
|
||||||
PostmarkClickWebhook,
|
PostmarkClickWebhook,
|
||||||
@@ -19,8 +22,6 @@ from backend.server.routers.postmark.models import (
|
|||||||
PostmarkSubscriptionChangeWebhook,
|
PostmarkSubscriptionChangeWebhook,
|
||||||
PostmarkWebhook,
|
PostmarkWebhook,
|
||||||
)
|
)
|
||||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import backend.server.v2.store.db
|
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
|
from . import db as store_db
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
############### Caches #######################
|
############### Caches #######################
|
||||||
##############################################
|
##############################################
|
||||||
@@ -29,7 +30,7 @@ async def _get_cached_store_agents(
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
):
|
):
|
||||||
"""Cached helper to get store agents."""
|
"""Cached helper to get store agents."""
|
||||||
return await backend.server.v2.store.db.get_store_agents(
|
return await store_db.get_store_agents(
|
||||||
featured=featured,
|
featured=featured,
|
||||||
creators=[creator] if creator else None,
|
creators=[creator] if creator else None,
|
||||||
sorted_by=sorted_by,
|
sorted_by=sorted_by,
|
||||||
@@ -42,10 +43,12 @@ async def _get_cached_store_agents(
|
|||||||
|
|
||||||
# Cache individual agent details for 15 minutes
|
# Cache individual agent details for 15 minutes
|
||||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
async def _get_cached_agent_details(
|
||||||
|
username: str, agent_name: str, include_changelog: bool = False
|
||||||
|
):
|
||||||
"""Cached helper to get agent details."""
|
"""Cached helper to get agent details."""
|
||||||
return await backend.server.v2.store.db.get_store_agent_details(
|
return await store_db.get_store_agent_details(
|
||||||
username=username, agent_name=agent_name
|
username=username, agent_name=agent_name, include_changelog=include_changelog
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -59,7 +62,7 @@ async def _get_cached_store_creators(
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
):
|
):
|
||||||
"""Cached helper to get store creators."""
|
"""Cached helper to get store creators."""
|
||||||
return await backend.server.v2.store.db.get_store_creators(
|
return await store_db.get_store_creators(
|
||||||
featured=featured,
|
featured=featured,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
sorted_by=sorted_by,
|
sorted_by=sorted_by,
|
||||||
@@ -72,6 +75,4 @@ async def _get_cached_store_creators(
|
|||||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||||
async def _get_cached_creator_details(username: str):
|
async def _get_cached_creator_details(username: str):
|
||||||
"""Cached helper to get creator details."""
|
"""Cached helper to get creator details."""
|
||||||
return await backend.server.v2.store.db.get_store_creator_details(
|
return await store_db.get_store_creator_details(username=username.lower())
|
||||||
username=username.lower()
|
|
||||||
)
|
|
||||||
@@ -9,9 +9,7 @@ import prisma.errors
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import prisma.types
|
import prisma.types
|
||||||
|
|
||||||
import backend.server.v2.store.exceptions
|
from backend.data.db import query_raw_with_schema, transaction
|
||||||
import backend.server.v2.store.model
|
|
||||||
from backend.data.db import transaction
|
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
GraphMeta,
|
GraphMeta,
|
||||||
GraphModel,
|
GraphModel,
|
||||||
@@ -29,6 +27,9 @@ from backend.notifications.notifications import queue_notification_async
|
|||||||
from backend.util.exceptions import DatabaseError
|
from backend.util.exceptions import DatabaseError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from . import exceptions as store_exceptions
|
||||||
|
from . import model as store_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
@@ -46,7 +47,7 @@ async def get_store_agents(
|
|||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
) -> store_model.StoreAgentsResponse:
|
||||||
"""
|
"""
|
||||||
Get PUBLIC store agents from the StoreAgent view
|
Get PUBLIC store agents from the StoreAgent view
|
||||||
"""
|
"""
|
||||||
@@ -73,10 +74,10 @@ async def get_store_agents(
|
|||||||
total_pages = (total + page_size - 1) // page_size
|
total_pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
# Convert raw results to StoreAgent models
|
# Convert raw results to StoreAgent models
|
||||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
store_agents: list[store_model.StoreAgent] = []
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
try:
|
try:
|
||||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
store_agent = store_model.StoreAgent(
|
||||||
slug=agent["slug"],
|
slug=agent["slug"],
|
||||||
agent_name=agent["agent_name"],
|
agent_name=agent["agent_name"],
|
||||||
agent_image=(
|
agent_image=(
|
||||||
@@ -122,11 +123,11 @@ async def get_store_agents(
|
|||||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||||
total_pages = (total + page_size - 1) // page_size
|
total_pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
store_agents: list[store_model.StoreAgent] = []
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
try:
|
try:
|
||||||
# Create the StoreAgent object safely
|
# Create the StoreAgent object safely
|
||||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
store_agent = store_model.StoreAgent(
|
||||||
slug=agent.slug,
|
slug=agent.slug,
|
||||||
agent_name=agent.agent_name,
|
agent_name=agent.agent_name,
|
||||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||||
@@ -148,9 +149,9 @@ async def get_store_agents(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug(f"Found {len(store_agents)} agents")
|
logger.debug(f"Found {len(store_agents)} agents")
|
||||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
return store_model.StoreAgentsResponse(
|
||||||
agents=store_agents,
|
agents=store_agents,
|
||||||
pagination=backend.server.v2.store.model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
current_page=page,
|
current_page=page,
|
||||||
total_items=total,
|
total_items=total,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
@@ -181,8 +182,8 @@ async def log_search_term(search_query: str):
|
|||||||
|
|
||||||
|
|
||||||
async def get_store_agent_details(
|
async def get_store_agent_details(
|
||||||
username: str, agent_name: str
|
username: str, agent_name: str, include_changelog: bool = False
|
||||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
) -> store_model.StoreAgentDetails:
|
||||||
"""Get PUBLIC store agent details from the StoreAgent view"""
|
"""Get PUBLIC store agent details from the StoreAgent view"""
|
||||||
logger.debug(f"Getting store agent details for {username}/{agent_name}")
|
logger.debug(f"Getting store agent details for {username}/{agent_name}")
|
||||||
|
|
||||||
@@ -193,7 +194,7 @@ async def get_store_agent_details(
|
|||||||
|
|
||||||
if not agent:
|
if not agent:
|
||||||
logger.warning(f"Agent not found: {username}/{agent_name}")
|
logger.warning(f"Agent not found: {username}/{agent_name}")
|
||||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
raise store_exceptions.AgentNotFoundError(
|
||||||
f"Agent {username}/{agent_name} not found"
|
f"Agent {username}/{agent_name} not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -246,8 +247,29 @@ async def get_store_agent_details(
|
|||||||
else:
|
else:
|
||||||
recommended_schedule_cron = None
|
recommended_schedule_cron = None
|
||||||
|
|
||||||
|
# Fetch changelog data if requested
|
||||||
|
changelog_data = None
|
||||||
|
if include_changelog and store_listing:
|
||||||
|
changelog_versions = (
|
||||||
|
await prisma.models.StoreListingVersion.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"storeListingId": store_listing.id,
|
||||||
|
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||||
|
},
|
||||||
|
order=[{"version": "desc"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
changelog_data = [
|
||||||
|
store_model.ChangelogEntry(
|
||||||
|
version=str(version.version),
|
||||||
|
changes_summary=version.changesSummary or "No changes recorded",
|
||||||
|
date=version.createdAt,
|
||||||
|
)
|
||||||
|
for version in changelog_versions
|
||||||
|
]
|
||||||
|
|
||||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||||
return backend.server.v2.store.model.StoreAgentDetails(
|
return store_model.StoreAgentDetails(
|
||||||
store_listing_version_id=agent.storeListingVersionId,
|
store_listing_version_id=agent.storeListingVersionId,
|
||||||
slug=agent.slug,
|
slug=agent.slug,
|
||||||
agent_name=agent.agent_name,
|
agent_name=agent.agent_name,
|
||||||
@@ -262,12 +284,15 @@ async def get_store_agent_details(
|
|||||||
runs=agent.runs,
|
runs=agent.runs,
|
||||||
rating=agent.rating,
|
rating=agent.rating,
|
||||||
versions=agent.versions,
|
versions=agent.versions,
|
||||||
|
agentGraphVersions=agent.agentGraphVersions,
|
||||||
|
agentGraphId=agent.agentGraphId,
|
||||||
last_updated=agent.updated_at,
|
last_updated=agent.updated_at,
|
||||||
active_version_id=active_version_id,
|
active_version_id=active_version_id,
|
||||||
has_approved_version=has_approved_version,
|
has_approved_version=has_approved_version,
|
||||||
recommended_schedule_cron=recommended_schedule_cron,
|
recommended_schedule_cron=recommended_schedule_cron,
|
||||||
|
changelog=changelog_data,
|
||||||
)
|
)
|
||||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
except store_exceptions.AgentNotFoundError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting store agent details: {e}")
|
logger.error(f"Error getting store agent details: {e}")
|
||||||
@@ -303,7 +328,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
|
|
||||||
async def get_store_agent_by_version_id(
|
async def get_store_agent_by_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
) -> store_model.StoreAgentDetails:
|
||||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -313,12 +338,12 @@ async def get_store_agent_by_version_id(
|
|||||||
|
|
||||||
if not agent:
|
if not agent:
|
||||||
logger.warning(f"Agent not found: {store_listing_version_id}")
|
logger.warning(f"Agent not found: {store_listing_version_id}")
|
||||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
raise store_exceptions.AgentNotFoundError(
|
||||||
f"Agent {store_listing_version_id} not found"
|
f"Agent {store_listing_version_id} not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Found agent details for {store_listing_version_id}")
|
logger.debug(f"Found agent details for {store_listing_version_id}")
|
||||||
return backend.server.v2.store.model.StoreAgentDetails(
|
return store_model.StoreAgentDetails(
|
||||||
store_listing_version_id=agent.storeListingVersionId,
|
store_listing_version_id=agent.storeListingVersionId,
|
||||||
slug=agent.slug,
|
slug=agent.slug,
|
||||||
agent_name=agent.agent_name,
|
agent_name=agent.agent_name,
|
||||||
@@ -333,9 +358,11 @@ async def get_store_agent_by_version_id(
|
|||||||
runs=agent.runs,
|
runs=agent.runs,
|
||||||
rating=agent.rating,
|
rating=agent.rating,
|
||||||
versions=agent.versions,
|
versions=agent.versions,
|
||||||
|
agentGraphVersions=agent.agentGraphVersions,
|
||||||
|
agentGraphId=agent.agentGraphId,
|
||||||
last_updated=agent.updated_at,
|
last_updated=agent.updated_at,
|
||||||
)
|
)
|
||||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
except store_exceptions.AgentNotFoundError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting store agent details: {e}")
|
logger.error(f"Error getting store agent details: {e}")
|
||||||
@@ -348,7 +375,7 @@ async def get_store_creators(
|
|||||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
) -> store_model.CreatorsResponse:
|
||||||
"""Get PUBLIC store creators from the Creator view"""
|
"""Get PUBLIC store creators from the Creator view"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
|
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
|
||||||
@@ -423,7 +450,7 @@ async def get_store_creators(
|
|||||||
|
|
||||||
# Convert to response model
|
# Convert to response model
|
||||||
creator_models = [
|
creator_models = [
|
||||||
backend.server.v2.store.model.Creator(
|
store_model.Creator(
|
||||||
username=creator.username,
|
username=creator.username,
|
||||||
name=creator.name,
|
name=creator.name,
|
||||||
description=creator.description,
|
description=creator.description,
|
||||||
@@ -437,9 +464,9 @@ async def get_store_creators(
|
|||||||
]
|
]
|
||||||
|
|
||||||
logger.debug(f"Found {len(creator_models)} creators")
|
logger.debug(f"Found {len(creator_models)} creators")
|
||||||
return backend.server.v2.store.model.CreatorsResponse(
|
return store_model.CreatorsResponse(
|
||||||
creators=creator_models,
|
creators=creator_models,
|
||||||
pagination=backend.server.v2.store.model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
current_page=page,
|
current_page=page,
|
||||||
total_items=total,
|
total_items=total,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
@@ -453,7 +480,7 @@ async def get_store_creators(
|
|||||||
|
|
||||||
async def get_store_creator_details(
|
async def get_store_creator_details(
|
||||||
username: str,
|
username: str,
|
||||||
) -> backend.server.v2.store.model.CreatorDetails:
|
) -> store_model.CreatorDetails:
|
||||||
logger.debug(f"Getting store creator details for {username}")
|
logger.debug(f"Getting store creator details for {username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -464,12 +491,10 @@ async def get_store_creator_details(
|
|||||||
|
|
||||||
if not creator:
|
if not creator:
|
||||||
logger.warning(f"Creator not found: {username}")
|
logger.warning(f"Creator not found: {username}")
|
||||||
raise backend.server.v2.store.exceptions.CreatorNotFoundError(
|
raise store_exceptions.CreatorNotFoundError(f"Creator {username} not found")
|
||||||
f"Creator {username} not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Found creator details for {username}")
|
logger.debug(f"Found creator details for {username}")
|
||||||
return backend.server.v2.store.model.CreatorDetails(
|
return store_model.CreatorDetails(
|
||||||
name=creator.name,
|
name=creator.name,
|
||||||
username=creator.username,
|
username=creator.username,
|
||||||
description=creator.description,
|
description=creator.description,
|
||||||
@@ -479,7 +504,7 @@ async def get_store_creator_details(
|
|||||||
agent_runs=creator.agent_runs,
|
agent_runs=creator.agent_runs,
|
||||||
top_categories=creator.top_categories,
|
top_categories=creator.top_categories,
|
||||||
)
|
)
|
||||||
except backend.server.v2.store.exceptions.CreatorNotFoundError:
|
except store_exceptions.CreatorNotFoundError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting store creator details: {e}")
|
logger.error(f"Error getting store creator details: {e}")
|
||||||
@@ -488,7 +513,7 @@ async def get_store_creator_details(
|
|||||||
|
|
||||||
async def get_store_submissions(
|
async def get_store_submissions(
|
||||||
user_id: str, page: int = 1, page_size: int = 20
|
user_id: str, page: int = 1, page_size: int = 20
|
||||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
) -> store_model.StoreSubmissionsResponse:
|
||||||
"""Get store submissions for the authenticated user -- not an admin"""
|
"""Get store submissions for the authenticated user -- not an admin"""
|
||||||
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
|
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
|
||||||
|
|
||||||
@@ -513,7 +538,7 @@ async def get_store_submissions(
|
|||||||
# Convert to response models
|
# Convert to response models
|
||||||
submission_models = []
|
submission_models = []
|
||||||
for sub in submissions:
|
for sub in submissions:
|
||||||
submission_model = backend.server.v2.store.model.StoreSubmission(
|
submission_model = store_model.StoreSubmission(
|
||||||
agent_id=sub.agent_id,
|
agent_id=sub.agent_id,
|
||||||
agent_version=sub.agent_version,
|
agent_version=sub.agent_version,
|
||||||
name=sub.name,
|
name=sub.name,
|
||||||
@@ -538,9 +563,9 @@ async def get_store_submissions(
|
|||||||
submission_models.append(submission_model)
|
submission_models.append(submission_model)
|
||||||
|
|
||||||
logger.debug(f"Found {len(submission_models)} submissions")
|
logger.debug(f"Found {len(submission_models)} submissions")
|
||||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
return store_model.StoreSubmissionsResponse(
|
||||||
submissions=submission_models,
|
submissions=submission_models,
|
||||||
pagination=backend.server.v2.store.model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
current_page=page,
|
current_page=page,
|
||||||
total_items=total,
|
total_items=total,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
@@ -551,9 +576,9 @@ async def get_store_submissions(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching store submissions: {e}")
|
logger.error(f"Error fetching store submissions: {e}")
|
||||||
# Return empty response rather than exposing internal errors
|
# Return empty response rather than exposing internal errors
|
||||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
return store_model.StoreSubmissionsResponse(
|
||||||
submissions=[],
|
submissions=[],
|
||||||
pagination=backend.server.v2.store.model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
current_page=page,
|
current_page=page,
|
||||||
total_items=0,
|
total_items=0,
|
||||||
total_pages=0,
|
total_pages=0,
|
||||||
@@ -586,7 +611,7 @@ async def delete_store_submission(
|
|||||||
|
|
||||||
if not submission:
|
if not submission:
|
||||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||||
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
|
raise store_exceptions.SubmissionNotFoundError(
|
||||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -618,7 +643,7 @@ async def create_store_submission(
|
|||||||
categories: list[str] = [],
|
categories: list[str] = [],
|
||||||
changes_summary: str | None = "Initial Submission",
|
changes_summary: str | None = "Initial Submission",
|
||||||
recommended_schedule_cron: str | None = None,
|
recommended_schedule_cron: str | None = None,
|
||||||
) -> backend.server.v2.store.model.StoreSubmission:
|
) -> store_model.StoreSubmission:
|
||||||
"""
|
"""
|
||||||
Create the first (and only) store listing and thus submission as a normal user
|
Create the first (and only) store listing and thus submission as a normal user
|
||||||
|
|
||||||
@@ -659,7 +684,7 @@ async def create_store_submission(
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||||
)
|
)
|
||||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
raise store_exceptions.AgentNotFoundError(
|
||||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -732,7 +757,7 @@ async def create_store_submission(
|
|||||||
|
|
||||||
logger.debug(f"Created store listing for agent {agent_id}")
|
logger.debug(f"Created store listing for agent {agent_id}")
|
||||||
# Return submission details
|
# Return submission details
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return store_model.StoreSubmission(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_version=agent_version,
|
agent_version=agent_version,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -755,7 +780,7 @@ async def create_store_submission(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Slug '{slug}' is already in use by another agent (agent_id: {agent_id}) for user {user_id}"
|
f"Slug '{slug}' is already in use by another agent (agent_id: {agent_id}) for user {user_id}"
|
||||||
)
|
)
|
||||||
raise backend.server.v2.store.exceptions.SlugAlreadyInUseError(
|
raise store_exceptions.SlugAlreadyInUseError(
|
||||||
f"The URL slug '{slug}' is already in use by another one of your agents. Please choose a different slug."
|
f"The URL slug '{slug}' is already in use by another one of your agents. Please choose a different slug."
|
||||||
) from exc
|
) from exc
|
||||||
else:
|
else:
|
||||||
@@ -764,8 +789,8 @@ async def create_store_submission(
|
|||||||
f"Unique constraint violated (not slug): {error_str}"
|
f"Unique constraint violated (not slug): {error_str}"
|
||||||
) from exc
|
) from exc
|
||||||
except (
|
except (
|
||||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
store_exceptions.AgentNotFoundError,
|
||||||
backend.server.v2.store.exceptions.ListingExistsError,
|
store_exceptions.ListingExistsError,
|
||||||
):
|
):
|
||||||
raise
|
raise
|
||||||
except prisma.errors.PrismaError as e:
|
except prisma.errors.PrismaError as e:
|
||||||
@@ -786,7 +811,7 @@ async def edit_store_submission(
|
|||||||
changes_summary: str | None = "Update submission",
|
changes_summary: str | None = "Update submission",
|
||||||
recommended_schedule_cron: str | None = None,
|
recommended_schedule_cron: str | None = None,
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
) -> backend.server.v2.store.model.StoreSubmission:
|
) -> store_model.StoreSubmission:
|
||||||
"""
|
"""
|
||||||
Edit an existing store listing submission.
|
Edit an existing store listing submission.
|
||||||
|
|
||||||
@@ -828,7 +853,7 @@ async def edit_store_submission(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not current_version:
|
if not current_version:
|
||||||
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
|
raise store_exceptions.SubmissionNotFoundError(
|
||||||
f"Store listing version not found: {store_listing_version_id}"
|
f"Store listing version not found: {store_listing_version_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -837,7 +862,7 @@ async def edit_store_submission(
|
|||||||
not current_version.StoreListing
|
not current_version.StoreListing
|
||||||
or current_version.StoreListing.owningUserId != user_id
|
or current_version.StoreListing.owningUserId != user_id
|
||||||
):
|
):
|
||||||
raise backend.server.v2.store.exceptions.UnauthorizedError(
|
raise store_exceptions.UnauthorizedError(
|
||||||
f"User {user_id} does not own submission {store_listing_version_id}"
|
f"User {user_id} does not own submission {store_listing_version_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -846,7 +871,7 @@ async def edit_store_submission(
|
|||||||
|
|
||||||
# Check if we can edit this submission
|
# Check if we can edit this submission
|
||||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
||||||
raise backend.server.v2.store.exceptions.InvalidOperationError(
|
raise store_exceptions.InvalidOperationError(
|
||||||
"Cannot edit a rejected submission"
|
"Cannot edit a rejected submission"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -895,7 +920,7 @@ async def edit_store_submission(
|
|||||||
|
|
||||||
if not updated_version:
|
if not updated_version:
|
||||||
raise DatabaseError("Failed to update store listing version")
|
raise DatabaseError("Failed to update store listing version")
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return store_model.StoreSubmission(
|
||||||
agent_id=current_version.agentGraphId,
|
agent_id=current_version.agentGraphId,
|
||||||
agent_version=current_version.agentGraphVersion,
|
agent_version=current_version.agentGraphVersion,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -916,16 +941,16 @@ async def edit_store_submission(
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise backend.server.v2.store.exceptions.InvalidOperationError(
|
raise store_exceptions.InvalidOperationError(
|
||||||
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except (
|
except (
|
||||||
backend.server.v2.store.exceptions.SubmissionNotFoundError,
|
store_exceptions.SubmissionNotFoundError,
|
||||||
backend.server.v2.store.exceptions.UnauthorizedError,
|
store_exceptions.UnauthorizedError,
|
||||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
store_exceptions.AgentNotFoundError,
|
||||||
backend.server.v2.store.exceptions.ListingExistsError,
|
store_exceptions.ListingExistsError,
|
||||||
backend.server.v2.store.exceptions.InvalidOperationError,
|
store_exceptions.InvalidOperationError,
|
||||||
):
|
):
|
||||||
raise
|
raise
|
||||||
except prisma.errors.PrismaError as e:
|
except prisma.errors.PrismaError as e:
|
||||||
@@ -948,7 +973,7 @@ async def create_store_version(
|
|||||||
categories: list[str] = [],
|
categories: list[str] = [],
|
||||||
changes_summary: str | None = "Initial submission",
|
changes_summary: str | None = "Initial submission",
|
||||||
recommended_schedule_cron: str | None = None,
|
recommended_schedule_cron: str | None = None,
|
||||||
) -> backend.server.v2.store.model.StoreSubmission:
|
) -> store_model.StoreSubmission:
|
||||||
"""
|
"""
|
||||||
Create a new version for an existing store listing
|
Create a new version for an existing store listing
|
||||||
|
|
||||||
@@ -981,7 +1006,7 @@ async def create_store_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not listing:
|
if not listing:
|
||||||
raise backend.server.v2.store.exceptions.ListingNotFoundError(
|
raise store_exceptions.ListingNotFoundError(
|
||||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -993,7 +1018,7 @@ async def create_store_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not agent:
|
if not agent:
|
||||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
raise store_exceptions.AgentNotFoundError(
|
||||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1028,7 +1053,7 @@ async def create_store_version(
|
|||||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||||
)
|
)
|
||||||
# Return submission details
|
# Return submission details
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return store_model.StoreSubmission(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_version=agent_version,
|
agent_version=agent_version,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -1055,7 +1080,7 @@ async def create_store_review(
|
|||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
score: int,
|
score: int,
|
||||||
comments: str | None = None,
|
comments: str | None = None,
|
||||||
) -> backend.server.v2.store.model.StoreReview:
|
) -> store_model.StoreReview:
|
||||||
"""Create a review for a store listing as a user to detail their experience"""
|
"""Create a review for a store listing as a user to detail their experience"""
|
||||||
try:
|
try:
|
||||||
data = prisma.types.StoreListingReviewUpsertInput(
|
data = prisma.types.StoreListingReviewUpsertInput(
|
||||||
@@ -1080,7 +1105,7 @@ async def create_store_review(
|
|||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return backend.server.v2.store.model.StoreReview(
|
return store_model.StoreReview(
|
||||||
score=review.score,
|
score=review.score,
|
||||||
comments=review.comments,
|
comments=review.comments,
|
||||||
)
|
)
|
||||||
@@ -1092,7 +1117,7 @@ async def create_store_review(
|
|||||||
|
|
||||||
async def get_user_profile(
|
async def get_user_profile(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> backend.server.v2.store.model.ProfileDetails | None:
|
) -> store_model.ProfileDetails | None:
|
||||||
logger.debug(f"Getting user profile for {user_id}")
|
logger.debug(f"Getting user profile for {user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1102,7 +1127,7 @@ async def get_user_profile(
|
|||||||
|
|
||||||
if not profile:
|
if not profile:
|
||||||
return None
|
return None
|
||||||
return backend.server.v2.store.model.ProfileDetails(
|
return store_model.ProfileDetails(
|
||||||
name=profile.name,
|
name=profile.name,
|
||||||
username=profile.username,
|
username=profile.username,
|
||||||
description=profile.description,
|
description=profile.description,
|
||||||
@@ -1115,8 +1140,8 @@ async def get_user_profile(
|
|||||||
|
|
||||||
|
|
||||||
async def update_profile(
|
async def update_profile(
|
||||||
user_id: str, profile: backend.server.v2.store.model.Profile
|
user_id: str, profile: store_model.Profile
|
||||||
) -> backend.server.v2.store.model.CreatorDetails:
|
) -> store_model.CreatorDetails:
|
||||||
"""
|
"""
|
||||||
Update the store profile for a user or create a new one if it doesn't exist.
|
Update the store profile for a user or create a new one if it doesn't exist.
|
||||||
Args:
|
Args:
|
||||||
@@ -1139,7 +1164,7 @@ async def update_profile(
|
|||||||
where={"userId": user_id}
|
where={"userId": user_id}
|
||||||
)
|
)
|
||||||
if not existing_profile:
|
if not existing_profile:
|
||||||
raise backend.server.v2.store.exceptions.ProfileNotFoundError(
|
raise store_exceptions.ProfileNotFoundError(
|
||||||
f"Profile not found for user {user_id}. This should not be possible."
|
f"Profile not found for user {user_id}. This should not be possible."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1175,7 +1200,7 @@ async def update_profile(
|
|||||||
logger.error(f"Failed to update profile for user {user_id}")
|
logger.error(f"Failed to update profile for user {user_id}")
|
||||||
raise DatabaseError("Failed to update profile")
|
raise DatabaseError("Failed to update profile")
|
||||||
|
|
||||||
return backend.server.v2.store.model.CreatorDetails(
|
return store_model.CreatorDetails(
|
||||||
name=updated_profile.name,
|
name=updated_profile.name,
|
||||||
username=updated_profile.username,
|
username=updated_profile.username,
|
||||||
description=updated_profile.description,
|
description=updated_profile.description,
|
||||||
@@ -1195,7 +1220,7 @@ async def get_my_agents(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
) -> store_model.MyAgentsResponse:
|
||||||
"""Get the agents for the authenticated user"""
|
"""Get the agents for the authenticated user"""
|
||||||
logger.debug(f"Getting my agents for user {user_id}, page={page}")
|
logger.debug(f"Getting my agents for user {user_id}, page={page}")
|
||||||
|
|
||||||
@@ -1232,7 +1257,7 @@ async def get_my_agents(
|
|||||||
total_pages = (total + page_size - 1) // page_size
|
total_pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
my_agents = [
|
my_agents = [
|
||||||
backend.server.v2.store.model.MyAgent(
|
store_model.MyAgent(
|
||||||
agent_id=graph.id,
|
agent_id=graph.id,
|
||||||
agent_version=graph.version,
|
agent_version=graph.version,
|
||||||
agent_name=graph.name or "",
|
agent_name=graph.name or "",
|
||||||
@@ -1245,9 +1270,9 @@ async def get_my_agents(
|
|||||||
if (graph := library_agent.AgentGraph)
|
if (graph := library_agent.AgentGraph)
|
||||||
]
|
]
|
||||||
|
|
||||||
return backend.server.v2.store.model.MyAgentsResponse(
|
return store_model.MyAgentsResponse(
|
||||||
agents=my_agents,
|
agents=my_agents,
|
||||||
pagination=backend.server.v2.store.model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
current_page=page,
|
current_page=page,
|
||||||
total_items=total,
|
total_items=total,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
@@ -1394,7 +1419,7 @@ async def review_store_submission(
|
|||||||
external_comments: str,
|
external_comments: str,
|
||||||
internal_comments: str,
|
internal_comments: str,
|
||||||
reviewer_id: str,
|
reviewer_id: str,
|
||||||
) -> backend.server.v2.store.model.StoreSubmission:
|
) -> store_model.StoreSubmission:
|
||||||
"""Review a store listing submission as an admin."""
|
"""Review a store listing submission as an admin."""
|
||||||
try:
|
try:
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -1625,7 +1650,7 @@ async def review_store_submission(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Convert to Pydantic model for consistency
|
# Convert to Pydantic model for consistency
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return store_model.StoreSubmission(
|
||||||
agent_id=submission.agentGraphId,
|
agent_id=submission.agentGraphId,
|
||||||
agent_version=submission.agentGraphVersion,
|
agent_version=submission.agentGraphVersion,
|
||||||
name=submission.name,
|
name=submission.name,
|
||||||
@@ -1660,7 +1685,7 @@ async def get_admin_listings_with_versions(
|
|||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> backend.server.v2.store.model.StoreListingsWithVersionsResponse:
|
) -> store_model.StoreListingsWithVersionsResponse:
|
||||||
"""
|
"""
|
||||||
Get store listings for admins with all their versions.
|
Get store listings for admins with all their versions.
|
||||||
|
|
||||||
@@ -1759,10 +1784,10 @@ async def get_admin_listings_with_versions(
|
|||||||
# Convert to response models
|
# Convert to response models
|
||||||
listings_with_versions = []
|
listings_with_versions = []
|
||||||
for listing in listings:
|
for listing in listings:
|
||||||
versions: list[backend.server.v2.store.model.StoreSubmission] = []
|
versions: list[store_model.StoreSubmission] = []
|
||||||
# If we have versions, turn them into StoreSubmission models
|
# If we have versions, turn them into StoreSubmission models
|
||||||
for version in listing.Versions or []:
|
for version in listing.Versions or []:
|
||||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
version_model = store_model.StoreSubmission(
|
||||||
agent_id=version.agentGraphId,
|
agent_id=version.agentGraphId,
|
||||||
agent_version=version.agentGraphVersion,
|
agent_version=version.agentGraphVersion,
|
||||||
name=version.name,
|
name=version.name,
|
||||||
@@ -1790,26 +1815,24 @@ async def get_admin_listings_with_versions(
|
|||||||
|
|
||||||
creator_email = listing.OwningUser.email if listing.OwningUser else None
|
creator_email = listing.OwningUser.email if listing.OwningUser else None
|
||||||
|
|
||||||
listing_with_versions = (
|
listing_with_versions = store_model.StoreListingWithVersions(
|
||||||
backend.server.v2.store.model.StoreListingWithVersions(
|
listing_id=listing.id,
|
||||||
listing_id=listing.id,
|
slug=listing.slug,
|
||||||
slug=listing.slug,
|
agent_id=listing.agentGraphId,
|
||||||
agent_id=listing.agentGraphId,
|
agent_version=listing.agentGraphVersion,
|
||||||
agent_version=listing.agentGraphVersion,
|
active_version_id=listing.activeVersionId,
|
||||||
active_version_id=listing.activeVersionId,
|
has_approved_version=listing.hasApprovedVersion,
|
||||||
has_approved_version=listing.hasApprovedVersion,
|
creator_email=creator_email,
|
||||||
creator_email=creator_email,
|
latest_version=latest_version,
|
||||||
latest_version=latest_version,
|
versions=versions,
|
||||||
versions=versions,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
listings_with_versions.append(listing_with_versions)
|
listings_with_versions.append(listing_with_versions)
|
||||||
|
|
||||||
logger.debug(f"Found {len(listings_with_versions)} listings for admin")
|
logger.debug(f"Found {len(listings_with_versions)} listings for admin")
|
||||||
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
|
return store_model.StoreListingsWithVersionsResponse(
|
||||||
listings=listings_with_versions,
|
listings=listings_with_versions,
|
||||||
pagination=backend.server.v2.store.model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
current_page=page,
|
current_page=page,
|
||||||
total_items=total,
|
total_items=total,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
@@ -1819,9 +1842,9 @@ async def get_admin_listings_with_versions(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching admin store listings: {e}")
|
logger.error(f"Error fetching admin store listings: {e}")
|
||||||
# Return empty response rather than exposing internal errors
|
# Return empty response rather than exposing internal errors
|
||||||
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
|
return store_model.StoreListingsWithVersionsResponse(
|
||||||
listings=[],
|
listings=[],
|
||||||
pagination=backend.server.v2.store.model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
current_page=page,
|
current_page=page,
|
||||||
total_items=0,
|
total_items=0,
|
||||||
total_pages=0,
|
total_pages=0,
|
||||||
@@ -6,8 +6,8 @@ import prisma.models
|
|||||||
import pytest
|
import pytest
|
||||||
from prisma import Prisma
|
from prisma import Prisma
|
||||||
|
|
||||||
import backend.server.v2.store.db as db
|
from . import db
|
||||||
from backend.server.v2.store.model import Profile
|
from .model import Profile
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -40,6 +40,8 @@ async def test_get_store_agents(mocker):
|
|||||||
runs=10,
|
runs=10,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0"],
|
versions=["1.0"],
|
||||||
|
agentGraphVersions=["1"],
|
||||||
|
agentGraphId="test-graph-id",
|
||||||
updated_at=datetime.now(),
|
updated_at=datetime.now(),
|
||||||
is_available=False,
|
is_available=False,
|
||||||
useForOnboarding=False,
|
useForOnboarding=False,
|
||||||
@@ -83,6 +85,8 @@ async def test_get_store_agent_details(mocker):
|
|||||||
runs=10,
|
runs=10,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0"],
|
versions=["1.0"],
|
||||||
|
agentGraphVersions=["1"],
|
||||||
|
agentGraphId="test-graph-id",
|
||||||
updated_at=datetime.now(),
|
updated_at=datetime.now(),
|
||||||
is_available=False,
|
is_available=False,
|
||||||
useForOnboarding=False,
|
useForOnboarding=False,
|
||||||
@@ -105,6 +109,8 @@ async def test_get_store_agent_details(mocker):
|
|||||||
runs=15,
|
runs=15,
|
||||||
rating=4.8,
|
rating=4.8,
|
||||||
versions=["1.0", "2.0"],
|
versions=["1.0", "2.0"],
|
||||||
|
agentGraphVersions=["1", "2"],
|
||||||
|
agentGraphId="test-graph-id-active",
|
||||||
updated_at=datetime.now(),
|
updated_at=datetime.now(),
|
||||||
is_available=True,
|
is_available=True,
|
||||||
useForOnboarding=False,
|
useForOnboarding=False,
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user