Compare commits

...

10 Commits

Author SHA1 Message Date
Swifty
4c851414a9 feat(platform): add login redirect flow for OAuth and Connect endpoints
- Fix OAuth resume CORS issue by adding JSON response support
- Add X-Frame-Options: DENY security headers to OAuth endpoints
- Fix authentication fallthrough vulnerability in OAuth router
- Replace dangerouslySetInnerHTML with React consent components
- Add rate limiting to consent submission endpoint

Connect flow for unauthenticated users:
- Add ConnectLoginState model and Redis storage functions
- Handle unauthenticated users in /connect/{provider} endpoint
- Add /connect/resume endpoint for post-login continuation
- Create /auth/connect-resume frontend page
- Update login page to handle connect_session parameter
- Update auth callback to redirect to connect-resume

This enables the full credential broker popup flow where:
1. External app opens popup to /connect/{provider}
2. If user not logged in -> redirect to /login?connect_session=X
3. User logs in -> redirect to /auth/connect-resume
4. User sees consent form and approves
5. postMessage returns grant_id to opener

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 19:06:03 +01:00
Swifty
7e145734c7 fix(platform): address OAuth security vulnerabilities and improve code quality
Backend:
- Fix XSS vulnerabilities in consent_templates.py by escaping all user input
- Add debug logging for auth token validation failures in router.py
- Use urlencode() for redirect URL construction to prevent injection
- Fix HTTP status codes: 401 for InvalidClientError, 400 for other errors
- Migrate consent state from in-memory dict to Redis with TTL
- Remove unused external API routes (middleware.py, v1.py, integrations.py, tools.py)

Frontend:
- Fix open redirect vulnerability in useLoginPage.ts (use strict origin matching)
- Fix XSS in oauth_callback/route.ts (escape < > in JSON, add CSP header)
- Add validateRedirectUri() helper for OAuth client forms
- Migrate icons to Phosphor (CircleNotch, Copy, DotsThreeVertical)
- Fix TypeScript type error in useOAuthClientModals.tsx

Docs:
- Update integration guides with dual-auth flow (API key + login)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 11:34:20 +01:00
Swifty
0d0c426209 added editing oauth clients 2025-12-09 17:53:15 +01:00
Swifty
2a4d474ca4 feat(frontend): add OAuth client management UI
Add a new Developer Settings page at /profile/developer for managing
OAuth clients. Users can now register, view, suspend, activate, and
delete OAuth clients through the UI.

Features:
- Register OAuth clients (public or confidential)
- Configure redirect URIs, homepage, privacy policy, terms of service
- View client credentials (secret shown only once)
- Suspend/activate clients
- Delete clients

The page is accessible from the account menu in the navbar.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-09 17:14:48 +01:00
Swifty
9e67f0bf45 fix(backend): use typed Prisma input classes for strict type checking
Replace dict literals with properly typed Prisma input classes to fix
pyright type errors. This ensures type safety when calling Prisma create,
upsert, and update operations.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-09 16:42:38 +01:00
Swifty
0b25c643a4 Merge branch 'swiftyos/oauth-integrations' of github.com:Significant-Gravitas/AutoGPT into swiftyos/oauth-integrations 2025-12-09 16:26:17 +01:00
Swifty
9430ea2354 Merge branch 'dev' into swiftyos/oauth-integrations 2025-12-09 16:26:14 +01:00
Swifty
15033d8ebf feat(backend): complete OAuth provider integration
Add webhook notifications on execution completion and rate limiting
to OAuth endpoints. This completes the OAuth Provider & Credential
Broker implementation.

- Add webhook notification integration in executor manager
- Add rate limiting to /oauth/authorize and /oauth/token endpoints
- Add database migration for OAuth provider tables
- Update OpenAPI spec with OAuth endpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-09 16:18:05 +01:00
Swifty
741d1b40aa fix(backend): fix XSS vulnerability and type errors in OAuth provider
- Add HTML escaping to user input in error messages (XSS fix)
- Add type: ignore comments for Prisma dict type issues

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-09 13:53:17 +01:00
Swifty
1acb18f5ff add external auth flows 2025-12-09 12:49:10 +01:00
69 changed files with 15211 additions and 1468 deletions

View File

@@ -6,7 +6,7 @@ from typing import Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
from pydantic import BaseModel, Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH
@@ -83,17 +83,17 @@ async def create_api_key(
generated_key = keysmith.generate_key()
saved_key_obj = await PrismaAPIKey.prisma().create(
data={
"id": str(uuid.uuid4()),
"name": name,
"head": generated_key.head,
"tail": generated_key.tail,
"hash": generated_key.hash,
"salt": generated_key.salt,
"permissions": [p for p in permissions],
"description": description,
"userId": user_id,
}
data=APIKeyCreateInput(
id=str(uuid.uuid4()),
name=name,
head=generated_key.head,
tail=generated_key.tail,
hash=generated_key.hash,
salt=generated_key.salt,
permissions=permissions,
description=description,
userId=user_id,
)
)
return APIKeyInfo.from_db(saved_key_obj), generated_key.key

View File

@@ -0,0 +1,327 @@
"""
Credential Grant data layer.
Handles database operations for credential grants which allow OAuth clients
to use credentials on behalf of users.
"""
from datetime import datetime, timezone
from typing import Optional
from prisma.enums import CredentialGrantPermission
from prisma.models import CredentialGrant
from backend.data.db import prisma
async def create_credential_grant(
user_id: str,
client_id: str,
credential_id: str,
provider: str,
granted_scopes: list[str],
permissions: list[CredentialGrantPermission],
expires_at: Optional[datetime] = None,
) -> CredentialGrant:
"""
Create a new credential grant.
Args:
user_id: ID of the user granting access
client_id: Database ID of the OAuth client
credential_id: ID of the credential being granted
provider: Provider name (e.g., "google", "github")
granted_scopes: List of integration scopes granted
permissions: List of permissions (USE, DELETE)
expires_at: Optional expiration datetime
Returns:
Created CredentialGrant
"""
return await prisma.credentialgrant.create(
data={ # type: ignore[typeddict-item]
"userId": user_id,
"clientId": client_id,
"credentialId": credential_id,
"provider": provider,
"grantedScopes": granted_scopes,
"permissions": permissions,
"expiresAt": expires_at,
}
)
async def get_credential_grant(
grant_id: str,
user_id: Optional[str] = None,
client_id: Optional[str] = None,
) -> Optional[CredentialGrant]:
"""
Get a credential grant by ID.
Args:
grant_id: Grant ID
user_id: Optional user ID filter
client_id: Optional client database ID filter
Returns:
CredentialGrant or None
"""
where: dict[str, str] = {"id": grant_id}
if user_id:
where["userId"] = user_id
if client_id:
where["clientId"] = client_id
return await prisma.credentialgrant.find_first(where=where) # type: ignore[arg-type]
async def get_grants_for_user_client(
user_id: str,
client_id: str,
include_revoked: bool = False,
include_expired: bool = False,
) -> list[CredentialGrant]:
"""
Get all credential grants for a user-client pair.
Args:
user_id: User ID
client_id: Client database ID
include_revoked: Include revoked grants
include_expired: Include expired grants
Returns:
List of CredentialGrant objects
"""
where: dict[str, str | None] = {
"userId": user_id,
"clientId": client_id,
}
if not include_revoked:
where["revokedAt"] = None
grants = await prisma.credentialgrant.find_many(
where=where, # type: ignore[arg-type]
order={"createdAt": "desc"},
)
# Filter expired if needed
if not include_expired:
now = datetime.now(timezone.utc)
grants = [g for g in grants if g.expiresAt is None or g.expiresAt > now]
return grants
async def get_grants_for_credential(
user_id: str,
credential_id: str,
) -> list[CredentialGrant]:
"""
Get all active grants for a specific credential.
Args:
user_id: User ID
credential_id: Credential ID
Returns:
List of active CredentialGrant objects
"""
now = datetime.now(timezone.utc)
grants = await prisma.credentialgrant.find_many(
where={
"userId": user_id,
"credentialId": credential_id,
"revokedAt": None,
},
include={"Client": True},
)
# Filter expired
return [g for g in grants if g.expiresAt is None or g.expiresAt > now]
async def get_grant_by_credential_and_client(
user_id: str,
credential_id: str,
client_id: str,
) -> Optional[CredentialGrant]:
"""
Get the grant for a specific credential and client.
Args:
user_id: User ID
credential_id: Credential ID
client_id: Client database ID
Returns:
CredentialGrant or None
"""
return await prisma.credentialgrant.find_first(
where={
"userId": user_id,
"credentialId": credential_id,
"clientId": client_id,
"revokedAt": None,
}
)
async def update_grant_scopes(
grant_id: str,
granted_scopes: list[str],
) -> CredentialGrant:
"""
Update the granted scopes for a credential grant.
Args:
grant_id: Grant ID
granted_scopes: New list of granted scopes
Returns:
Updated CredentialGrant
"""
result = await prisma.credentialgrant.update(
where={"id": grant_id},
data={"grantedScopes": granted_scopes},
)
if result is None:
raise ValueError(f"Grant {grant_id} not found")
return result
async def update_grant_last_used(grant_id: str) -> None:
"""
Update the lastUsedAt timestamp for a grant.
Args:
grant_id: Grant ID
"""
await prisma.credentialgrant.update(
where={"id": grant_id},
data={"lastUsedAt": datetime.now(timezone.utc)},
)
async def revoke_grant(grant_id: str) -> CredentialGrant:
"""
Revoke a credential grant.
Args:
grant_id: Grant ID
Returns:
Revoked CredentialGrant
"""
result = await prisma.credentialgrant.update(
where={"id": grant_id},
data={"revokedAt": datetime.now(timezone.utc)},
)
if result is None:
raise ValueError(f"Grant {grant_id} not found")
return result
async def revoke_grants_for_credential(
user_id: str,
credential_id: str,
) -> int:
"""
Revoke all grants for a specific credential.
Args:
user_id: User ID
credential_id: Credential ID
Returns:
Number of grants revoked
"""
return await prisma.credentialgrant.update_many(
where={
"userId": user_id,
"credentialId": credential_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
async def revoke_grants_for_client(
user_id: str,
client_id: str,
) -> int:
"""
Revoke all grants for a specific client.
Args:
user_id: User ID
client_id: Client database ID
Returns:
Number of grants revoked
"""
return await prisma.credentialgrant.update_many(
where={
"userId": user_id,
"clientId": client_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
async def delete_grant(grant_id: str) -> None:
"""
Permanently delete a credential grant.
Args:
grant_id: Grant ID
"""
await prisma.credentialgrant.delete(where={"id": grant_id})
async def check_grant_permission(
grant_id: str,
required_permission: CredentialGrantPermission,
) -> bool:
"""
Check if a grant has a specific permission.
Args:
grant_id: Grant ID
required_permission: Permission to check
Returns:
True if grant has the permission
"""
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
if not grant:
return False
return required_permission in grant.permissions
async def is_grant_valid(grant_id: str) -> bool:
"""
Check if a grant is valid (not revoked and not expired).
Args:
grant_id: Grant ID
Returns:
True if grant is valid
"""
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
if not grant:
return False
if grant.revokedAt:
return False
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
return False
return True

View File

@@ -11,6 +11,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)

View File

@@ -14,6 +14,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError
@@ -28,11 +29,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -41,7 +42,10 @@ async def create_test_user(user_id: str) -> None:
# Ensure UserBalance record exists
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)
@@ -342,10 +346,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
# First, set balance near max
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
update={"balance": max_int - 100},
),
)
# Try to add more than possible - should clamp to POSTGRES_INT_MAX

View File

@@ -8,6 +8,7 @@ which would have caught the CreditTransactionType enum casting bug.
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserCreateInput
from backend.data.credit import (
AutoTopUpConfig,
@@ -29,12 +30,12 @@ async def cleanup_test_user():
# Create the user first
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
topUpConfig=SafeJson({}),
timezone="UTC",
)
)
except Exception:
# User might already exist, that's fine

View File

@@ -12,6 +12,12 @@ import pytest
import stripe
from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserCreateInput,
)
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -35,32 +41,32 @@ async def setup_test_user_with_topup():
# Create user
await User.prisma().create(
data={
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
}
data=UserCreateInput(
id=REFUND_TEST_USER_ID,
email=f"{REFUND_TEST_USER_ID}@example.com",
name="Refund Test User",
)
)
# Create user balance
await UserBalance.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
}
data=UserBalanceCreateInput(
userId=REFUND_TEST_USER_ID,
balance=1000, # $10
)
)
# Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
}
data=CreditTransactionCreateInput(
userId=REFUND_TEST_USER_ID,
amount=1000,
type=CreditTransactionType.TOP_UP,
transactionKey="pi_test_12345",
runningBalance=1000,
isActive=True,
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
)
)
return topup_tx
@@ -93,12 +99,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
# Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
}
data=CreditRefundRequestCreateInput(
userId=REFUND_TEST_USER_ID,
amount=500,
transactionKey=topup_tx.transactionKey, # Should match the original transaction
reason="Test refund",
)
)
# Call deduct_credits
@@ -286,12 +292,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
refund_requests = []
for i in range(5):
req = await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
}
data=CreditRefundRequestCreateInput(
userId=REFUND_TEST_USER_ID,
amount=100, # $1 each
transactionKey=topup_tx.transactionKey,
reason=f"Test refund {i}",
)
)
refund_requests.append(req)

View File

@@ -3,6 +3,11 @@ from datetime import datetime, timedelta, timezone
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from prisma.types import (
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserBalanceUpsertInput,
)
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
@@ -23,10 +28,10 @@ async def disable_test_user_transactions():
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
update={"balance": 0, "updatedAt": old_date},
),
)
@@ -140,23 +145,23 @@ async def test_block_credit_reset(server: SpinTestServer):
# Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
}
data=CreditTransactionCreateInput(
userId=DEFAULT_USER_ID,
amount=100,
type=CreditTransactionType.TOP_UP,
runningBalance=1100,
isActive=True,
createdAt=month1, # Set specific timestamp
)
)
# Update user balance to match
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
update={"balance": 1100},
),
)
# Now test month 2 behavior
@@ -175,14 +180,14 @@ async def test_block_credit_reset(server: SpinTestServer):
# Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
}
data=CreditTransactionCreateInput(
userId=DEFAULT_USER_ID,
amount=-700, # Spent 700 to get to 400
type=CreditTransactionType.USAGE,
runningBalance=400,
isActive=True,
createdAt=month2,
)
)
# Move to month 3

View File

@@ -12,6 +12,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)
@@ -66,14 +70,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(
userId=user_id, balance=initial_balance_target
),
update={"balance": initial_balance_target},
),
)
current_balance = await credit_system.get_credits(user_id)
@@ -110,10 +114,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
# Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
update={"balance": POSTGRES_INT_MIN},
),
)
edge_balance = await credit_system.get_credits(user_id)
@@ -147,15 +151,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
# Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
update={"balance": test_balance},
),
)
current_balance = await credit_system.get_credits(user_id)
@@ -212,15 +214,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
update={"balance": initial_balance},
),
)
# Apply multiple refunds that would cumulatively underflow
@@ -290,15 +290,13 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
update={"balance": initial_balance},
),
)
async def large_refund(amount: int, label: str):

View File

@@ -14,6 +14,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserCreateInput
from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson
@@ -24,11 +25,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -121,7 +122,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
try:
# Create UserBalance with specific value
await UserBalance.prisma().create(
data={"userId": user_id, "balance": 5000} # $50
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
)
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
@@ -160,7 +161,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
try:
# Set initial balance in UserBalance
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
await UserBalance.prisma().create(
data=UserBalanceCreateInput(userId=user_id, balance=1000)
)
# Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str):

View File

@@ -27,6 +27,7 @@ from prisma.models import (
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
@@ -34,7 +35,7 @@ from prisma.types import (
AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
AgentNodeExecutionWhereUniqueInput,
_AgentNodeExecutionWhereUnique_id_Input,
)
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field
@@ -71,6 +72,13 @@ logger = logging.getLogger(__name__)
config = Config()
class GrantResolverContext(BaseModel):
"""Context for grant-based credential resolution in external API executions."""
client_db_id: str # The OAuth client database UUID
grant_ids: list[str] # List of grant IDs to use for credential resolution
class ExecutionContext(BaseModel):
"""
Unified context that carries execution-level data throughout the entire execution flow.
@@ -81,6 +89,8 @@ class ExecutionContext(BaseModel):
user_timezone: str = "UTC"
root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None
# For external API executions using credential grants
grant_resolver_context: Optional[GrantResolverContext] = None
# -------------------------- Models -------------------------- #
@@ -705,18 +715,18 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs),
"credentialInputs": (
data=AgentGraphExecutionCreateInput(
agentGraphId=graph_id,
agentGraphVersion=graph_version,
executionStatus=ExecutionStatus.INCOMPLETE,
inputs=SafeJson(inputs),
credentialInputs=(
SafeJson(credential_inputs) if credential_inputs else Json({})
),
"nodesInputMasks": (
nodesInputMasks=(
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
),
"NodeExecutions": {
NodeExecutions={
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
@@ -732,10 +742,10 @@ async def create_graph_execution(
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
"parentGraphExecutionId": parent_graph_exec_id,
},
userId=user_id,
agentPresetId=preset_id,
parentGraphExecutionId=parent_graph_exec_id,
),
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
@@ -827,10 +837,10 @@ async def upsert_execution_output(
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
data: AgentNodeExecutionInputOutputCreateInput = {
"name": output_name,
"referencedByOutputExecId": node_exec_id,
}
data = AgentNodeExecutionInputOutputCreateInput(
name=output_name,
referencedByOutputExecId=node_exec_id,
)
if output_data is not None:
data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)
@@ -948,7 +958,7 @@ async def update_node_execution_status(
if res := await AgentNodeExecution.prisma().update(
where=cast(
AgentNodeExecutionWhereUniqueInput,
_AgentNodeExecutionWhereUnique_id_Input,
{
"id": node_exec_id,
"executionStatus": {"in": [s.value for s in allowed_from]},

View File

@@ -10,7 +10,11 @@ from typing import Optional
from prisma.enums import ReviewStatus
from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput
from prisma.types import (
PendingHumanReviewCreateInput,
PendingHumanReviewUpdateInput,
PendingHumanReviewUpsertInput,
)
from pydantic import BaseModel
from backend.server.v2.executions.review.model import (
@@ -66,20 +70,20 @@ async def get_or_create_human_review(
# Upsert - get existing or create new review
review = await PendingHumanReview.prisma().upsert(
where={"nodeExecId": node_exec_id},
data={
"create": {
"userId": user_id,
"nodeExecId": node_exec_id,
"graphExecId": graph_exec_id,
"graphId": graph_id,
"graphVersion": graph_version,
"payload": SafeJson(input_data),
"instructions": message,
"editable": editable,
"status": ReviewStatus.WAITING,
},
"update": {}, # Do nothing on update - keep existing review as is
},
data=PendingHumanReviewUpsertInput(
create=PendingHumanReviewCreateInput(
userId=user_id,
nodeExecId=node_exec_id,
graphExecId=graph_exec_id,
graphId=graph_id,
graphVersion=graph_version,
payload=SafeJson(input_data),
instructions=message,
editable=editable,
status=ReviewStatus.WAITING,
),
update={}, # Do nothing on update - keep existing review as is
),
)
logger.info(

View File

@@ -0,0 +1,302 @@
"""
Integration scopes mapping.
Maps AutoGPT's fine-grained integration scopes to provider-specific OAuth scopes.
These scopes are used to request granular permissions when connecting integrations
through the Credential Broker.
"""
from enum import Enum
from typing import Optional
from backend.integrations.providers import ProviderName
class IntegrationScope(str, Enum):
"""
Fine-grained integration scopes for credential grants.
Format: {provider}:{resource}.{permission}
"""
# Google scopes
GOOGLE_EMAIL_READ = "google:email.read"
GOOGLE_GMAIL_READONLY = "google:gmail.readonly"
GOOGLE_GMAIL_SEND = "google:gmail.send"
GOOGLE_GMAIL_MODIFY = "google:gmail.modify"
GOOGLE_DRIVE_READONLY = "google:drive.readonly"
GOOGLE_DRIVE_FILE = "google:drive.file"
GOOGLE_CALENDAR_READONLY = "google:calendar.readonly"
GOOGLE_CALENDAR_EVENTS = "google:calendar.events"
GOOGLE_SHEETS_READONLY = "google:sheets.readonly"
GOOGLE_SHEETS = "google:sheets"
GOOGLE_DOCS_READONLY = "google:docs.readonly"
GOOGLE_DOCS = "google:docs"
# GitHub scopes
GITHUB_REPOS_READ = "github:repos.read"
GITHUB_REPOS_WRITE = "github:repos.write"
GITHUB_ISSUES_READ = "github:issues.read"
GITHUB_ISSUES_WRITE = "github:issues.write"
GITHUB_USER_READ = "github:user.read"
GITHUB_GISTS = "github:gists"
GITHUB_NOTIFICATIONS = "github:notifications"
# Discord scopes
DISCORD_IDENTIFY = "discord:identify"
DISCORD_EMAIL = "discord:email"
DISCORD_GUILDS = "discord:guilds"
DISCORD_MESSAGES_READ = "discord:messages.read"
# Twitter scopes
TWITTER_READ = "twitter:read"
TWITTER_WRITE = "twitter:write"
TWITTER_DM = "twitter:dm"
# Notion scopes
NOTION_READ = "notion:read"
NOTION_WRITE = "notion:write"
# Todoist scopes
TODOIST_READ = "todoist:read"
TODOIST_WRITE = "todoist:write"
# Scope descriptions for consent UI
INTEGRATION_SCOPE_DESCRIPTIONS: dict[str, str] = {
# Google
IntegrationScope.GOOGLE_EMAIL_READ.value: "Read your email address",
IntegrationScope.GOOGLE_GMAIL_READONLY.value: "Read your Gmail messages",
IntegrationScope.GOOGLE_GMAIL_SEND.value: "Send emails on your behalf",
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: "Read, send, and manage your emails",
IntegrationScope.GOOGLE_DRIVE_READONLY.value: "View files in your Google Drive",
IntegrationScope.GOOGLE_DRIVE_FILE.value: "Create and edit files in Google Drive",
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: "View your calendar",
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: "Create and edit calendar events",
IntegrationScope.GOOGLE_SHEETS_READONLY.value: "View your spreadsheets",
IntegrationScope.GOOGLE_SHEETS.value: "Create and edit spreadsheets",
IntegrationScope.GOOGLE_DOCS_READONLY.value: "View your documents",
IntegrationScope.GOOGLE_DOCS.value: "Create and edit documents",
# GitHub
IntegrationScope.GITHUB_REPOS_READ.value: "Read repository information",
IntegrationScope.GITHUB_REPOS_WRITE.value: "Create and manage repositories",
IntegrationScope.GITHUB_ISSUES_READ.value: "Read issues and pull requests",
IntegrationScope.GITHUB_ISSUES_WRITE.value: "Create and manage issues",
IntegrationScope.GITHUB_USER_READ.value: "Read your GitHub profile",
IntegrationScope.GITHUB_GISTS.value: "Create and manage gists",
IntegrationScope.GITHUB_NOTIFICATIONS.value: "Access notifications",
# Discord
IntegrationScope.DISCORD_IDENTIFY.value: "Access your Discord username",
IntegrationScope.DISCORD_EMAIL.value: "Access your Discord email",
IntegrationScope.DISCORD_GUILDS.value: "View your server list",
IntegrationScope.DISCORD_MESSAGES_READ.value: "Read messages",
# Twitter
IntegrationScope.TWITTER_READ.value: "Read tweets and profile",
IntegrationScope.TWITTER_WRITE.value: "Post tweets on your behalf",
IntegrationScope.TWITTER_DM.value: "Send and read direct messages",
# Notion
IntegrationScope.NOTION_READ.value: "View Notion pages",
IntegrationScope.NOTION_WRITE.value: "Create and edit Notion pages",
# Todoist
IntegrationScope.TODOIST_READ.value: "View your tasks",
IntegrationScope.TODOIST_WRITE.value: "Create and manage tasks",
}
# Mapping from integration scopes to provider OAuth scopes
INTEGRATION_SCOPE_MAPPING: dict[str, dict[str, list[str]]] = {
ProviderName.GOOGLE.value: {
IntegrationScope.GOOGLE_EMAIL_READ.value: [
"https://www.googleapis.com/auth/userinfo.email",
"openid",
],
IntegrationScope.GOOGLE_GMAIL_READONLY.value: [
"https://www.googleapis.com/auth/gmail.readonly",
],
IntegrationScope.GOOGLE_GMAIL_SEND.value: [
"https://www.googleapis.com/auth/gmail.send",
],
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: [
"https://www.googleapis.com/auth/gmail.modify",
],
IntegrationScope.GOOGLE_DRIVE_READONLY.value: [
"https://www.googleapis.com/auth/drive.readonly",
],
IntegrationScope.GOOGLE_DRIVE_FILE.value: [
"https://www.googleapis.com/auth/drive.file",
],
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: [
"https://www.googleapis.com/auth/calendar.readonly",
],
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: [
"https://www.googleapis.com/auth/calendar.events",
],
IntegrationScope.GOOGLE_SHEETS_READONLY.value: [
"https://www.googleapis.com/auth/spreadsheets.readonly",
],
IntegrationScope.GOOGLE_SHEETS.value: [
"https://www.googleapis.com/auth/spreadsheets",
],
IntegrationScope.GOOGLE_DOCS_READONLY.value: [
"https://www.googleapis.com/auth/documents.readonly",
],
IntegrationScope.GOOGLE_DOCS.value: [
"https://www.googleapis.com/auth/documents",
],
},
ProviderName.GITHUB.value: {
IntegrationScope.GITHUB_REPOS_READ.value: [
"repo:status",
"public_repo",
],
IntegrationScope.GITHUB_REPOS_WRITE.value: [
"repo",
],
IntegrationScope.GITHUB_ISSUES_READ.value: [
"repo:status",
],
IntegrationScope.GITHUB_ISSUES_WRITE.value: [
"repo",
],
IntegrationScope.GITHUB_USER_READ.value: [
"read:user",
"user:email",
],
IntegrationScope.GITHUB_GISTS.value: [
"gist",
],
IntegrationScope.GITHUB_NOTIFICATIONS.value: [
"notifications",
],
},
ProviderName.DISCORD.value: {
IntegrationScope.DISCORD_IDENTIFY.value: [
"identify",
],
IntegrationScope.DISCORD_EMAIL.value: [
"email",
],
IntegrationScope.DISCORD_GUILDS.value: [
"guilds",
],
IntegrationScope.DISCORD_MESSAGES_READ.value: [
"messages.read",
],
},
ProviderName.TWITTER.value: {
IntegrationScope.TWITTER_READ.value: [
"tweet.read",
"users.read",
],
IntegrationScope.TWITTER_WRITE.value: [
"tweet.write",
],
IntegrationScope.TWITTER_DM.value: [
"dm.read",
"dm.write",
],
},
ProviderName.NOTION.value: {
IntegrationScope.NOTION_READ.value: [], # Notion uses workspace-level access
IntegrationScope.NOTION_WRITE.value: [],
},
ProviderName.TODOIST.value: {
IntegrationScope.TODOIST_READ.value: [
"data:read",
],
IntegrationScope.TODOIST_WRITE.value: [
"data:read_write",
],
},
}
def get_provider_scopes(
provider: ProviderName | str, integration_scopes: list[str]
) -> list[str]:
"""
Convert integration scopes to provider-specific OAuth scopes.
Args:
provider: The provider name
integration_scopes: List of integration scope strings
Returns:
List of provider-specific OAuth scopes
"""
provider_value = provider.value if isinstance(provider, ProviderName) else provider
provider_mapping = INTEGRATION_SCOPE_MAPPING.get(provider_value, {})
oauth_scopes: set[str] = set()
for scope in integration_scopes:
if scope in provider_mapping:
oauth_scopes.update(provider_mapping[scope])
return list(oauth_scopes)
def get_provider_for_scope(scope: str) -> Optional[ProviderName]:
"""
Get the provider for an integration scope.
Args:
scope: Integration scope string (e.g., "google:gmail.readonly")
Returns:
ProviderName or None if not recognized
"""
if ":" not in scope:
return None
provider_prefix = scope.split(":")[0]
# Map prefixes to providers
prefix_mapping = {
"google": ProviderName.GOOGLE,
"github": ProviderName.GITHUB,
"discord": ProviderName.DISCORD,
"twitter": ProviderName.TWITTER,
"notion": ProviderName.NOTION,
"todoist": ProviderName.TODOIST,
}
return prefix_mapping.get(provider_prefix)
def validate_integration_scopes(scopes: list[str]) -> tuple[bool, list[str]]:
"""
Validate a list of integration scopes.
Args:
scopes: List of integration scope strings
Returns:
Tuple of (valid, invalid_scopes)
"""
valid_scopes = {s.value for s in IntegrationScope}
invalid = [s for s in scopes if s not in valid_scopes]
return len(invalid) == 0, invalid
def group_scopes_by_provider(
scopes: list[str],
) -> dict[ProviderName, list[str]]:
"""
Group integration scopes by their provider.
Args:
scopes: List of integration scope strings
Returns:
Dictionary mapping providers to their scopes
"""
grouped: dict[ProviderName, list[str]] = {}
for scope in scopes:
provider = get_provider_for_scope(scope)
if provider:
if provider not in grouped:
grouped[provider] = []
grouped[provider].append(scope)
return grouped

View File

@@ -0,0 +1,176 @@
"""
OAuth Audit Logging.
Logs all OAuth-related operations for security auditing and compliance.
"""
import logging
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Optional
from backend.data.db import prisma
logger = logging.getLogger(__name__)
class OAuthEventType(str, Enum):
"""Types of OAuth events to audit."""
# Client events
CLIENT_REGISTERED = "client.registered"
CLIENT_UPDATED = "client.updated"
CLIENT_DELETED = "client.deleted"
CLIENT_SECRET_ROTATED = "client.secret_rotated"
CLIENT_SUSPENDED = "client.suspended"
CLIENT_ACTIVATED = "client.activated"
# Authorization events
AUTHORIZATION_REQUESTED = "authorization.requested"
AUTHORIZATION_GRANTED = "authorization.granted"
AUTHORIZATION_DENIED = "authorization.denied"
AUTHORIZATION_REVOKED = "authorization.revoked"
# Token events
TOKEN_ISSUED = "token.issued"
TOKEN_REFRESHED = "token.refreshed"
TOKEN_REVOKED = "token.revoked"
TOKEN_EXPIRED = "token.expired"
# Grant events
GRANT_CREATED = "grant.created"
GRANT_UPDATED = "grant.updated"
GRANT_REVOKED = "grant.revoked"
GRANT_USED = "grant.used"
# Credential events
CREDENTIAL_CONNECTED = "credential.connected"
CREDENTIAL_DELETED = "credential.deleted"
# Execution events
EXECUTION_STARTED = "execution.started"
EXECUTION_COMPLETED = "execution.completed"
EXECUTION_FAILED = "execution.failed"
EXECUTION_CANCELLED = "execution.cancelled"
async def log_oauth_event(
event_type: OAuthEventType,
user_id: Optional[str] = None,
client_id: Optional[str] = None,
grant_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
details: Optional[dict[str, Any]] = None,
) -> str:
"""
Log an OAuth audit event.
Args:
event_type: Type of event
user_id: User ID involved (if any)
client_id: OAuth client ID involved (if any)
grant_id: Grant ID involved (if any)
ip_address: Client IP address
user_agent: Client user agent
details: Additional event details
Returns:
ID of the created audit log entry
"""
try:
from prisma import Json
audit_entry = await prisma.oauthauditlog.create(
data={ # type: ignore[typeddict-item]
"eventType": event_type.value,
"userId": user_id,
"clientId": client_id,
"grantId": grant_id,
"ipAddress": ip_address,
"userAgent": user_agent,
"details": Json(details or {}),
}
)
logger.debug(
f"OAuth audit: {event_type.value} - "
f"user={user_id}, client={client_id}, grant={grant_id}"
)
return audit_entry.id
except Exception as e:
# Log but don't fail the operation if audit logging fails
logger.error(f"Failed to create OAuth audit log: {e}")
return ""
async def get_audit_logs(
user_id: Optional[str] = None,
client_id: Optional[str] = None,
event_type: Optional[OAuthEventType] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100,
offset: int = 0,
) -> list:
"""
Query OAuth audit logs.
Args:
user_id: Filter by user ID
client_id: Filter by client ID
event_type: Filter by event type
start_date: Filter by start date
end_date: Filter by end date
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of audit log entries
"""
where: dict[str, Any] = {}
if user_id:
where["userId"] = user_id
if client_id:
where["clientId"] = client_id
if event_type:
where["eventType"] = event_type.value
if start_date:
where["createdAt"] = {"gte": start_date}
if end_date:
if "createdAt" in where:
where["createdAt"]["lte"] = end_date
else:
where["createdAt"] = {"lte": end_date}
return await prisma.oauthauditlog.find_many(
where=where if where else None, # type: ignore[arg-type]
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
async def cleanup_old_audit_logs(days_to_keep: int = 90) -> int:
"""
Delete audit logs older than the specified number of days.
Args:
days_to_keep: Number of days of logs to retain
Returns:
Number of logs deleted
"""
from datetime import timedelta
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
result = await prisma.oauthauditlog.delete_many(
where={"createdAt": {"lt": cutoff_date}}
)
logger.info(f"Cleaned up {result} OAuth audit logs older than {days_to_keep} days")
return result

View File

@@ -7,7 +7,11 @@ import prisma
import pydantic
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from prisma.types import (
UserOnboardingCreateInput,
UserOnboardingUpdateInput,
UserOnboardingUpsertInput,
)
from backend.data import execution as execution_db
from backend.data.credit import get_user_credit_model
@@ -112,10 +116,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update},
"update": update,
},
data=UserOnboardingUpsertInput(
create=UserOnboardingCreateInput(userId=user_id, **update),
update=update,
),
)

View File

@@ -67,6 +67,7 @@ from backend.executor.utils import (
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhook_notifier import get_webhook_notifier
from backend.notifications.notifications import queue_notification
from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json
@@ -221,11 +222,31 @@ async def execute_node(
creds_locks: list[AsyncRedisLock] = []
input_model = cast(type[BlockSchema], node_block.input_schema)
# Check if this is an external API execution using grant-based credential resolution
grant_resolver = None
if execution_context and execution_context.grant_resolver_context:
from backend.integrations.grant_resolver import GrantBasedCredentialResolver
grant_ctx = execution_context.grant_resolver_context
grant_resolver = GrantBasedCredentialResolver(
user_id=user_id,
client_id=grant_ctx.client_db_id,
grant_ids=grant_ctx.grant_ids,
)
await grant_resolver.initialize()
# Handle regular credentials fields
for field_name, input_type in input_model.get_credentials_fields().items():
credentials_meta = input_type(**input_data[field_name])
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
creds_locks.append(lock)
if grant_resolver:
# External API execution - use grant resolver (no locking needed)
credentials = await grant_resolver.resolve_credential(credentials_meta.id)
else:
# Normal execution - use credentials manager with locking
credentials, lock = await creds_manager.acquire(
user_id, credentials_meta.id
)
creds_locks.append(lock)
extra_exec_kwargs[field_name] = credentials
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
@@ -243,10 +264,17 @@ async def execute_node(
)
file_name = field_data.get("name", "selected file")
try:
credentials, lock = await creds_manager.acquire(
user_id, cred_id
)
creds_locks.append(lock)
if grant_resolver:
# External API execution - use grant resolver
credentials = await grant_resolver.resolve_credential(
cred_id
)
else:
# Normal execution - use credentials manager
credentials, lock = await creds_manager.acquire(
user_id, cred_id
)
creds_locks.append(lock)
extra_exec_kwargs[kwarg_name] = credentials
except ValueError:
# Credential was deleted or doesn't exist
@@ -785,6 +813,7 @@ class ExecutionProcessor:
graph_exec_id=graph_exec.graph_exec_id,
status=exec_meta.status,
stats=exec_stats,
event_loop=self.node_execution_loop,
)
def _charge_usage(
@@ -1916,6 +1945,53 @@ def update_node_execution_status(
return exec_update
async def _notify_execution_webhook(
execution_id: str,
agent_id: str,
status: ExecutionStatus,
outputs: dict[str, Any] | None = None,
error: str | None = None,
) -> None:
"""
Send webhook notification for execution completion if registered.
This is a fire-and-forget operation that checks if a webhook was registered
for this execution and sends the appropriate notification.
"""
from backend.data.db import prisma
try:
webhook = await prisma.executionwebhook.find_first(
where={"executionId": execution_id}
)
if not webhook:
return
notifier = get_webhook_notifier()
if status == ExecutionStatus.COMPLETED:
await notifier.notify_execution_completed(
execution_id=execution_id,
agent_id=agent_id,
client_id=webhook.clientId,
webhook_url=webhook.webhookUrl,
outputs=outputs or {},
webhook_secret=webhook.secret,
)
elif status == ExecutionStatus.FAILED:
await notifier.notify_execution_failed(
execution_id=execution_id,
agent_id=agent_id,
client_id=webhook.clientId,
webhook_url=webhook.webhookUrl,
error=error or "Execution failed",
webhook_secret=webhook.secret,
)
except Exception as e:
# Don't let webhook failures affect execution state updates
logger.warning(f"Failed to send webhook notification for {execution_id}: {e}")
async def async_update_graph_execution_state(
db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str,
@@ -1928,6 +2004,17 @@ async def async_update_graph_execution_state(
)
if graph_update:
await send_async_execution_update(graph_update)
# Send webhook notification for terminal states
if status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED:
await _notify_execution_webhook(
execution_id=graph_exec_id,
agent_id=graph_update.graph_id,
status=status,
outputs=(
graph_update.outputs if hasattr(graph_update, "outputs") else None
),
)
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update
@@ -1938,11 +2025,33 @@ def update_graph_execution_state(
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
event_loop: asyncio.AbstractEventLoop | None = None,
) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
if graph_update:
send_execution_update(graph_update)
# Send webhook notification for terminal states (fire-and-forget)
if (
status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED
) and event_loop:
try:
asyncio.run_coroutine_threadsafe(
_notify_execution_webhook(
execution_id=graph_exec_id,
agent_id=graph_update.graph_id,
status=status,
outputs=(
graph_update.outputs
if hasattr(graph_update, "outputs")
else None
),
),
event_loop,
)
except Exception as e:
logger.warning(f"Failed to schedule webhook notification: {e}")
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update

View File

@@ -0,0 +1,278 @@
"""
Grant-Based Credential Resolver.
Resolves credentials during agent execution based on credential grants.
External applications can only use credentials they have been granted access to,
and only for the scopes that were granted.
Credentials are NEVER exposed to external applications - this resolver
provides the credentials to the execution engine internally.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
from prisma.enums import CredentialGrantPermission
from prisma.models import CredentialGrant
from backend.data import credential_grants as grants_db
from backend.data.db import prisma
from backend.data.model import Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
logger = logging.getLogger(__name__)
class GrantValidationError(Exception):
"""Raised when a grant is invalid or lacks required permissions."""
pass
class CredentialNotFoundError(Exception):
"""Raised when a credential referenced by a grant is not found."""
pass
class ScopeMismatchError(Exception):
"""Raised when the grant doesn't cover required scopes."""
pass
class GrantBasedCredentialResolver:
"""
Resolves credentials for agent execution based on credential grants.
This resolver validates that:
1. The grant exists and is valid (not revoked/expired)
2. The grant has USE permission
3. The grant covers the required scopes (if specified)
4. The underlying credential exists
Then it provides the credential to the execution engine internally.
The credential value is NEVER exposed to external applications.
"""
def __init__(
self,
user_id: str,
client_id: str,
grant_ids: list[str],
):
"""
Initialize the resolver.
Args:
user_id: User ID who owns the credentials
client_id: Database ID of the OAuth client
grant_ids: List of grant IDs the client is using for this execution
"""
self.user_id = user_id
self.client_id = client_id
self.grant_ids = grant_ids
self._grants: dict[str, CredentialGrant] = {}
self._credentials_manager = IntegrationCredentialsManager()
self._initialized = False
async def initialize(self) -> None:
"""
Load and validate all grants.
This should be called before any credential resolution.
Raises:
GrantValidationError: If any grant is invalid
"""
now = datetime.now(timezone.utc)
for grant_id in self.grant_ids:
grant = await grants_db.get_credential_grant(
grant_id=grant_id,
user_id=self.user_id,
client_id=self.client_id,
)
if not grant:
raise GrantValidationError(f"Grant {grant_id} not found")
# Check if revoked
if grant.revokedAt:
raise GrantValidationError(f"Grant {grant_id} has been revoked")
# Check if expired
if grant.expiresAt and grant.expiresAt < now:
raise GrantValidationError(f"Grant {grant_id} has expired")
# Check USE permission
if CredentialGrantPermission.USE not in grant.permissions:
raise GrantValidationError(
f"Grant {grant_id} does not have USE permission"
)
self._grants[grant_id] = grant
self._initialized = True
logger.info(
f"Initialized grant resolver with {len(self._grants)} grants "
f"for user {self.user_id}, client {self.client_id}"
)
async def resolve_credential(
self,
credential_id: str,
required_scopes: Optional[list[str]] = None,
) -> Credentials:
"""
Resolve a credential for agent execution.
This method:
1. Finds a grant that covers this credential
2. Validates the grant covers required scopes
3. Retrieves the actual credential
4. Updates grant usage tracking
Args:
credential_id: ID of the credential to resolve
required_scopes: Optional list of scopes the credential must have
Returns:
The resolved Credentials object
Raises:
GrantValidationError: If no valid grant covers this credential
ScopeMismatchError: If the grant doesn't cover required scopes
CredentialNotFoundError: If the underlying credential doesn't exist
"""
if not self._initialized:
raise RuntimeError("Resolver not initialized. Call initialize() first.")
# Find a grant that covers this credential
matching_grant: Optional[CredentialGrant] = None
for grant in self._grants.values():
if grant.credentialId == credential_id:
matching_grant = grant
break
if not matching_grant:
raise GrantValidationError(f"No grant found for credential {credential_id}")
# Validate scopes if required
if required_scopes:
granted_scopes = set(matching_grant.grantedScopes)
required_scopes_set = set(required_scopes)
missing_scopes = required_scopes_set - granted_scopes
if missing_scopes:
raise ScopeMismatchError(
f"Grant {matching_grant.id} is missing required scopes: "
f"{', '.join(missing_scopes)}"
)
# Get the actual credential
credentials = await self._credentials_manager.get(
user_id=self.user_id,
credentials_id=credential_id,
lock=True,
)
if not credentials:
raise CredentialNotFoundError(
f"Credential {credential_id} not found for user {self.user_id}"
)
# Update last used timestamp for the grant
await grants_db.update_grant_last_used(matching_grant.id)
logger.debug(
f"Resolved credential {credential_id} via grant {matching_grant.id} "
f"for client {self.client_id}"
)
return credentials
async def get_available_credentials(self) -> list[dict]:
"""
Get list of available credentials based on grants.
Returns a list of credential metadata (NOT the actual credential values).
Returns:
List of dicts with credential metadata
"""
if not self._initialized:
raise RuntimeError("Resolver not initialized. Call initialize() first.")
credentials_info = []
for grant in self._grants.values():
credentials_info.append(
{
"grant_id": grant.id,
"credential_id": grant.credentialId,
"provider": grant.provider,
"granted_scopes": grant.grantedScopes,
}
)
return credentials_info
def get_grant_for_credential(self, credential_id: str) -> Optional[CredentialGrant]:
"""
Get the grant for a specific credential.
Args:
credential_id: ID of the credential
Returns:
CredentialGrant or None if not found
"""
for grant in self._grants.values():
if grant.credentialId == credential_id:
return grant
return None
async def create_resolver_from_oauth_token(
user_id: str,
client_public_id: str,
grant_ids: Optional[list[str]] = None,
) -> GrantBasedCredentialResolver:
"""
Create a credential resolver from OAuth token context.
This is a convenience function for creating a resolver from
the context available in OAuth-authenticated requests.
Args:
user_id: User ID from the OAuth token
client_public_id: Public client ID from the OAuth token
grant_ids: Optional list of grant IDs to use
Returns:
Initialized GrantBasedCredentialResolver
"""
# Look up the OAuth client database ID from the public client ID
client = await prisma.oauthclient.find_unique(where={"clientId": client_public_id})
if not client:
raise GrantValidationError(f"OAuth client {client_public_id} not found")
# If no grant IDs specified, get all grants for this client+user
if grant_ids is None:
grants = await grants_db.get_grants_for_user_client(
user_id=user_id,
client_id=client.id,
include_revoked=False,
include_expired=False,
)
grant_ids = [g.id for g in grants]
resolver = GrantBasedCredentialResolver(
user_id=user_id,
client_id=client.id,
grant_ids=grant_ids,
)
await resolver.initialize()
return resolver

View File

@@ -0,0 +1,331 @@
"""
Webhook Notification System for External API.
Sends webhook notifications to external applications for execution events.
"""
import asyncio
import hashlib
import hmac
import json
import logging
import weakref
from datetime import datetime, timezone
from typing import Any, Coroutine, Optional
from urllib.parse import urlparse
import httpx
logger = logging.getLogger(__name__)
# Webhook delivery settings
WEBHOOK_TIMEOUT_SECONDS = 30
WEBHOOK_MAX_RETRIES = 3
WEBHOOK_RETRY_DELAYS = [5, 30, 300] # seconds: 5s, 30s, 5min
class WebhookDeliveryError(Exception):
"""Raised when webhook delivery fails."""
pass
def sign_webhook_payload(payload: dict[str, Any], secret: str) -> str:
"""
Create HMAC-SHA256 signature for webhook payload.
Args:
payload: The webhook payload to sign
secret: The webhook secret key
Returns:
Hex-encoded HMAC-SHA256 signature
"""
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
signature = hmac.new(
secret.encode(),
payload_bytes,
hashlib.sha256,
).hexdigest()
return signature
def verify_webhook_signature(
payload: dict[str, Any],
signature: str,
secret: str,
) -> bool:
"""
Verify a webhook signature.
Args:
payload: The webhook payload
signature: The signature to verify
secret: The webhook secret key
Returns:
True if signature is valid
"""
expected = sign_webhook_payload(payload, secret)
return hmac.compare_digest(expected, signature)
def validate_webhook_url(url: str, allowed_domains: list[str]) -> bool:
"""
Validate that a webhook URL is allowed.
Args:
url: The webhook URL to validate
allowed_domains: List of allowed domains (from OAuth client config)
Returns:
True if URL is valid and allowed
"""
from backend.util.url import hostname_matches_any_domain
try:
parsed = urlparse(url)
# Must be HTTPS (except for localhost in development)
if parsed.scheme != "https":
if not (
parsed.scheme == "http"
and parsed.hostname in ["localhost", "127.0.0.1"]
):
return False
# Must have a host
if not parsed.hostname:
return False
# Check against allowed domains
return hostname_matches_any_domain(parsed.hostname, allowed_domains)
except Exception:
return False
async def send_webhook(
url: str,
payload: dict[str, Any],
secret: Optional[str] = None,
timeout: int = WEBHOOK_TIMEOUT_SECONDS,
) -> bool:
"""
Send a webhook notification.
Args:
url: Webhook URL
payload: Payload to send
secret: Optional secret for signature
timeout: Request timeout in seconds
Returns:
True if webhook was delivered successfully
"""
headers = {
"Content-Type": "application/json",
"User-Agent": "AutoGPT-Webhook/1.0",
"X-Webhook-Timestamp": datetime.now(timezone.utc).isoformat(),
}
if secret:
signature = sign_webhook_payload(payload, secret)
headers["X-Webhook-Signature"] = f"sha256={signature}"
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers=headers,
)
if response.status_code >= 200 and response.status_code < 300:
logger.debug(f"Webhook delivered successfully to {url}")
return True
else:
logger.warning(
f"Webhook delivery failed: {url} returned {response.status_code}"
)
return False
except httpx.TimeoutException:
logger.warning(f"Webhook delivery timed out: {url}")
return False
except Exception as e:
logger.error(f"Webhook delivery error: {url} - {str(e)}")
return False
async def send_webhook_with_retry(
url: str,
payload: dict[str, Any],
secret: Optional[str] = None,
max_retries: int = WEBHOOK_MAX_RETRIES,
) -> bool:
"""
Send a webhook with automatic retries.
Args:
url: Webhook URL
payload: Payload to send
secret: Optional secret for signature
max_retries: Maximum number of retry attempts
Returns:
True if webhook was eventually delivered successfully
"""
for attempt in range(max_retries + 1):
if await send_webhook(url, payload, secret):
return True
if attempt < max_retries:
delay = WEBHOOK_RETRY_DELAYS[min(attempt, len(WEBHOOK_RETRY_DELAYS) - 1)]
logger.info(
f"Webhook delivery failed, retrying in {delay}s (attempt {attempt + 1})"
)
await asyncio.sleep(delay)
logger.error(f"Webhook delivery failed after {max_retries} retries: {url}")
return False
# Track pending webhook tasks to prevent garbage collection
# Using WeakSet so tasks are automatically removed when they complete and are dereferenced
_pending_webhook_tasks: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
def _create_tracked_task(coro: Coroutine[Any, Any, bool]) -> asyncio.Task[bool]:
"""Create a task that is tracked to prevent garbage collection."""
task = asyncio.create_task(coro)
_pending_webhook_tasks.add(task)
# No explicit done callback needed - WeakSet automatically removes
# references when tasks are garbage collected after completion
return task
class WebhookNotifier:
"""
Service for sending webhook notifications to external applications.
"""
def __init__(self):
pass
async def notify_execution_started(
self,
execution_id: str,
agent_id: str,
client_id: str,
webhook_url: str,
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that an execution has started.
"""
payload = {
"event": "execution.started",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"execution_id": execution_id,
"agent_id": agent_id,
"status": "running",
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
async def notify_execution_completed(
self,
execution_id: str,
agent_id: str,
client_id: str,
webhook_url: str,
outputs: dict[str, Any],
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that an execution has completed successfully.
"""
payload = {
"event": "execution.completed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"execution_id": execution_id,
"agent_id": agent_id,
"status": "completed",
"outputs": outputs,
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
async def notify_execution_failed(
self,
execution_id: str,
agent_id: str,
client_id: str,
webhook_url: str,
error: str,
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that an execution has failed.
"""
payload = {
"event": "execution.failed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"execution_id": execution_id,
"agent_id": agent_id,
"status": "failed",
"error": error,
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
async def notify_grant_revoked(
self,
grant_id: str,
credential_id: str,
provider: str,
client_id: str,
webhook_url: str,
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that a credential grant has been revoked.
"""
payload = {
"event": "grant.revoked",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"grant_id": grant_id,
"credential_id": credential_id,
"provider": provider,
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
# Module-level singleton
_webhook_notifier: Optional[WebhookNotifier] = None
def get_webhook_notifier() -> WebhookNotifier:
"""Get the singleton webhook notifier instance."""
global _webhook_notifier
if _webhook_notifier is None:
_webhook_notifier = WebhookNotifier()
return _webhook_notifier

View File

@@ -3,21 +3,19 @@ from fastapi import FastAPI
from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.middleware.security import SecurityHeadersMiddleware
from .routes.integrations import integrations_router
from .routes.tools import tools_router
from .routes.v1 import v1_router
from .routes.execution import execution_router
from .routes.grants import grants_router
external_app = FastAPI(
title="AutoGPT External API",
description="External API for AutoGPT integrations",
description="External API for AutoGPT integrations (OAuth-based)",
docs_url="/docs",
version="1.0",
)
external_app.add_middleware(SecurityHeadersMiddleware)
external_app.include_router(v1_router, prefix="/v1")
external_app.include_router(tools_router, prefix="/v1")
external_app.include_router(integrations_router, prefix="/v1")
external_app.include_router(grants_router, prefix="/v1")
external_app.include_router(execution_router, prefix="/v1")
# Add Prometheus instrumentation
instrument_fastapi(

View File

@@ -1,36 +0,0 @@
from fastapi import HTTPException, Security
from fastapi.security import APIKeyHeader
from prisma.enums import APIKeyPermission
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
"""Base middleware for API key authentication"""
if api_key is None:
raise HTTPException(status_code=401, detail="Missing API key")
api_key_obj = await validate_api_key(api_key)
if not api_key_obj:
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key_obj
def require_permission(permission: APIKeyPermission):
"""Dependency function for checking specific permissions"""
async def check_permission(
api_key: APIKeyInfo = Security(require_api_key),
) -> APIKeyInfo:
if not has_permission(api_key, permission):
raise HTTPException(
status_code=403,
detail=f"API key lacks the required permission '{permission}'",
)
return api_key
return check_permission

View File

@@ -0,0 +1,164 @@
"""
OAuth Access Token middleware for external API.
Validates OAuth access tokens and provides user/client context
for external API endpoints that use OAuth authentication.
"""
from datetime import datetime, timezone
from typing import Optional
import jwt
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
from backend.data.db import prisma
from backend.server.oauth.token_service import get_token_service
class OAuthTokenInfo(BaseModel):
"""Information extracted from a validated OAuth access token."""
user_id: str
client_id: str
scopes: list[str]
token_id: str
# HTTP Bearer token extractor
oauth_bearer = HTTPBearer(auto_error=False)
async def require_oauth_token(
credentials: Optional[HTTPAuthorizationCredentials] = Security(oauth_bearer),
) -> OAuthTokenInfo:
"""
Validate an OAuth access token and return token info.
Extracts the Bearer token from the Authorization header,
validates the JWT signature and claims, and checks that
the token hasn't been revoked.
Raises:
HTTPException: 401 if token is missing, invalid, or revoked
"""
if credentials is None:
raise HTTPException(
status_code=401,
detail="Missing authorization token",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
token_service = get_token_service()
try:
# Verify JWT signature and claims
claims = token_service.verify_access_token(token)
# Check if token is in database and not revoked
token_hash = token_service.hash_token(token)
stored_token = await prisma.oauthaccesstoken.find_unique(
where={"tokenHash": token_hash}
)
if not stored_token:
raise HTTPException(
status_code=401,
detail="Token not found",
headers={"WWW-Authenticate": "Bearer"},
)
if stored_token.revokedAt:
raise HTTPException(
status_code=401,
detail="Token has been revoked",
headers={"WWW-Authenticate": "Bearer"},
)
if stored_token.expiresAt < datetime.now(timezone.utc):
raise HTTPException(
status_code=401,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last used timestamp (fire and forget)
await prisma.oauthaccesstoken.update(
where={"id": stored_token.id},
data={"lastUsedAt": datetime.now(timezone.utc)},
)
return OAuthTokenInfo(
user_id=claims.sub,
client_id=claims.client_id,
scopes=claims.scope.split() if claims.scope else [],
token_id=stored_token.id,
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=401,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
except jwt.InvalidTokenError as e:
raise HTTPException(
status_code=401,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"},
)
def require_scope(required_scope: str):
"""
Dependency that validates OAuth token and checks for required scope.
Args:
required_scope: The scope required for this endpoint
Returns:
Dependency function that returns OAuthTokenInfo if authorized
"""
async def check_scope(
token: OAuthTokenInfo = Security(require_oauth_token),
) -> OAuthTokenInfo:
if required_scope not in token.scopes:
raise HTTPException(
status_code=403,
detail=f"Token lacks required scope '{required_scope}'",
headers={"WWW-Authenticate": f'Bearer scope="{required_scope}"'},
)
return token
return check_scope
def require_any_scope(*required_scopes: str):
"""
Dependency that validates OAuth token and checks for any of the required scopes.
Args:
required_scopes: At least one of these scopes is required
Returns:
Dependency function that returns OAuthTokenInfo if authorized
"""
async def check_scopes(
token: OAuthTokenInfo = Security(require_oauth_token),
) -> OAuthTokenInfo:
for scope in required_scopes:
if scope in token.scopes:
return token
scope_list = " ".join(required_scopes)
raise HTTPException(
status_code=403,
detail=f"Token lacks required scopes (need one of: {scope_list})",
headers={"WWW-Authenticate": f'Bearer scope="{scope_list}"'},
)
return check_scopes

View File

@@ -0,0 +1,377 @@
"""
Agent Execution endpoints for external OAuth clients.
Allows external applications to:
- Execute agents using granted credentials
- Poll execution status
- Cancel running executions
- Get available capabilities
External apps can only use credentials they have been granted access to.
"""
import logging
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, HTTPException, Security
from prisma.enums import AgentExecutionStatus
from pydantic import BaseModel, Field
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.db import prisma
from backend.data.execution import ExecutionContext, GrantResolverContext
from backend.executor.utils import add_graph_execution
from backend.integrations.grant_resolver import (
GrantValidationError,
create_resolver_from_oauth_token,
)
from backend.integrations.webhook_notifier import validate_webhook_url
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
logger = logging.getLogger(__name__)
execution_router = APIRouter(prefix="/executions", tags=["executions"])
# ================================================================
# Request/Response Models
# ================================================================
class ExecuteAgentRequest(BaseModel):
"""Request to execute an agent."""
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Input values for the agent",
)
grant_ids: Optional[list[str]] = Field(
default=None,
description="Specific grant IDs to use. If not provided, uses all available grants.",
)
webhook_url: Optional[str] = Field(
default=None,
description="URL to receive execution status webhooks",
)
class ExecuteAgentResponse(BaseModel):
"""Response from starting an agent execution."""
execution_id: str
status: str
message: str
class ExecutionStatusResponse(BaseModel):
"""Response with execution status."""
execution_id: str
status: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
outputs: Optional[dict[str, Any]] = None
error: Optional[str] = None
class GrantInfo(BaseModel):
"""Summary of a credential grant for capabilities."""
grant_id: str
provider: str
scopes: list[str]
class CapabilitiesResponse(BaseModel):
"""Response describing what the client can do."""
user_id: str
client_id: str
grants: list[GrantInfo]
available_scopes: list[str]
# ================================================================
# Endpoints
# ================================================================
@execution_router.get("/capabilities", response_model=CapabilitiesResponse)
async def get_capabilities(
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> CapabilitiesResponse:
"""
Get the capabilities available to this client for the authenticated user.
Returns information about:
- Available credential grants (NOT credential values)
- Scopes the client has access to
"""
try:
resolver = await create_resolver_from_oauth_token(
user_id=token.user_id,
client_public_id=token.client_id,
)
credentials_info = await resolver.get_available_credentials()
grants = [
GrantInfo(
grant_id=info["grant_id"],
provider=info["provider"],
scopes=info["granted_scopes"],
)
for info in credentials_info
]
return CapabilitiesResponse(
user_id=token.user_id,
client_id=token.client_id,
grants=grants,
available_scopes=token.scopes,
)
except GrantValidationError:
# No grants available is not an error, just empty capabilities
return CapabilitiesResponse(
user_id=token.user_id,
client_id=token.client_id,
grants=[],
available_scopes=token.scopes,
)
@execution_router.post(
"/agents/{agent_id}/execute",
response_model=ExecuteAgentResponse,
)
async def execute_agent(
agent_id: str,
request: ExecuteAgentRequest,
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> ExecuteAgentResponse:
"""
Execute an agent using granted credentials.
The agent must be accessible to the user, and the client must have
valid credential grants that satisfy the agent's requirements.
Args:
agent_id: The agent (graph) ID to execute
request: Execution parameters including inputs and optional grant IDs
"""
# Verify the agent exists and user has access
# First try to get the latest version
graph = await graph_db.get_graph(
graph_id=agent_id,
version=None,
user_id=token.user_id,
)
if not graph:
# Try to find it in the store (public agents)
graph = await graph_db.get_graph(
graph_id=agent_id,
version=None,
user_id=None,
skip_access_check=True,
)
if not graph:
raise HTTPException(
status_code=404,
detail=f"Agent {agent_id} not found or not accessible",
)
# Initialize the grant resolver to validate grants exist
# The resolver context will be passed to the execution engine
grant_resolver_context = None
try:
resolver = await create_resolver_from_oauth_token(
user_id=token.user_id,
client_public_id=token.client_id,
grant_ids=request.grant_ids,
)
# Get available credentials info to build resolver context
credentials_info = await resolver.get_available_credentials()
grant_resolver_context = GrantResolverContext(
client_db_id=resolver.client_id,
grant_ids=[c["grant_id"] for c in credentials_info],
)
except GrantValidationError as e:
raise HTTPException(
status_code=403,
detail=f"Grant validation failed: {str(e)}",
)
try:
# Build execution context with grant resolver info
execution_context = ExecutionContext(
grant_resolver_context=grant_resolver_context,
)
# Execute the agent with grant resolver context
graph_exec = await add_graph_execution(
graph_id=agent_id,
user_id=token.user_id,
inputs=request.inputs,
graph_version=graph.version,
execution_context=execution_context,
)
# Log the execution for audit
logger.info(
f"External execution started: agent={agent_id}, "
f"execution={graph_exec.id}, client={token.client_id}, "
f"user={token.user_id}"
)
# Register webhook if provided
if request.webhook_url:
# Get client to check webhook domains
client = await prisma.oauthclient.find_unique(
where={"clientId": token.client_id}
)
if client:
if not validate_webhook_url(request.webhook_url, client.webhookDomains):
raise HTTPException(
status_code=400,
detail="Webhook URL not in allowed domains for this client",
)
# Store webhook registration with client's webhook secret
await prisma.executionwebhook.create(
data={ # type: ignore[typeddict-item]
"executionId": graph_exec.id,
"webhookUrl": request.webhook_url,
"clientId": client.id,
"userId": token.user_id,
"secret": client.webhookSecret,
}
)
logger.info(
f"Registered webhook for execution {graph_exec.id}: {request.webhook_url}"
)
return ExecuteAgentResponse(
execution_id=graph_exec.id,
status="queued",
message="Agent execution has been queued",
)
except ValueError as e:
# Client error - invalid input or configuration
logger.warning(
f"Invalid execution request: agent={agent_id}, "
f"client={token.client_id}, error={str(e)}"
)
raise HTTPException(
status_code=400,
detail=f"Invalid request: {str(e)}",
)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except Exception:
# Server error - log full exception but don't expose details to client
logger.exception(
f"Unexpected error starting execution: agent={agent_id}, "
f"client={token.client_id}"
)
raise HTTPException(
status_code=500,
detail="An internal error occurred while starting execution",
)
@execution_router.get(
"/{execution_id}",
response_model=ExecutionStatusResponse,
)
async def get_execution_status(
execution_id: str,
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> ExecutionStatusResponse:
"""
Get the status of an agent execution.
Returns current status, outputs (if completed), and any error messages.
"""
graph_exec = await execution_db.get_graph_execution(
user_id=token.user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not graph_exec:
raise HTTPException(
status_code=404,
detail=f"Execution {execution_id} not found",
)
# Build response
outputs = None
error = None
if graph_exec.status == AgentExecutionStatus.COMPLETED:
outputs = graph_exec.outputs
elif graph_exec.status == AgentExecutionStatus.FAILED:
# Get error from execution stats
# Note: Currently no standard error field in stats, but could be added
error = "Execution failed"
return ExecutionStatusResponse(
execution_id=execution_id,
status=graph_exec.status.value,
started_at=graph_exec.started_at,
completed_at=graph_exec.ended_at,
outputs=outputs,
error=error,
)
@execution_router.post("/{execution_id}/cancel")
async def cancel_execution(
execution_id: str,
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> dict:
"""
Cancel a running agent execution.
Only executions in QUEUED or RUNNING status can be cancelled.
"""
graph_exec = await execution_db.get_graph_execution(
user_id=token.user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not graph_exec:
raise HTTPException(
status_code=404,
detail=f"Execution {execution_id} not found",
)
# Check if execution can be cancelled
if graph_exec.status not in [
AgentExecutionStatus.QUEUED,
AgentExecutionStatus.RUNNING,
]:
raise HTTPException(
status_code=400,
detail=f"Cannot cancel execution with status {graph_exec.status.value}",
)
# Update execution status to TERMINATED
# Note: This is a simplified implementation. A full implementation would
# need to signal the executor to stop processing.
await prisma.agentgraphexecution.update(
where={"id": execution_id},
data={"executionStatus": AgentExecutionStatus.TERMINATED},
)
logger.info(
f"Execution terminated: execution={execution_id}, "
f"client={token.client_id}, user={token.user_id}"
)
return {"message": "Execution terminated", "execution_id": execution_id}

View File

@@ -0,0 +1,207 @@
"""
Credential Grants endpoints for external OAuth clients.
Allows external applications to:
- List their credential grants (metadata only, NOT credential values)
- Get grant details
- Delete credentials via grants (if permitted)
Credentials are NEVER returned to external applications.
"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, HTTPException, Security
from pydantic import BaseModel
from backend.data import credential_grants as grants_db
from backend.data.db import prisma
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
grants_router = APIRouter(prefix="/grants", tags=["grants"])
# ================================================================
# Response Models
# ================================================================
class GrantSummary(BaseModel):
"""Summary of a credential grant (returned in list endpoints)."""
id: str
provider: str
granted_scopes: list[str]
permissions: list[str]
created_at: datetime
last_used_at: Optional[datetime] = None
expires_at: Optional[datetime] = None
class GrantDetail(BaseModel):
"""Detailed grant information."""
id: str
provider: str
credential_id: str
granted_scopes: list[str]
permissions: list[str]
created_at: datetime
updated_at: datetime
last_used_at: Optional[datetime] = None
expires_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None
# ================================================================
# Endpoints
# ================================================================
@grants_router.get("/", response_model=list[GrantSummary])
async def list_grants(
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
) -> list[GrantSummary]:
"""
List all active credential grants for this client and user.
Returns grant metadata but NOT credential values.
Credentials are never exposed to external applications.
"""
# Get the OAuth client's database ID from the public client_id
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
if not client:
raise HTTPException(status_code=400, detail="Invalid client")
grants = await grants_db.get_grants_for_user_client(
user_id=token.user_id,
client_id=client.id,
include_revoked=False,
include_expired=False,
)
return [
GrantSummary(
id=grant.id,
provider=grant.provider,
granted_scopes=grant.grantedScopes,
permissions=[p.value for p in grant.permissions],
created_at=grant.createdAt,
last_used_at=grant.lastUsedAt,
expires_at=grant.expiresAt,
)
for grant in grants
]
@grants_router.get("/{grant_id}", response_model=GrantDetail)
async def get_grant(
grant_id: str,
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
) -> GrantDetail:
"""
Get detailed information about a specific grant.
Returns grant metadata including scopes and permissions.
Does NOT return the credential value.
"""
# Get the OAuth client's database ID
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
if not client:
raise HTTPException(status_code=400, detail="Invalid client")
grant = await grants_db.get_credential_grant(
grant_id=grant_id,
user_id=token.user_id,
client_id=client.id,
)
if not grant:
raise HTTPException(status_code=404, detail="Grant not found")
# Check if expired
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
raise HTTPException(status_code=404, detail="Grant has expired")
# Check if revoked
if grant.revokedAt:
raise HTTPException(status_code=404, detail="Grant has been revoked")
return GrantDetail(
id=grant.id,
provider=grant.provider,
credential_id=grant.credentialId,
granted_scopes=grant.grantedScopes,
permissions=[p.value for p in grant.permissions],
created_at=grant.createdAt,
updated_at=grant.updatedAt,
last_used_at=grant.lastUsedAt,
expires_at=grant.expiresAt,
revoked_at=grant.revokedAt,
)
@grants_router.delete("/{grant_id}/credential")
async def delete_credential_via_grant(
grant_id: str,
token: OAuthTokenInfo = Security(require_scope("integrations:delete")),
) -> dict:
"""
Delete the underlying credential associated with a grant.
This requires the grant to have the DELETE permission.
Deleting the credential also invalidates all grants for that credential.
"""
from prisma.enums import CredentialGrantPermission
# Get the OAuth client's database ID
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
if not client:
raise HTTPException(status_code=400, detail="Invalid client")
# Get the grant
grant = await grants_db.get_credential_grant(
grant_id=grant_id,
user_id=token.user_id,
client_id=client.id,
)
if not grant:
raise HTTPException(status_code=404, detail="Grant not found")
# Check if grant is valid
if grant.revokedAt:
raise HTTPException(status_code=400, detail="Grant has been revoked")
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
raise HTTPException(status_code=400, detail="Grant has expired")
# Check DELETE permission
if CredentialGrantPermission.DELETE not in grant.permissions:
raise HTTPException(
status_code=403,
detail="Grant does not have DELETE permission for this credential",
)
# Delete the credential using the credentials store
try:
creds_store = IntegrationCredentialsStore()
await creds_store.delete_creds_by_id(
user_id=token.user_id,
credentials_id=grant.credentialId,
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to delete credential: {str(e)}",
)
# Revoke all grants for this credential
await grants_db.revoke_grants_for_credential(
user_id=token.user_id,
credential_id=grant.credentialId,
)
return {"message": "Credential deleted successfully"}

View File

@@ -1,650 +0,0 @@
"""
External API endpoints for integrations and credentials.
This module provides endpoints for external applications (like Autopilot) to:
- Initiate OAuth flows with custom callback URLs
- Complete OAuth flows by exchanging authorization codes
- Create API key, user/password, and host-scoped credentials
- List and manage user credentials
"""
import logging
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union
from urllib.parse import urlparse
from fastapi import APIRouter, Body, HTTPException, Path, Security, status
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field, SecretStr
from backend.data.api_key import APIKeyInfo
from backend.data.model import (
APIKeyCredentials,
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.server.external.middleware import require_permission
from backend.server.integrations.models import get_all_provider_names
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
logger = logging.getLogger(__name__)
settings = Settings()
creds_manager = IntegrationCredentialsManager()
integrations_router = APIRouter(prefix="/integrations", tags=["integrations"])
# ==================== Request/Response Models ==================== #
class OAuthInitiateRequest(BaseModel):
"""Request model for initiating an OAuth flow."""
callback_url: str = Field(
..., description="The external app's callback URL for OAuth redirect"
)
scopes: list[str] = Field(
default_factory=list, description="OAuth scopes to request"
)
state_metadata: dict[str, Any] = Field(
default_factory=dict,
description="Arbitrary metadata to echo back on completion",
)
class OAuthInitiateResponse(BaseModel):
"""Response model for OAuth initiation."""
login_url: str = Field(..., description="URL to redirect user for OAuth consent")
state_token: str = Field(..., description="State token for CSRF protection")
expires_at: int = Field(
..., description="Unix timestamp when the state token expires"
)
class OAuthCompleteRequest(BaseModel):
"""Request model for completing an OAuth flow."""
code: str = Field(..., description="Authorization code from OAuth provider")
state_token: str = Field(..., description="State token from initiate request")
class OAuthCompleteResponse(BaseModel):
"""Response model for OAuth completion."""
credentials_id: str = Field(..., description="ID of the stored credentials")
provider: str = Field(..., description="Provider name")
type: str = Field(..., description="Credential type (oauth2)")
title: Optional[str] = Field(None, description="Credential title")
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
username: Optional[str] = Field(None, description="Username from provider")
state_metadata: dict[str, Any] = Field(
default_factory=dict, description="Echoed metadata from initiate request"
)
class CredentialSummary(BaseModel):
"""Summary of a credential without sensitive data."""
id: str
provider: str
type: CredentialsType
title: Optional[str] = None
scopes: Optional[list[str]] = None
username: Optional[str] = None
host: Optional[str] = None
class ProviderInfo(BaseModel):
"""Information about an integration provider."""
name: str
supports_oauth: bool = False
supports_api_key: bool = False
supports_user_password: bool = False
supports_host_scoped: bool = False
default_scopes: list[str] = Field(default_factory=list)
# ==================== Credential Creation Models ==================== #
class CreateAPIKeyCredentialRequest(BaseModel):
"""Request model for creating API key credentials."""
type: Literal["api_key"] = "api_key"
api_key: str = Field(..., description="The API key")
title: str = Field(..., description="A name for this credential")
expires_at: Optional[int] = Field(
None, description="Unix timestamp when the API key expires"
)
class CreateUserPasswordCredentialRequest(BaseModel):
"""Request model for creating username/password credentials."""
type: Literal["user_password"] = "user_password"
username: str = Field(..., description="Username")
password: str = Field(..., description="Password")
title: str = Field(..., description="A name for this credential")
class CreateHostScopedCredentialRequest(BaseModel):
"""Request model for creating host-scoped credentials."""
type: Literal["host_scoped"] = "host_scoped"
host: str = Field(..., description="Host/domain pattern to match")
headers: dict[str, str] = Field(..., description="Headers to include in requests")
title: str = Field(..., description="A name for this credential")
# Union type for credential creation
CreateCredentialRequest = Annotated[
CreateAPIKeyCredentialRequest
| CreateUserPasswordCredentialRequest
| CreateHostScopedCredentialRequest,
Field(discriminator="type"),
]
class CreateCredentialResponse(BaseModel):
"""Response model for credential creation."""
id: str
provider: str
type: CredentialsType
title: Optional[str] = None
# ==================== Helper Functions ==================== #
def validate_callback_url(callback_url: str) -> bool:
"""Validate that the callback URL is from an allowed origin."""
allowed_origins = settings.config.external_oauth_callback_origins
try:
parsed = urlparse(callback_url)
callback_origin = f"{parsed.scheme}://{parsed.netloc}"
for allowed in allowed_origins:
# Simple origin matching
if callback_origin == allowed:
return True
# Allow localhost with any port in development (proper hostname check)
if parsed.hostname == "localhost":
for allowed in allowed_origins:
allowed_parsed = urlparse(allowed)
if allowed_parsed.hostname == "localhost":
return True
return False
except Exception:
return False
def _get_oauth_handler_for_external(
provider_name: str, redirect_uri: str
) -> "BaseOAuthHandler":
"""Get an OAuth handler configured with an external redirect URI."""
# Ensure blocks are loaded so SDK providers are available
try:
from backend.blocks import load_all_blocks
load_all_blocks()
except Exception as e:
logger.warning(f"Failed to load blocks: {e}")
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name}' does not support OAuth",
)
# Check if this provider has custom OAuth credentials
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_name)
if oauth_credentials and not oauth_credentials.use_secrets:
import os
client_id = (
os.getenv(oauth_credentials.client_id_env_var)
if oauth_credentials.client_id_env_var
else None
)
client_secret = (
os.getenv(oauth_credentials.client_secret_env_var)
if oauth_credentials.client_secret_env_var
else None
)
else:
client_id = getattr(settings.secrets, f"{provider_name}_client_id", None)
client_secret = getattr(
settings.secrets, f"{provider_name}_client_secret", None
)
if not (client_id and client_secret):
logger.error(f"Attempt to use unconfigured {provider_name} OAuth integration")
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail={
"message": f"Integration with provider '{provider_name}' is not configured.",
"hint": "Set client ID and secret in the application's deployment environment",
},
)
handler_class = HANDLERS_BY_NAME[provider_name]
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
)
# ==================== Endpoints ==================== #
@integrations_router.get("/providers", response_model=list[ProviderInfo])
async def list_providers(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[ProviderInfo]:
"""
List all available integration providers.
Returns a list of all providers with their supported credential types.
Most providers support API key credentials, and some also support OAuth.
"""
# Ensure blocks are loaded
try:
from backend.blocks import load_all_blocks
load_all_blocks()
except Exception as e:
logger.warning(f"Failed to load blocks: {e}")
from backend.sdk.registry import AutoRegistry
providers = []
for name in get_all_provider_names():
supports_oauth = name in HANDLERS_BY_NAME
handler_class = HANDLERS_BY_NAME.get(name)
default_scopes = (
getattr(handler_class, "DEFAULT_SCOPES", []) if handler_class else []
)
# Check if provider has specific auth types from SDK registration
sdk_provider = AutoRegistry.get_provider(name)
if sdk_provider and sdk_provider.supported_auth_types:
supports_api_key = "api_key" in sdk_provider.supported_auth_types
supports_user_password = (
"user_password" in sdk_provider.supported_auth_types
)
supports_host_scoped = "host_scoped" in sdk_provider.supported_auth_types
else:
# Fallback for legacy providers
supports_api_key = True # All providers can accept API keys
supports_user_password = name in ("smtp",)
supports_host_scoped = name == "http"
providers.append(
ProviderInfo(
name=name,
supports_oauth=supports_oauth,
supports_api_key=supports_api_key,
supports_user_password=supports_user_password,
supports_host_scoped=supports_host_scoped,
default_scopes=default_scopes,
)
)
return providers
@integrations_router.post(
"/{provider}/oauth/initiate",
response_model=OAuthInitiateResponse,
summary="Initiate OAuth flow",
)
async def initiate_oauth(
provider: Annotated[str, Path(title="The OAuth provider")],
request: OAuthInitiateRequest,
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> OAuthInitiateResponse:
"""
Initiate an OAuth flow for an external application.
This endpoint allows external apps to start an OAuth flow with a custom
callback URL. The callback URL must be from an allowed origin configured
in the platform settings.
Returns a login URL to redirect the user to, along with a state token
for CSRF protection.
"""
# Validate callback URL
if not validate_callback_url(request.callback_url):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
)
# Validate provider
try:
provider_name = ProviderName(provider)
except ValueError:
# Check if it's a dynamically registered provider
if provider not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider}' not found",
)
provider_name = provider
# Get OAuth handler with external callback URL
handler = _get_oauth_handler_for_external(
provider if isinstance(provider_name, str) else provider_name.value,
request.callback_url,
)
# Store state token with external flow metadata
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id=api_key.user_id,
provider=provider if isinstance(provider_name, str) else provider_name.value,
scopes=request.scopes,
callback_url=request.callback_url,
state_metadata=request.state_metadata,
initiated_by_api_key_id=api_key.id,
)
# Build login URL
login_url = handler.get_login_url(
request.scopes, state_token, code_challenge=code_challenge
)
# Calculate expiration (10 minutes from now)
from datetime import datetime, timedelta, timezone
expires_at = int((datetime.now(timezone.utc) + timedelta(minutes=10)).timestamp())
return OAuthInitiateResponse(
login_url=login_url,
state_token=state_token,
expires_at=expires_at,
)
@integrations_router.post(
"/{provider}/oauth/complete",
response_model=OAuthCompleteResponse,
summary="Complete OAuth flow",
)
async def complete_oauth(
provider: Annotated[str, Path(title="The OAuth provider")],
request: OAuthCompleteRequest,
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> OAuthCompleteResponse:
"""
Complete an OAuth flow by exchanging the authorization code for tokens.
This endpoint should be called after the user has authorized the application
and been redirected back to the external app's callback URL with an
authorization code.
"""
# Verify state token
valid_state = await creds_manager.store.verify_state_token(
api_key.user_id, request.state_token, provider
)
if not valid_state:
logger.warning(f"Invalid or expired state token for provider {provider}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state token",
)
# Verify this is an external flow (callback_url must be set)
if not valid_state.callback_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="State token was not created for external OAuth flow",
)
# Get OAuth handler with the original callback URL
handler = _get_oauth_handler_for_external(provider, valid_state.callback_url)
try:
scopes = valid_state.scopes
scopes = handler.handle_default_scopes(scopes)
credentials = await handler.exchange_code_for_tokens(
request.code, scopes, valid_state.code_verifier
)
# Handle Linear's space-separated scopes
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
credentials.scopes = credentials.scopes[0].split(" ")
# Check scope mismatch
if not set(scopes).issubset(set(credentials.scopes)):
logger.warning(
f"Granted scopes {credentials.scopes} for provider {provider} "
f"do not include all requested scopes {scopes}"
)
except Exception as e:
logger.error(f"OAuth2 Code->Token exchange failed for provider {provider}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
)
# Store credentials
await creds_manager.create(api_key.user_id, credentials)
logger.info(f"Successfully completed external OAuth for provider {provider}")
return OAuthCompleteResponse(
credentials_id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
state_metadata=valid_state.state_metadata,
)
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
async def list_credentials(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
"""
List all credentials for the authenticated user.
Returns metadata about each credential without exposing sensitive tokens.
"""
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
return [
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@integrations_router.get(
"/{provider}/credentials", response_model=list[CredentialSummary]
)
async def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
"""
List credentials for a specific provider.
"""
credentials = await creds_manager.store.get_creds_by_provider(
api_key.user_id, provider
)
return [
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@integrations_router.post(
"/{provider}/credentials",
response_model=CreateCredentialResponse,
status_code=status.HTTP_201_CREATED,
summary="Create credentials",
)
async def create_credential(
provider: Annotated[str, Path(title="The provider to create credentials for")],
request: Union[
CreateAPIKeyCredentialRequest,
CreateUserPasswordCredentialRequest,
CreateHostScopedCredentialRequest,
] = Body(..., discriminator="type"),
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> CreateCredentialResponse:
"""
Create non-OAuth credentials for a provider.
Supports creating:
- API key credentials (type: "api_key")
- Username/password credentials (type: "user_password")
- Host-scoped credentials (type: "host_scoped")
For OAuth credentials, use the OAuth initiate/complete flow instead.
"""
# Validate provider exists
all_providers = get_all_provider_names()
if provider not in all_providers:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider}' not found",
)
# Create the appropriate credential type
credentials: Credentials
if request.type == "api_key":
credentials = APIKeyCredentials(
provider=provider,
api_key=SecretStr(request.api_key),
title=request.title,
expires_at=request.expires_at,
)
elif request.type == "user_password":
credentials = UserPasswordCredentials(
provider=provider,
username=SecretStr(request.username),
password=SecretStr(request.password),
title=request.title,
)
elif request.type == "host_scoped":
# Convert string headers to SecretStr
secret_headers = {k: SecretStr(v) for k, v in request.headers.items()}
credentials = HostScopedCredentials(
provider=provider,
host=request.host,
headers=secret_headers,
title=request.title,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported credential type: {request.type}",
)
# Store credentials
try:
await creds_manager.create(api_key.user_id, credentials)
except Exception as e:
logger.error(f"Failed to store credentials: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to store credentials: {str(e)}",
)
logger.info(f"Created {request.type} credentials for provider {provider}")
return CreateCredentialResponse(
id=credentials.id,
provider=provider,
type=credentials.type,
title=credentials.title,
)
class DeleteCredentialResponse(BaseModel):
"""Response model for deleting a credential."""
deleted: bool = Field(..., description="Whether the credential was deleted")
credentials_id: str = Field(..., description="ID of the deleted credential")
@integrations_router.delete(
"/{provider}/credentials/{cred_id}",
response_model=DeleteCredentialResponse,
)
async def delete_credential(
provider: Annotated[str, Path(title="The provider")],
cred_id: Annotated[str, Path(title="The credential ID to delete")],
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
),
) -> DeleteCredentialResponse:
"""
Delete a credential.
Note: This does not revoke the tokens with the provider. For full cleanup,
use the main API's delete endpoint which handles webhook cleanup and
token revocation.
"""
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if creds.provider != provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
)
await creds_manager.delete(api_key.user_id, cred_id)
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)

View File

@@ -1,148 +0,0 @@
"""External API routes for chat tools - stateless HTTP endpoints.
Note: These endpoints use ephemeral sessions that are not persisted to Redis.
As a result, session-based rate limiting (max_agent_runs, max_agent_schedules)
is not enforced for external API calls. Each request creates a fresh session
with zeroed counters. Rate limiting for external API consumers should be
handled separately (e.g., via API key quotas).
"""
import logging
from typing import Any
from fastapi import APIRouter, Security
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.data.api_key import APIKeyInfo
from backend.server.external.middleware import require_permission
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
from backend.server.v2.chat.tools.models import ToolResponseBase
logger = logging.getLogger(__name__)
tools_router = APIRouter(prefix="/tools", tags=["tools"])
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
# Request models
class FindAgentRequest(BaseModel):
query: str = Field(..., description="Search query for finding agents")
class RunAgentRequest(BaseModel):
"""Request to run or schedule an agent.
The tool automatically handles the setup flow:
- First call returns available inputs so user can decide what values to use
- Returns missing credentials if user needs to configure them
- Executes when inputs are provided OR use_defaults=true
- Schedules execution if schedule_name and cron are provided
"""
username_agent_slug: str = Field(
...,
description="The marketplace agent slug (e.g., 'username/agent-name')",
)
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Dictionary of input values for the agent",
)
use_defaults: bool = Field(
default=False,
description="Set to true to run with default values (user must confirm)",
)
schedule_name: str | None = Field(
None,
description="Name for scheduled execution (triggers scheduling mode)",
)
cron: str | None = Field(
None,
description="Cron expression (5 fields: minute hour day month weekday)",
)
timezone: str = Field(
default="UTC",
description="IANA timezone (e.g., 'America/New_York', 'UTC')",
)
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id)
@tools_router.post(
path="/find-agent",
)
async def find_agent(
request: FindAgentRequest,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
) -> dict[str, Any]:
"""
Search for agents in the marketplace based on capabilities and user needs.
Args:
request: Search query for finding agents
Returns:
List of matching agents or no results response
"""
session = _create_ephemeral_session(api_key.user_id)
result = await find_agent_tool._execute(
user_id=api_key.user_id,
session=session,
query=request.query,
)
return _response_to_dict(result)
@tools_router.post(
path="/run-agent",
)
async def run_agent(
request: RunAgentRequest,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
) -> dict[str, Any]:
"""
Run or schedule an agent from the marketplace.
The endpoint automatically handles the setup flow:
- Returns missing inputs if required fields are not provided
- Returns missing credentials if user needs to configure them
- Executes immediately if all requirements are met
- Schedules execution if schedule_name and cron are provided
For scheduled execution:
- Cron format: "minute hour day month weekday"
- Examples: "0 9 * * 1-5" (9am weekdays), "0 0 * * *" (daily at midnight)
- Timezone: Use IANA timezone names like "America/New_York"
Args:
request: Agent slug, inputs, and optional schedule config
Returns:
- setup_requirements: If inputs or credentials are missing
- execution_started: If agent was run or scheduled successfully
- error: If something went wrong
"""
session = _create_ephemeral_session(api_key.user_id)
result = await run_agent_tool._execute(
user_id=api_key.user_id,
session=session,
username_agent_slug=request.username_agent_slug,
inputs=request.inputs,
use_defaults=request.use_defaults,
schedule_name=request.schedule_name or "",
cron=request.cron or "",
timezone=request.timezone,
)
return _response_to_dict(result)
def _response_to_dict(result: ToolResponseBase) -> dict[str, Any]:
"""Convert a tool response to a dictionary for JSON serialization."""
return result.model_dump()

View File

@@ -1,295 +0,0 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Literal, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict
import backend.data.block
import backend.server.v2.store.cache as store_cache
import backend.server.v2.store.model as store_model
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import APIKeyInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.server.external.middleware import require_permission
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
v1_router = APIRouter()
class NodeOutput(TypedDict):
key: str
value: Any
class ExecutionNode(TypedDict):
node_id: str
input: Any
output: dict[str, Any]
class ExecutionNodeOutput(TypedDict):
node_id: str
outputs: list[NodeOutput]
class GraphExecutionResult(TypedDict):
execution_id: str
status: str
nodes: list[ExecutionNode]
output: Optional[list[dict[str, str]]]
@v1_router.get(
path="/blocks",
tags=["blocks"],
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
)
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled]
@v1_router.post(
path="/blocks/{block_id}/execute",
tags=["blocks"],
dependencies=[Security(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
)
async def execute_graph_block(
block_id: str,
data: BlockInput,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
output = defaultdict(list)
async for name, data in obj.execute(data):
output[name].append(data)
return output
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],
)
async def execute_graph(
graph_id: str,
graph_version: int,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = await add_graph_execution(
graph_id=graph_id,
user_id=api_key.user_id,
inputs=node_input,
graph_version=graph_version,
)
return {"id": graph_exec.id}
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@v1_router.get(
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
tags=["graphs"],
)
async def get_graph_execution_results(
graph_id: str,
graph_exec_id: str,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
) -> GraphExecutionResult:
graph_exec = await execution_db.get_graph_execution(
user_id=api_key.user_id,
execution_id=graph_exec_id,
include_node_executions=True,
)
if not graph_exec:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
)
if not await graph_db.get_graph(
graph_id=graph_exec.graph_id,
version=graph_exec.graph_version,
user_id=api_key.user_id,
):
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return GraphExecutionResult(
execution_id=graph_exec_id,
status=graph_exec.status.value,
nodes=[
ExecutionNode(
node_id=node_exec.node_id,
input=node_exec.input_data.get("value", node_exec.input_data),
output={k: v for k, v in node_exec.output_data.items()},
)
for node_exec in graph_exec.node_executions
],
output=(
[
{name: value}
for name, values in graph_exec.outputs.items()
for value in values
]
if graph_exec.status == AgentExecutionStatus.COMPLETED
else None
),
)
##############################################
############### Store Endpoints ##############
##############################################
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.StoreAgentsResponse:
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured: Filter to only show featured agents
creator: Filter agents by creator username
sorted_by: Sort agents by "runs", "rating", "name", or "updated_at"
search_query: Search agents by name, subheading and description
category: Filter agents by category
page: Page number for pagination (default 1)
page_size: Number of agents per page (default 20)
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
"""
if page < 1:
raise HTTPException(status_code=422, detail="Page must be greater than 0")
if page_size < 1:
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
username: str,
agent_name: str,
) -> store_model.StoreAgentDetails:
"""
Get details of a specific store agent by username and agent name.
Args:
username: Creator's username
agent_name: Name/slug of the agent
Returns:
StoreAgentDetails: Detailed information about the agent
"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
agent = await store_cache._get_cached_agent_details(
username=username, agent_name=agent_name
)
return agent
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
"""
Get a paginated list of store creators with optional filtering and sorting.
Args:
featured: Filter to only show featured creators
search_query: Search creators by profile description
sorted_by: Sort by "agent_rating", "agent_runs", or "num_agents"
page: Page number for pagination (default 1)
page_size: Number of creators per page (default 20)
Returns:
CreatorsResponse: Paginated list of creators matching the filters
"""
if page < 1:
raise HTTPException(status_code=422, detail="Page must be greater than 0")
if page_size < 1:
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
return creators
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorDetails,
)
async def get_store_creator(
username: str,
) -> store_model.CreatorDetails:
"""
Get details of a specific store creator by username.
Args:
username: Creator's username
Returns:
CreatorDetails: Detailed information about the creator
"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,471 @@
"""
Security utilities for the integration connect popup flow.
Handles state management, nonce validation, and origin verification
for the OAuth-style popup flow when connecting integrations.
"""
import hashlib
import logging
import secrets
from datetime import datetime, timezone
from typing import Any, Optional
from urllib.parse import urlparse
from prisma.models import OAuthClient
from pydantic import BaseModel
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# State expiration time
STATE_EXPIRATION_SECONDS = 600 # 10 minutes
NONCE_EXPIRATION_SECONDS = 3600 # 1 hour (nonces valid for longer to prevent races)
LOGIN_STATE_EXPIRATION_SECONDS = 600 # 10 minutes for login redirect flow
class ConnectState(BaseModel):
"""Pydantic model for connect state stored in Redis."""
user_id: str
client_id: str
provider: str
requested_scopes: list[str]
redirect_origin: str
nonce: str
credential_id: Optional[str] = None
created_at: str
expires_at: str
class ConnectContinuationState(BaseModel):
"""
State for continuing the connect flow after OAuth completes.
When a user chooses to "connect new" during the connect flow,
we store this state so we can complete the grant creation after
the OAuth callback.
"""
user_id: str
client_id: str # Public client ID
client_db_id: str # Database UUID of the OAuth client
provider: str
requested_scopes: list[str] # Integration scopes (e.g., "google:gmail.readonly")
redirect_origin: str
nonce: str
created_at: str
class ConnectLoginState(BaseModel):
"""
State for connect flow when user needs to log in first.
When an unauthenticated user tries to access /connect/{provider},
we store the connect parameters and redirect to login. After login,
the user is redirected back to complete the connect flow.
"""
client_id: str
provider: str
requested_scopes: list[str]
redirect_origin: str
nonce: str
created_at: str
expires_at: str
# Continuation state expiration (same as regular state)
CONTINUATION_EXPIRATION_SECONDS = 600 # 10 minutes
async def store_connect_continuation(
user_id: str,
client_id: str,
client_db_id: str,
provider: str,
requested_scopes: list[str],
redirect_origin: str,
nonce: str,
) -> str:
"""
Store continuation state for completing connect flow after OAuth.
Args:
user_id: User initiating the connection
client_id: Public OAuth client ID
client_db_id: Database UUID of the OAuth client
provider: Integration provider name
requested_scopes: Requested integration scopes
redirect_origin: Origin to send postMessage to
nonce: Client-provided nonce for replay protection
Returns:
Continuation token to be stored in OAuth state metadata
"""
token = generate_connect_token()
now = datetime.now(timezone.utc)
state = ConnectContinuationState(
user_id=user_id,
client_id=client_id,
client_db_id=client_db_id,
provider=provider,
requested_scopes=requested_scopes,
redirect_origin=redirect_origin,
nonce=nonce,
created_at=now.isoformat(),
)
redis = await get_redis_async()
key = f"connect_continuation:{token}"
await redis.setex(key, CONTINUATION_EXPIRATION_SECONDS, state.model_dump_json())
logger.debug(f"Stored connect continuation state for token {token[:8]}...")
return token
async def get_connect_continuation(token: str) -> Optional[ConnectContinuationState]:
"""
Get continuation state without consuming it.
Args:
token: Continuation token
Returns:
ConnectContinuationState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_continuation:{token}"
data = await redis.get(key)
if not data:
return None
return ConnectContinuationState.model_validate_json(data)
async def consume_connect_continuation(
token: str,
) -> Optional[ConnectContinuationState]:
"""
Get and consume (delete) continuation state.
This ensures the token can only be used once.
Args:
token: Continuation token
Returns:
ConnectContinuationState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_continuation:{token}"
# Atomic get-and-delete to prevent race conditions
data = await redis.getdel(key)
if not data:
return None
state = ConnectContinuationState.model_validate_json(data)
logger.debug(f"Consumed connect continuation state for token {token[:8]}...")
return state
def generate_connect_token() -> str:
"""Generate a secure random token for connect state."""
return secrets.token_urlsafe(32)
async def store_connect_state(
user_id: str,
client_id: str,
provider: str,
requested_scopes: list[str],
redirect_origin: str,
nonce: str,
credential_id: Optional[str] = None,
) -> str:
"""
Store connect state in Redis and return a state token.
Args:
user_id: User initiating the connection
client_id: OAuth client ID (public identifier)
provider: Integration provider name
requested_scopes: Requested integration scopes
redirect_origin: Origin to send postMessage to
nonce: Client-provided nonce for replay protection
credential_id: Optional existing credential to grant access to
Returns:
State token to be used in the connect flow
"""
token = generate_connect_token()
now = datetime.now(timezone.utc)
expires_at = now.timestamp() + STATE_EXPIRATION_SECONDS
state = ConnectState(
user_id=user_id,
client_id=client_id,
provider=provider,
requested_scopes=requested_scopes,
redirect_origin=redirect_origin,
nonce=nonce,
credential_id=credential_id,
created_at=now.isoformat(),
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
)
redis = await get_redis_async()
key = f"connect_state:{token}"
await redis.setex(key, STATE_EXPIRATION_SECONDS, state.model_dump_json())
logger.debug(f"Stored connect state for token {token[:8]}...")
return token
async def get_connect_state(token: str) -> Optional[ConnectState]:
"""
Get connect state without consuming it.
Args:
token: State token
Returns:
ConnectState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_state:{token}"
data = await redis.get(key)
if not data:
return None
return ConnectState.model_validate_json(data)
async def consume_connect_state(token: str) -> Optional[ConnectState]:
"""
Get and consume (delete) connect state.
This ensures the token can only be used once.
Args:
token: State token
Returns:
ConnectState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_state:{token}"
# Atomic get-and-delete to prevent race conditions
data = await redis.getdel(key)
if not data:
return None
state = ConnectState.model_validate_json(data)
logger.debug(f"Consumed connect state for token {token[:8]}...")
return state
async def store_connect_login_state(
client_id: str,
provider: str,
requested_scopes: list[str],
redirect_origin: str,
nonce: str,
) -> str:
"""
Store connect parameters for unauthenticated users.
When a user isn't logged in, we store the connect params and redirect
to login. After login, the frontend calls /connect/resume with the token.
Args:
client_id: OAuth client ID
provider: Integration provider name
requested_scopes: Requested integration scopes
redirect_origin: Origin to send postMessage to
nonce: Client-provided nonce for replay protection
Returns:
Login state token to be used after login completes
"""
token = generate_connect_token()
now = datetime.now(timezone.utc)
expires_at = now.timestamp() + LOGIN_STATE_EXPIRATION_SECONDS
state = ConnectLoginState(
client_id=client_id,
provider=provider,
requested_scopes=requested_scopes,
redirect_origin=redirect_origin,
nonce=nonce,
created_at=now.isoformat(),
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
)
redis = await get_redis_async()
key = f"connect_login_state:{token}"
await redis.setex(key, LOGIN_STATE_EXPIRATION_SECONDS, state.model_dump_json())
logger.debug(f"Stored connect login state for token {token[:8]}...")
return token
async def get_connect_login_state(token: str) -> Optional[ConnectLoginState]:
"""
Get connect login state without consuming it.
Args:
token: Login state token
Returns:
ConnectLoginState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_login_state:{token}"
data = await redis.get(key)
if not data:
return None
return ConnectLoginState.model_validate_json(data)
async def consume_connect_login_state(token: str) -> Optional[ConnectLoginState]:
"""
Get and consume (delete) connect login state.
This ensures the token can only be used once.
Args:
token: Login state token
Returns:
ConnectLoginState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_login_state:{token}"
# Atomic get-and-delete to prevent race conditions
data = await redis.getdel(key)
if not data:
return None
state = ConnectLoginState.model_validate_json(data)
logger.debug(f"Consumed connect login state for token {token[:8]}...")
return state
async def validate_nonce(client_id: str, nonce: str) -> bool:
"""
Validate that a nonce hasn't been used before (replay protection).
Uses atomic SET NX EX for check-and-set with automatic TTL expiry.
Args:
client_id: OAuth client ID
nonce: Client-provided nonce
Returns:
True if nonce is valid (not replayed)
"""
redis = await get_redis_async()
# Create a hash of the nonce for storage
nonce_hash = hashlib.sha256(nonce.encode()).hexdigest()
key = f"nonce:{client_id}:{nonce_hash}"
# Atomic set-if-not-exists with expiration (prevents race condition)
was_set = await redis.set(key, "1", nx=True, ex=NONCE_EXPIRATION_SECONDS)
if was_set:
return True
logger.warning(f"Nonce replay detected for client {client_id}")
return False
def validate_redirect_origin(origin: str, client: OAuthClient) -> bool:
"""
Validate that a redirect origin is allowed for the client.
The origin must match one of the client's registered redirect URIs
or webhook domains.
Args:
origin: Origin URL to validate
client: OAuth client to check against
Returns:
True if origin is allowed
"""
from backend.util.url import hostname_matches_any_domain
try:
parsed_origin = urlparse(origin)
origin_host = parsed_origin.netloc.lower()
# Check against redirect URIs
for redirect_uri in client.redirectUris:
parsed_redirect = urlparse(redirect_uri)
if parsed_redirect.netloc.lower() == origin_host:
return True
# Check against webhook domains
if hostname_matches_any_domain(origin_host, client.webhookDomains):
return True
return False
except Exception:
return False
def create_post_message_data(
success: bool,
grant_id: Optional[str] = None,
credential_id: Optional[str] = None,
provider: Optional[str] = None,
error: Optional[str] = None,
error_description: Optional[str] = None,
nonce: Optional[str] = None,
) -> dict[str, Any]:
"""
Create the postMessage data to send back to the opener.
Args:
success: Whether the operation succeeded
grant_id: ID of the created grant (if successful)
credential_id: ID of the credential (if successful)
provider: Provider name
error: Error code (if failed)
error_description: Human-readable error description
nonce: Original nonce for correlation
Returns:
Dictionary to be sent via postMessage
"""
data: dict[str, Any] = {
"type": "autogpt_connect_result",
"success": success,
}
if nonce:
data["nonce"] = nonce
if success:
data["grant_id"] = grant_id
data["credential_id"] = credential_id
data["provider"] = provider
else:
data["error"] = error
data["error_description"] = error_description
return data

View File

@@ -0,0 +1,20 @@
"""
OAuth 2.0 Provider module for AutoGPT Platform.
This module implements AutoGPT as an OAuth 2.0 Authorization Server,
allowing external applications to authenticate users and access
platform resources with user consent.
Key components:
- router.py: OAuth authorization and token endpoints
- discovery_router.py: OIDC discovery endpoints
- client_router.py: OAuth client management
- token_service.py: JWT generation and validation
- service.py: Core OAuth business logic
"""
from backend.server.oauth.client_router import client_router
from backend.server.oauth.discovery_router import discovery_router
from backend.server.oauth.router import oauth_router
__all__ = ["oauth_router", "discovery_router", "client_router"]

View File

@@ -0,0 +1,367 @@
"""
OAuth Client Management endpoints.
Implements self-service client registration and management:
- POST /oauth/clients - Register a new client
- GET /oauth/clients - List owned clients
- GET /oauth/clients/{client_id} - Get client details
- PATCH /oauth/clients/{client_id} - Update client
- DELETE /oauth/clients/{client_id} - Delete client
- POST /oauth/clients/{client_id}/rotate-secret - Rotate client secret
"""
import hashlib
import secrets
from autogpt_libs.auth import get_user_id
from fastapi import APIRouter, HTTPException, Security
from prisma.enums import OAuthClientStatus
from pydantic import BaseModel
from backend.data.db import prisma
from backend.server.oauth.models import (
ClientResponse,
ClientSecretResponse,
OAuthScope,
RegisterClientRequest,
UpdateClientRequest,
)
client_router = APIRouter(prefix="/oauth/clients", tags=["oauth-clients"])
def _generate_client_id() -> str:
"""Generate a unique client ID."""
return f"app_{secrets.token_urlsafe(16)}"
def _generate_client_secret() -> str:
"""Generate a secure client secret."""
return secrets.token_urlsafe(32)
def _generate_webhook_secret() -> str:
"""Generate a secure webhook secret for HMAC signing."""
return secrets.token_urlsafe(32)
def _hash_secret(secret: str, salt: str) -> str:
"""Hash a client secret with salt."""
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
def _client_to_response(client) -> ClientResponse:
"""Convert Prisma client to response model."""
return ClientResponse(
id=client.id,
client_id=client.clientId,
client_type=client.clientType,
name=client.name,
description=client.description,
logo_url=client.logoUrl,
homepage_url=client.homepageUrl,
privacy_policy_url=client.privacyPolicyUrl,
terms_of_service_url=client.termsOfServiceUrl,
redirect_uris=client.redirectUris,
allowed_scopes=client.allowedScopes,
webhook_domains=client.webhookDomains,
status=client.status,
created_at=client.createdAt,
updated_at=client.updatedAt,
)
# Default allowed scopes for new clients
DEFAULT_ALLOWED_SCOPES = [
OAuthScope.OPENID.value,
OAuthScope.PROFILE.value,
OAuthScope.EMAIL.value,
OAuthScope.INTEGRATIONS_LIST.value,
OAuthScope.INTEGRATIONS_CONNECT.value,
OAuthScope.INTEGRATIONS_DELETE.value,
OAuthScope.AGENTS_EXECUTE.value,
]
@client_router.post("/", response_model=ClientSecretResponse)
async def register_client(
request: RegisterClientRequest,
user_id: str = Security(get_user_id),
) -> ClientSecretResponse:
"""
Register a new OAuth client.
The client is immediately active (no admin approval required).
For confidential clients, the client_secret is returned only once.
The webhook_secret is always generated and returned only once.
"""
# Generate client credentials
client_id = _generate_client_id()
client_secret = None
client_secret_hash = None
client_secret_salt = None
if request.client_type == "confidential":
client_secret = _generate_client_secret()
client_secret_salt = secrets.token_urlsafe(16)
client_secret_hash = _hash_secret(client_secret, client_secret_salt)
# Generate webhook secret for HMAC signing
webhook_secret = _generate_webhook_secret()
# Create client
await prisma.oauthclient.create(
data={ # type: ignore[typeddict-item]
"clientId": client_id,
"clientSecretHash": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"clientType": request.client_type,
"name": request.name,
"description": request.description,
"logoUrl": str(request.logo_url) if request.logo_url else None,
"homepageUrl": str(request.homepage_url) if request.homepage_url else None,
"privacyPolicyUrl": (
str(request.privacy_policy_url) if request.privacy_policy_url else None
),
"termsOfServiceUrl": (
str(request.terms_of_service_url)
if request.terms_of_service_url
else None
),
"redirectUris": request.redirect_uris,
"allowedScopes": DEFAULT_ALLOWED_SCOPES,
"webhookDomains": request.webhook_domains,
"webhookSecret": webhook_secret,
"status": OAuthClientStatus.ACTIVE,
"ownerId": user_id,
}
)
return ClientSecretResponse(
client_id=client_id,
client_secret=client_secret or "",
webhook_secret=webhook_secret,
)
@client_router.get("/", response_model=list[ClientResponse])
async def list_clients(
user_id: str = Security(get_user_id),
) -> list[ClientResponse]:
"""List all OAuth clients owned by the current user."""
clients = await prisma.oauthclient.find_many(
where={"ownerId": user_id},
order={"createdAt": "desc"},
)
return [_client_to_response(c) for c in clients]
@client_router.get("/{client_id}", response_model=ClientResponse)
async def get_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Get details of a specific OAuth client."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
return _client_to_response(client)
@client_router.patch("/{client_id}", response_model=ClientResponse)
async def update_client(
client_id: str,
request: UpdateClientRequest,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Update an OAuth client."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
# Build update data
update_data: dict[str, str | list[str] | None] = {}
if request.name is not None:
update_data["name"] = request.name
if request.description is not None:
update_data["description"] = request.description
if request.logo_url is not None:
update_data["logoUrl"] = str(request.logo_url)
if request.homepage_url is not None:
update_data["homepageUrl"] = str(request.homepage_url)
if request.privacy_policy_url is not None:
update_data["privacyPolicyUrl"] = str(request.privacy_policy_url)
if request.terms_of_service_url is not None:
update_data["termsOfServiceUrl"] = str(request.terms_of_service_url)
if request.redirect_uris is not None:
update_data["redirectUris"] = request.redirect_uris
if request.webhook_domains is not None:
update_data["webhookDomains"] = request.webhook_domains
if not update_data:
return _client_to_response(client)
updated = await prisma.oauthclient.update(
where={"id": client.id},
data=update_data, # type: ignore[arg-type]
)
return _client_to_response(updated)
@client_router.delete("/{client_id}")
async def delete_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> dict:
"""
Delete an OAuth client.
This will also revoke all tokens and authorizations for this client.
"""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
# Delete cascades will handle tokens, codes, and authorizations
await prisma.oauthclient.delete(where={"id": client.id})
return {"status": "deleted", "client_id": client_id}
@client_router.post("/{client_id}/rotate-secret", response_model=ClientSecretResponse)
async def rotate_client_secret(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientSecretResponse:
"""
Rotate the client secret for a confidential client.
The new secret is returned only once. All existing tokens remain valid.
Also rotates the webhook secret for security.
"""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
if client.clientType != "confidential":
raise HTTPException(
status_code=400,
detail="Cannot rotate secret for public clients",
)
# Generate new secrets
new_secret = _generate_client_secret()
new_salt = secrets.token_urlsafe(16)
new_hash = _hash_secret(new_secret, new_salt)
new_webhook_secret = _generate_webhook_secret()
await prisma.oauthclient.update(
where={"id": client.id},
data={
"clientSecretHash": new_hash,
"clientSecretSalt": new_salt,
"webhookSecret": new_webhook_secret,
},
)
return ClientSecretResponse(
client_id=client_id,
client_secret=new_secret,
webhook_secret=new_webhook_secret,
)
class WebhookSecretResponse(BaseModel):
"""Response containing newly generated webhook secret."""
client_id: str
webhook_secret: str
@client_router.post(
"/{client_id}/rotate-webhook-secret", response_model=WebhookSecretResponse
)
async def rotate_webhook_secret(
client_id: str,
user_id: str = Security(get_user_id),
) -> WebhookSecretResponse:
"""
Rotate only the webhook secret for a client.
The new webhook secret is returned only once.
"""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
# Generate new webhook secret
new_webhook_secret = _generate_webhook_secret()
await prisma.oauthclient.update(
where={"id": client.id},
data={"webhookSecret": new_webhook_secret},
)
return WebhookSecretResponse(
client_id=client_id,
webhook_secret=new_webhook_secret,
)
@client_router.post("/{client_id}/suspend")
async def suspend_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Suspend an OAuth client (prevents new authorizations)."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
updated = await prisma.oauthclient.update(
where={"id": client.id},
data={"status": OAuthClientStatus.SUSPENDED},
)
return _client_to_response(updated)
@client_router.post("/{client_id}/activate")
async def activate_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Reactivate a suspended OAuth client."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
updated = await prisma.oauthclient.update(
where={"id": client.id},
data={"status": OAuthClientStatus.ACTIVE},
)
return _client_to_response(updated)

View File

@@ -0,0 +1,678 @@
"""
Server-rendered HTML templates for OAuth consent UI.
These templates are used for the OAuth authorization flow
when the user needs to approve access for an external application.
"""
import html
from typing import Optional
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
def _base_styles() -> str:
"""Common CSS styles for all OAuth pages."""
return """
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
color: #e4e4e7;
}
.container {
background: #27272a;
border-radius: 16px;
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
max-width: 420px;
width: 100%;
padding: 32px;
}
.header {
text-align: center;
margin-bottom: 24px;
}
.logo {
width: 64px;
height: 64px;
border-radius: 12px;
margin-bottom: 16px;
background: #3f3f46;
display: flex;
align-items: center;
justify-content: center;
margin-left: auto;
margin-right: auto;
}
.logo img {
max-width: 48px;
max-height: 48px;
border-radius: 8px;
}
.logo-placeholder {
font-size: 28px;
color: #a1a1aa;
}
h1 {
font-size: 20px;
font-weight: 600;
margin-bottom: 8px;
}
.subtitle {
color: #a1a1aa;
font-size: 14px;
}
.app-name {
color: #22d3ee;
font-weight: 600;
}
.divider {
height: 1px;
background: #3f3f46;
margin: 24px 0;
}
.scopes-section h2 {
font-size: 14px;
font-weight: 500;
color: #a1a1aa;
margin-bottom: 16px;
}
.scope-item {
display: flex;
align-items: flex-start;
gap: 12px;
padding: 12px 0;
border-bottom: 1px solid #3f3f46;
}
.scope-item:last-child {
border-bottom: none;
}
.scope-icon {
width: 20px;
height: 20px;
color: #22d3ee;
flex-shrink: 0;
margin-top: 2px;
}
.scope-text {
font-size: 14px;
line-height: 1.5;
}
.buttons {
display: flex;
gap: 12px;
margin-top: 24px;
}
.btn {
flex: 1;
padding: 12px 24px;
border-radius: 8px;
font-size: 14px;
font-weight: 500;
cursor: pointer;
border: none;
transition: all 0.2s;
}
.btn-cancel {
background: #3f3f46;
color: #e4e4e7;
}
.btn-cancel:hover {
background: #52525b;
}
.btn-allow {
background: #22d3ee;
color: #0f172a;
}
.btn-allow:hover {
background: #06b6d4;
}
.footer {
margin-top: 24px;
text-align: center;
font-size: 12px;
color: #71717a;
}
.footer a {
color: #a1a1aa;
text-decoration: none;
}
.footer a:hover {
text-decoration: underline;
}
.error-container {
text-align: center;
}
.error-icon {
width: 64px;
height: 64px;
margin: 0 auto 16px;
color: #ef4444;
}
.error-title {
color: #ef4444;
font-size: 18px;
font-weight: 600;
margin-bottom: 8px;
}
.error-message {
color: #a1a1aa;
font-size: 14px;
margin-bottom: 24px;
}
.success-icon {
width: 64px;
height: 64px;
margin: 0 auto 16px;
color: #22c55e;
}
.success-title {
color: #22c55e;
}
"""
def _check_icon() -> str:
"""SVG checkmark icon."""
return """
<svg class="scope-icon" viewBox="0 0 20 20" fill="currentColor">
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/>
</svg>
"""
def _error_icon() -> str:
"""SVG error icon."""
return """
<svg class="error-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="12" r="10"/>
<line x1="15" y1="9" x2="9" y2="15"/>
<line x1="9" y1="9" x2="15" y2="15"/>
</svg>
"""
def _success_icon() -> str:
"""SVG success icon."""
return """
<svg class="success-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="12" r="10"/>
<path d="M9 12l2 2 4-4"/>
</svg>
"""
def render_consent_page(
client_name: str,
client_logo: Optional[str],
scopes: list[str],
consent_token: str,
action_url: str,
privacy_policy_url: Optional[str] = None,
terms_url: Optional[str] = None,
) -> str:
"""
Render the OAuth consent page.
Args:
client_name: Name of the requesting application
client_logo: URL to the client's logo (optional)
scopes: List of requested scopes
consent_token: CSRF token for the consent form
action_url: URL to submit the consent form
privacy_policy_url: Client's privacy policy URL (optional)
terms_url: Client's terms of service URL (optional)
Returns:
HTML string for the consent page
"""
# Escape user-provided values to prevent XSS
safe_client_name = html.escape(client_name)
safe_client_logo = html.escape(client_logo) if client_logo else None
# Build logo HTML
if safe_client_logo:
logo_html = f'<img src="{safe_client_logo}" alt="{safe_client_name}">'
else:
logo_html = f'<span class="logo-placeholder">{html.escape(client_name[0].upper())}</span>'
# Build scopes HTML
scopes_html = ""
for scope in scopes:
description = SCOPE_DESCRIPTIONS.get(scope, scope)
scopes_html += f"""
<div class="scope-item">
{_check_icon()}
<span class="scope-text">{html.escape(description)}</span>
</div>
"""
# Build footer links (escape URLs)
footer_links = []
if privacy_policy_url:
footer_links.append(
f'<a href="{html.escape(privacy_policy_url)}" target="_blank">Privacy Policy</a>'
)
if terms_url:
footer_links.append(
f'<a href="{html.escape(terms_url)}" target="_blank">Terms of Service</a>'
)
footer_html = " &bull; ".join(footer_links) if footer_links else ""
# Escape action_url and consent_token
safe_action_url = html.escape(action_url)
safe_consent_token = html.escape(consent_token)
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorize {safe_client_name} - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="header">
<div class="logo">{logo_html}</div>
<h1>Authorize <span class="app-name">{safe_client_name}</span></h1>
<p class="subtitle">wants to access your AutoGPT account</p>
</div>
<div class="divider"></div>
<div class="scopes-section">
<h2>This will allow {safe_client_name} to:</h2>
{scopes_html}
</div>
<form method="POST" action="{safe_action_url}">
<input type="hidden" name="consent_token" value="{safe_consent_token}">
<div class="buttons">
<button type="submit" name="authorize" value="false" class="btn btn-cancel">
Cancel
</button>
<button type="submit" name="authorize" value="true" class="btn btn-allow">
Allow
</button>
</div>
</form>
{f'<div class="footer">{footer_html}</div>' if footer_html else ''}
</div>
</body>
</html>
"""
def render_error_page(
error: str,
error_description: str,
redirect_url: Optional[str] = None,
) -> str:
"""
Render an OAuth error page.
Args:
error: Error code
error_description: Human-readable error description
redirect_url: Optional URL to redirect back (if safe)
Returns:
HTML string for the error page
"""
# Escape user-provided values to prevent XSS
safe_error = html.escape(error)
safe_error_description = html.escape(error_description)
redirect_html = ""
if redirect_url:
safe_redirect_url = html.escape(redirect_url)
redirect_html = f"""
<a href="{safe_redirect_url}" class="btn btn-cancel" style="display: inline-block; text-decoration: none;">
Go Back
</a>
"""
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorization Error - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="error-container">
{_error_icon()}
<h1 class="error-title">Authorization Failed</h1>
<p class="error-message">{safe_error_description}</p>
<p class="error-message" style="font-size: 12px; color: #52525b;">
Error code: {safe_error}
</p>
{redirect_html}
</div>
</div>
</body>
</html>
"""
def render_success_page(
message: str,
redirect_origin: Optional[str] = None,
post_message_data: Optional[dict] = None,
) -> str:
"""
Render a success page, optionally with postMessage for popup flows.
Args:
message: Success message to display
redirect_origin: Origin for postMessage (popup flows)
post_message_data: Data to send via postMessage (popup flows)
Returns:
HTML string for the success page
"""
# Escape user-provided values to prevent XSS
safe_message = html.escape(message)
# PostMessage script for popup flows
post_message_script = ""
if redirect_origin and post_message_data:
import json
# json.dumps escapes for JS context, but we also escape < > for HTML context
safe_json_origin = (
json.dumps(redirect_origin).replace("<", "\\u003c").replace(">", "\\u003e")
)
safe_json_data = (
json.dumps(post_message_data)
.replace("<", "\\u003c")
.replace(">", "\\u003e")
)
post_message_script = f"""
<script>
(function() {{
var targetOrigin = {safe_json_origin};
var message = {safe_json_data};
if (window.opener) {{
window.opener.postMessage(message, targetOrigin);
setTimeout(function() {{ window.close(); }}, 1000);
}}
}})();
</script>
"""
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorization Successful - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="error-container">
{_success_icon()}
<h1 class="success-title">Success!</h1>
<p class="error-message">{safe_message}</p>
<p class="error-message" style="font-size: 12px;">
This window will close automatically...
</p>
</div>
</div>
{post_message_script}
</body>
</html>
"""
def render_login_redirect_page(login_url: str) -> str:
"""
Render a page that redirects to login.
Args:
login_url: URL to redirect to for login
Returns:
HTML string with auto-redirect
"""
# Escape URL to prevent XSS
safe_login_url = html.escape(login_url)
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="refresh" content="0;url={safe_login_url}">
<title>Login Required - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="error-container">
<p class="error-message">Redirecting to login...</p>
<a href="{safe_login_url}" class="btn btn-allow" style="display: inline-block; text-decoration: none;">
Click here if not redirected
</a>
</div>
</div>
</body>
</html>
"""
def _login_form_styles() -> str:
"""Additional CSS styles for login form."""
return """
.form-group {
margin-bottom: 16px;
}
.form-group label {
display: block;
font-size: 14px;
font-weight: 500;
color: #a1a1aa;
margin-bottom: 8px;
}
.form-group input {
width: 100%;
padding: 12px 16px;
border-radius: 8px;
border: 1px solid #3f3f46;
background: #18181b;
color: #e4e4e7;
font-size: 14px;
outline: none;
transition: border-color 0.2s;
}
.form-group input:focus {
border-color: #22d3ee;
}
.form-group input::placeholder {
color: #52525b;
}
.error-alert {
background: rgba(239, 68, 68, 0.1);
border: 1px solid #ef4444;
border-radius: 8px;
padding: 12px 16px;
margin-bottom: 16px;
color: #fca5a5;
font-size: 14px;
}
.btn-login {
width: 100%;
padding: 12px 24px;
border-radius: 8px;
font-size: 14px;
font-weight: 500;
cursor: pointer;
border: none;
background: #22d3ee;
color: #0f172a;
transition: all 0.2s;
margin-top: 8px;
}
.btn-login:hover {
background: #06b6d4;
}
.btn-login:disabled {
background: #3f3f46;
color: #71717a;
cursor: not-allowed;
}
.signup-link {
text-align: center;
margin-top: 16px;
font-size: 14px;
color: #a1a1aa;
}
.signup-link a {
color: #22d3ee;
text-decoration: none;
}
.signup-link a:hover {
text-decoration: underline;
}
"""
def render_login_page(
action_url: str,
login_state: str,
client_name: Optional[str] = None,
error_message: Optional[str] = None,
signup_url: Optional[str] = None,
browser_login_url: Optional[str] = None,
) -> str:
"""
Render an embedded login page for OAuth flow.
Args:
action_url: URL to submit the login form
login_state: State token to preserve OAuth parameters
client_name: Name of the application requesting access (optional)
error_message: Error message to display (optional)
signup_url: URL to signup page (optional)
browser_login_url: URL to redirect to frontend login (optional)
Returns:
HTML string for the login page
"""
# Escape all user-provided values to prevent XSS
safe_action_url = html.escape(action_url)
safe_login_state = html.escape(login_state)
safe_client_name = html.escape(client_name) if client_name else None
error_html = ""
if error_message:
safe_error_message = html.escape(error_message)
error_html = f'<div class="error-alert">{safe_error_message}</div>'
subtitle = "wants to access your AutoGPT account" if safe_client_name else ""
title_html = (
'<h1>Sign in to <span class="app-name">AutoGPT</span></h1>'
if not safe_client_name
else f'<h1><span class="app-name">{safe_client_name}</span></h1>'
)
signup_html = ""
if signup_url:
safe_signup_url = html.escape(signup_url)
signup_html = f"""
<div class="signup-link">
Don't have an account? <a href="{safe_signup_url}">Sign up</a>
</div>
"""
browser_login_html = ""
if browser_login_url:
safe_browser_login_url = html.escape(browser_login_url)
browser_login_html = f"""
<div class="divider"></div>
<div class="signup-link">
<a href="{safe_browser_login_url}">Sign in with Google or other providers</a>
</div>
"""
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Sign In - AutoGPT</title>
<style>
{_base_styles()}
{_login_form_styles()}
</style>
</head>
<body>
<div class="container">
<div class="header">
<div class="logo">
<span class="logo-placeholder">A</span>
</div>
{title_html}
<p class="subtitle">{subtitle}</p>
</div>
<div class="divider"></div>
{error_html}
<form method="POST" action="{safe_action_url}">
<input type="hidden" name="login_state" value="{safe_login_state}">
<div class="form-group">
<label for="email">Email</label>
<input
type="email"
id="email"
name="email"
placeholder="you@example.com"
required
autocomplete="email"
>
</div>
<div class="form-group">
<label for="password">Password</label>
<input
type="password"
id="password"
name="password"
placeholder="Enter your password"
required
autocomplete="current-password"
>
</div>
<button type="submit" class="btn-login">Sign In</button>
</form>
{signup_html}
{browser_login_html}
</div>
</body>
</html>
"""

View File

@@ -0,0 +1,71 @@
"""
OIDC Discovery endpoints.
Implements:
- GET /.well-known/openid-configuration - OIDC Discovery Document
- GET /.well-known/jwks.json - JSON Web Key Set
"""
from fastapi import APIRouter
from backend.server.oauth.models import JWKS, OpenIDConfiguration
from backend.server.oauth.token_service import get_token_service
from backend.util.settings import Settings
discovery_router = APIRouter(tags=["oidc-discovery"])
@discovery_router.get(
"/.well-known/openid-configuration",
response_model=OpenIDConfiguration,
)
async def openid_configuration() -> OpenIDConfiguration:
"""
OIDC Discovery Document.
Returns metadata about the OAuth 2.0 authorization server including
endpoints, supported features, and algorithms.
"""
settings = Settings()
base_url = settings.config.platform_base_url or "https://platform.agpt.co"
return OpenIDConfiguration(
issuer=base_url,
authorization_endpoint=f"{base_url}/oauth/authorize",
token_endpoint=f"{base_url}/oauth/token",
userinfo_endpoint=f"{base_url}/oauth/userinfo",
revocation_endpoint=f"{base_url}/oauth/revoke",
jwks_uri=f"{base_url}/.well-known/jwks.json",
scopes_supported=[
"openid",
"profile",
"email",
"integrations:list",
"integrations:connect",
"integrations:delete",
"agents:execute",
],
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token"],
token_endpoint_auth_methods_supported=[
"client_secret_post",
"client_secret_basic",
"none", # For public clients with PKCE
],
code_challenge_methods_supported=["S256"],
subject_types_supported=["public"],
id_token_signing_alg_values_supported=["RS256"],
)
@discovery_router.get("/.well-known/jwks.json", response_model=JWKS)
async def jwks() -> dict:
"""
JSON Web Key Set (JWKS).
Returns the public key(s) used to verify JWT signatures.
External applications can use these keys to verify access tokens
and ID tokens issued by this authorization server.
"""
token_service = get_token_service()
return token_service.get_jwks()

View File

@@ -0,0 +1,162 @@
"""
OAuth 2.0 Error Responses (RFC 6749 Section 5.2).
"""
from enum import Enum
from typing import Optional
from urllib.parse import urlencode
from fastapi import HTTPException
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
class OAuthErrorCode(str, Enum):
"""Standard OAuth 2.0 error codes."""
# Authorization endpoint errors (RFC 6749 Section 4.1.2.1)
INVALID_REQUEST = "invalid_request"
UNAUTHORIZED_CLIENT = "unauthorized_client"
ACCESS_DENIED = "access_denied"
UNSUPPORTED_RESPONSE_TYPE = "unsupported_response_type"
INVALID_SCOPE = "invalid_scope"
SERVER_ERROR = "server_error"
TEMPORARILY_UNAVAILABLE = "temporarily_unavailable"
# Token endpoint errors (RFC 6749 Section 5.2)
INVALID_CLIENT = "invalid_client"
INVALID_GRANT = "invalid_grant"
UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type"
# Extension errors
LOGIN_REQUIRED = "login_required"
CONSENT_REQUIRED = "consent_required"
class OAuthErrorResponse(BaseModel):
"""OAuth error response model."""
error: str
error_description: Optional[str] = None
error_uri: Optional[str] = None
class OAuthError(Exception):
"""Base OAuth error exception."""
def __init__(
self,
error: OAuthErrorCode,
description: Optional[str] = None,
uri: Optional[str] = None,
state: Optional[str] = None,
):
self.error = error
self.description = description
self.uri = uri
self.state = state
super().__init__(description or error.value)
def to_response(self) -> OAuthErrorResponse:
"""Convert to response model."""
return OAuthErrorResponse(
error=self.error.value,
error_description=self.description,
error_uri=self.uri,
)
def to_redirect(self, redirect_uri: str) -> RedirectResponse:
"""Convert to redirect response with error in query params."""
params = {"error": self.error.value}
if self.description:
params["error_description"] = self.description
if self.uri:
params["error_uri"] = self.uri
if self.state:
params["state"] = self.state
separator = "&" if "?" in redirect_uri else "?"
url = f"{redirect_uri}{separator}{urlencode(params)}"
return RedirectResponse(url=url, status_code=302)
def to_http_exception(self, status_code: int = 400) -> HTTPException:
"""Convert to FastAPI HTTPException."""
return HTTPException(
status_code=status_code,
detail=self.to_response().model_dump(exclude_none=True),
)
# Convenience error classes
class InvalidRequestError(OAuthError):
"""The request is missing a required parameter or is otherwise malformed."""
def __init__(self, description: str, state: Optional[str] = None):
super().__init__(OAuthErrorCode.INVALID_REQUEST, description, state=state)
class UnauthorizedClientError(OAuthError):
"""The client is not authorized to request an authorization code."""
def __init__(self, description: str, state: Optional[str] = None):
super().__init__(OAuthErrorCode.UNAUTHORIZED_CLIENT, description, state=state)
class AccessDeniedError(OAuthError):
"""The resource owner denied the request."""
def __init__(self, description: str = "Access denied", state: Optional[str] = None):
super().__init__(OAuthErrorCode.ACCESS_DENIED, description, state=state)
class InvalidScopeError(OAuthError):
"""The requested scope is invalid, unknown, or malformed."""
def __init__(self, description: str, state: Optional[str] = None):
super().__init__(OAuthErrorCode.INVALID_SCOPE, description, state=state)
class InvalidClientError(OAuthError):
"""Client authentication failed."""
def __init__(self, description: str = "Invalid client"):
super().__init__(OAuthErrorCode.INVALID_CLIENT, description)
class InvalidGrantError(OAuthError):
"""The provided authorization code or refresh token is invalid."""
def __init__(self, description: str = "Invalid grant"):
super().__init__(OAuthErrorCode.INVALID_GRANT, description)
class UnsupportedGrantTypeError(OAuthError):
"""The authorization grant type is not supported."""
def __init__(self, grant_type: str):
super().__init__(
OAuthErrorCode.UNSUPPORTED_GRANT_TYPE,
f"Grant type '{grant_type}' is not supported",
)
class LoginRequiredError(OAuthError):
"""User must be logged in to complete the request."""
def __init__(self, state: Optional[str] = None):
super().__init__(
OAuthErrorCode.LOGIN_REQUIRED,
"User authentication required",
state=state,
)
class ConsentRequiredError(OAuthError):
"""User consent is required for the requested scopes."""
def __init__(self, state: Optional[str] = None):
super().__init__(
OAuthErrorCode.CONSENT_REQUIRED,
"User consent required",
state=state,
)

View File

@@ -0,0 +1,288 @@
"""
Pydantic models for OAuth 2.0 requests and responses.
"""
from datetime import datetime
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel, Field, HttpUrl
# ============================================================
# Enums and Constants
# ============================================================
class OAuthScope(str, Enum):
"""Supported OAuth scopes."""
# OpenID Connect standard scopes
OPENID = "openid"
PROFILE = "profile"
EMAIL = "email"
# AutoGPT-specific scopes
INTEGRATIONS_LIST = "integrations:list"
INTEGRATIONS_CONNECT = "integrations:connect"
INTEGRATIONS_DELETE = "integrations:delete"
AGENTS_EXECUTE = "agents:execute"
SCOPE_DESCRIPTIONS: dict[str, str] = {
OAuthScope.OPENID.value: "Access your user ID",
OAuthScope.PROFILE.value: "Access your profile information (name)",
OAuthScope.EMAIL.value: "Access your email address",
OAuthScope.INTEGRATIONS_LIST.value: "View your connected integrations",
OAuthScope.INTEGRATIONS_CONNECT.value: "Connect new integrations on your behalf",
OAuthScope.INTEGRATIONS_DELETE.value: "Delete integrations on your behalf",
OAuthScope.AGENTS_EXECUTE.value: "Run agents on your behalf",
}
# ============================================================
# Authorization Request/Response Models
# ============================================================
class AuthorizationRequest(BaseModel):
"""OAuth 2.0 Authorization Request (RFC 6749 Section 4.1.1)."""
response_type: Literal["code"] = Field(
..., description="Must be 'code' for authorization code flow"
)
client_id: str = Field(..., description="Client identifier")
redirect_uri: str = Field(..., description="Redirect URI after authorization")
scope: str = Field(default="", description="Space-separated list of scopes")
state: str = Field(..., description="CSRF protection token (required)")
code_challenge: str = Field(..., description="PKCE code challenge (required)")
code_challenge_method: Literal["S256"] = Field(
default="S256", description="PKCE method (only S256 supported)"
)
nonce: Optional[str] = Field(None, description="OIDC nonce for replay protection")
prompt: Optional[Literal["consent", "login", "none"]] = Field(
None, description="Prompt behavior"
)
class ConsentFormData(BaseModel):
"""Consent form submission data."""
consent_token: str = Field(..., description="CSRF token for consent")
authorize: bool = Field(..., description="Whether user authorized")
# ============================================================
# Token Request/Response Models
# ============================================================
class TokenRequest(BaseModel):
"""OAuth 2.0 Token Request (RFC 6749 Section 4.1.3)."""
grant_type: Literal["authorization_code", "refresh_token"] = Field(
..., description="Grant type"
)
code: Optional[str] = Field(
None, description="Authorization code (for authorization_code grant)"
)
redirect_uri: Optional[str] = Field(
None, description="Must match authorization request"
)
client_id: str = Field(..., description="Client identifier")
client_secret: Optional[str] = Field(
None, description="Client secret (for confidential clients)"
)
code_verifier: Optional[str] = Field(
None, description="PKCE code verifier (for authorization_code grant)"
)
refresh_token: Optional[str] = Field(
None, description="Refresh token (for refresh_token grant)"
)
scope: Optional[str] = Field(
None, description="Requested scopes (for refresh_token grant)"
)
class TokenResponse(BaseModel):
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
access_token: str = Field(..., description="Access token")
token_type: Literal["Bearer"] = Field(default="Bearer", description="Token type")
expires_in: int = Field(..., description="Token lifetime in seconds")
refresh_token: Optional[str] = Field(None, description="Refresh token")
scope: Optional[str] = Field(None, description="Granted scopes")
id_token: Optional[str] = Field(None, description="OIDC ID token")
# ============================================================
# UserInfo Response Model
# ============================================================
class UserInfoResponse(BaseModel):
"""OIDC UserInfo Response."""
sub: str = Field(..., description="User ID (subject)")
email: Optional[str] = Field(None, description="User email")
email_verified: Optional[bool] = Field(
None, description="Whether email is verified"
)
name: Optional[str] = Field(None, description="User display name")
updated_at: Optional[int] = Field(None, description="Last profile update timestamp")
# ============================================================
# OIDC Discovery Models
# ============================================================
class OpenIDConfiguration(BaseModel):
"""OIDC Discovery Document."""
issuer: str
authorization_endpoint: str
token_endpoint: str
userinfo_endpoint: str
revocation_endpoint: str
jwks_uri: str
scopes_supported: list[str]
response_types_supported: list[str]
grant_types_supported: list[str]
token_endpoint_auth_methods_supported: list[str]
code_challenge_methods_supported: list[str]
subject_types_supported: list[str]
id_token_signing_alg_values_supported: list[str]
class JWK(BaseModel):
"""JSON Web Key."""
kty: str = Field(..., description="Key type (RSA)")
use: str = Field(default="sig", description="Key use (signature)")
kid: str = Field(..., description="Key ID")
alg: str = Field(default="RS256", description="Algorithm")
n: str = Field(..., description="RSA modulus")
e: str = Field(..., description="RSA exponent")
class JWKS(BaseModel):
"""JSON Web Key Set."""
keys: list[JWK]
# ============================================================
# Client Management Models
# ============================================================
class RegisterClientRequest(BaseModel):
"""Request to register a new OAuth client."""
name: str = Field(..., min_length=1, max_length=100, description="Client name")
description: Optional[str] = Field(
None, max_length=500, description="Client description"
)
logo_url: Optional[HttpUrl] = Field(None, description="Logo URL")
homepage_url: Optional[HttpUrl] = Field(None, description="Homepage URL")
privacy_policy_url: Optional[HttpUrl] = Field(
None, description="Privacy policy URL"
)
terms_of_service_url: Optional[HttpUrl] = Field(
None, description="Terms of service URL"
)
redirect_uris: list[str] = Field(
..., min_length=1, description="Allowed redirect URIs"
)
client_type: Literal["public", "confidential"] = Field(
default="public", description="Client type"
)
webhook_domains: list[str] = Field(
default_factory=list, description="Allowed webhook domains"
)
class UpdateClientRequest(BaseModel):
"""Request to update an OAuth client."""
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
logo_url: Optional[HttpUrl] = None
homepage_url: Optional[HttpUrl] = None
privacy_policy_url: Optional[HttpUrl] = None
terms_of_service_url: Optional[HttpUrl] = None
redirect_uris: Optional[list[str]] = None
webhook_domains: Optional[list[str]] = None
class ClientResponse(BaseModel):
"""OAuth client response."""
id: str
client_id: str
client_type: str
name: str
description: Optional[str]
logo_url: Optional[str]
homepage_url: Optional[str]
privacy_policy_url: Optional[str]
terms_of_service_url: Optional[str]
redirect_uris: list[str]
allowed_scopes: list[str]
webhook_domains: list[str]
status: str
created_at: datetime
updated_at: datetime
class ClientSecretResponse(BaseModel):
"""Response containing newly generated client credentials."""
client_id: str
client_secret: str = Field(
..., description="Client secret (only shown once, store securely)"
)
webhook_secret: str = Field(
...,
description="Webhook secret for HMAC signing (only shown once, store securely)",
)
# ============================================================
# Token Introspection/Revocation Models
# ============================================================
class TokenRevocationRequest(BaseModel):
"""Token revocation request (RFC 7009)."""
token: str = Field(..., description="Token to revoke")
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
None, description="Hint about token type"
)
class TokenIntrospectionRequest(BaseModel):
"""Token introspection request (RFC 7662)."""
token: str = Field(..., description="Token to introspect")
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
None, description="Hint about token type"
)
class TokenIntrospectionResponse(BaseModel):
"""Token introspection response."""
active: bool = Field(..., description="Whether the token is active")
scope: Optional[str] = Field(None, description="Token scopes")
client_id: Optional[str] = Field(
None, description="Client that token was issued to"
)
username: Optional[str] = Field(None, description="User identifier")
token_type: Optional[str] = Field(None, description="Token type")
exp: Optional[int] = Field(None, description="Expiration timestamp")
iat: Optional[int] = Field(None, description="Issued at timestamp")
sub: Optional[str] = Field(None, description="Subject (user ID)")
aud: Optional[str] = Field(None, description="Audience")
iss: Optional[str] = Field(None, description="Issuer")

View File

@@ -0,0 +1,66 @@
"""
PKCE (Proof Key for Code Exchange) implementation for OAuth 2.0.
RFC 7636: https://tools.ietf.org/html/rfc7636
"""
import base64
import hashlib
import secrets
def generate_code_verifier(length: int = 64) -> str:
"""
Generate a cryptographically random code verifier.
Args:
length: Length of the verifier (43-128 characters, default 64)
Returns:
URL-safe base64 encoded random string
"""
if not 43 <= length <= 128:
raise ValueError("Code verifier length must be between 43 and 128")
return secrets.token_urlsafe(length)[:length]
def generate_code_challenge(verifier: str, method: str = "S256") -> str:
"""
Generate a code challenge from the verifier.
Args:
verifier: The code verifier string
method: Challenge method ("S256" or "plain")
Returns:
The code challenge string
"""
if method == "S256":
digest = hashlib.sha256(verifier.encode("ascii")).digest()
# URL-safe base64 encoding without padding
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
elif method == "plain":
return verifier
else:
raise ValueError(f"Unsupported code challenge method: {method}")
def verify_code_challenge(
verifier: str,
challenge: str,
method: str = "S256",
) -> bool:
"""
Verify that a code verifier matches the stored challenge.
Args:
verifier: The code verifier from the token request
challenge: The code challenge stored during authorization
method: The challenge method used
Returns:
True if the verifier matches the challenge
"""
expected = generate_code_challenge(verifier, method)
# Use constant-time comparison to prevent timing attacks
return secrets.compare_digest(expected, challenge)

View File

@@ -0,0 +1,860 @@
"""
OAuth 2.0 Authorization Server endpoints.
Implements:
- GET /oauth/authorize - Authorization endpoint
- POST /oauth/authorize/consent - Consent form submission
- POST /oauth/token - Token endpoint
- GET /oauth/userinfo - OIDC UserInfo endpoint
- POST /oauth/revoke - Token revocation endpoint
Authentication:
- X-API-Key header - API key for external apps (preferred)
- Authorization: Bearer <jwt> - JWT token authentication
- access_token cookie - Browser-based auth
"""
import json
import logging
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
from urllib.parse import urlencode
from fastapi import APIRouter, Form, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from backend.data.db import prisma
from backend.data.redis_client import get_redis_async
from backend.server.oauth.consent_templates import (
render_consent_page,
render_error_page,
render_login_redirect_page,
)
from backend.server.oauth.errors import (
InvalidClientError,
InvalidRequestError,
OAuthError,
UnsupportedGrantTypeError,
)
from backend.server.oauth.models import TokenResponse, UserInfoResponse
from backend.server.oauth.service import get_oauth_service
from backend.server.oauth.token_service import get_token_service
from backend.util.rate_limiter import check_rate_limit
logger = logging.getLogger(__name__)
oauth_router = APIRouter(prefix="/oauth", tags=["oauth"])
# Redis key prefix and TTL for consent state storage
CONSENT_STATE_PREFIX = "oauth:consent:"
CONSENT_STATE_TTL = 600 # 10 minutes
# Redis key prefix and TTL for login redirect state storage
LOGIN_STATE_PREFIX = "oauth:login:"
LOGIN_STATE_TTL = 900 # 15 minutes (longer to allow time for login)
async def _store_login_state(token: str, state: dict) -> None:
"""Store OAuth login state in Redis with TTL."""
redis = await get_redis_async()
await redis.setex(
f"{LOGIN_STATE_PREFIX}{token}",
LOGIN_STATE_TTL,
json.dumps(state, default=str),
)
async def _get_and_delete_login_state(token: str) -> Optional[dict]:
"""Retrieve and delete login state from Redis (one-time use, atomic)."""
redis = await get_redis_async()
key = f"{LOGIN_STATE_PREFIX}{token}"
# Use GETDEL for atomic get+delete to prevent race conditions
state_json = await redis.getdel(key)
if state_json:
return json.loads(state_json)
return None
async def _store_consent_state(token: str, state: dict) -> None:
"""Store consent state in Redis with TTL."""
redis = await get_redis_async()
await redis.setex(
f"{CONSENT_STATE_PREFIX}{token}",
CONSENT_STATE_TTL,
json.dumps(state, default=str),
)
async def _get_and_delete_consent_state(token: str) -> Optional[dict]:
"""Retrieve and delete consent state from Redis (atomic get+delete)."""
redis = await get_redis_async()
key = f"{CONSENT_STATE_PREFIX}{token}"
# Use GETDEL for atomic get+delete to prevent race conditions
state_json = await redis.getdel(key)
if state_json:
return json.loads(state_json)
return None
async def _get_user_id_from_request(
request: Request, strict_bearer: bool = False
) -> Optional[str]:
"""
Extract user ID from request, checking API key, Authorization header, and cookie.
Supports:
1. X-API-Key header - API key authentication (preferred for external apps)
2. Authorization: Bearer <jwt> - JWT token authentication
3. access_token cookie - Cookie-based auth (for browser flows)
Args:
request: The incoming request
strict_bearer: If True and Bearer token is provided but invalid,
do NOT fallthrough to cookie auth (prevents auth downgrade attacks)
"""
from autogpt_libs.auth.jwt_utils import parse_jwt_token
from backend.data.api_key import validate_api_key
# First try X-API-Key header (for external apps)
api_key = request.headers.get("X-API-Key")
if api_key:
try:
api_key_info = await validate_api_key(api_key)
if api_key_info:
return api_key_info.user_id
except Exception:
logger.debug("API key validation failed")
# Then try Authorization header (JWT)
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
try:
token = auth_header[7:]
payload = parse_jwt_token(token)
return payload.get("sub")
except Exception as e:
logger.debug("JWT token validation failed: %s", type(e).__name__)
# Security fix: If Bearer token was provided but invalid,
# don't fallthrough to weaker auth methods when strict_bearer is True
if strict_bearer:
return None
# Finally try cookie (browser-based auth)
token = request.cookies.get("access_token")
if token:
try:
payload = parse_jwt_token(token)
return payload.get("sub")
except Exception as e:
logger.debug("Cookie token validation failed: %s", type(e).__name__)
return None
def _parse_scopes(scope_str: str) -> list[str]:
"""Parse space-separated scope string into list."""
if not scope_str:
return []
return [s.strip() for s in scope_str.split() if s.strip()]
def _get_client_ip(request: Request) -> str:
"""Get client IP address from request."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
# ================================================================
# Authorization Endpoint
# ================================================================
@oauth_router.get("/authorize", response_model=None)
async def authorize(
request: Request,
response_type: str = Query(..., description="Must be 'code'"),
client_id: str = Query(..., description="Client identifier"),
redirect_uri: str = Query(..., description="Redirect URI"),
state: str = Query(..., description="CSRF state parameter"),
code_challenge: str = Query(..., description="PKCE code challenge"),
code_challenge_method: str = Query("S256", description="PKCE method"),
scope: str = Query("", description="Space-separated scopes"),
nonce: Optional[str] = Query(None, description="OIDC nonce"),
prompt: Optional[str] = Query(None, description="Prompt behavior"),
) -> HTMLResponse | RedirectResponse:
"""
OAuth 2.0 Authorization Endpoint.
Validates the request, checks user authentication, and either:
- Returns error if user is not authenticated (API key or JWT required)
- Shows consent page if user hasn't authorized these scopes
- Redirects with authorization code if already authorized
Authentication methods (in order of preference):
1. X-API-Key header - API key for external apps
2. Authorization: Bearer <jwt> - JWT token
3. access_token cookie - Browser-based auth
"""
# Get user ID from API key, Authorization header, or cookie
user_id = await _get_user_id_from_request(request)
# Rate limiting - use client IP as identifier for authorize endpoint
client_ip = _get_client_ip(request)
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
if not rate_result.allowed:
return HTMLResponse(
render_error_page(
"rate_limit_exceeded",
"Too many authorization requests. Please try again later.",
),
status_code=429,
)
oauth_service = get_oauth_service()
try:
# Validate response_type
if response_type != "code":
raise InvalidRequestError(
"Only 'code' response_type is supported", state=state
)
# Validate PKCE method
if code_challenge_method != "S256":
raise InvalidRequestError(
"Only 'S256' code_challenge_method is supported", state=state
)
# Parse scopes
scopes = _parse_scopes(scope)
# Validate client and redirect URI
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
# Check if user is authenticated
if not user_id:
# User needs to log in - store OAuth params and redirect to frontend login
from backend.util.settings import Settings
settings = Settings()
login_token = secrets.token_urlsafe(32)
logger.info(f"Storing login state with token: {login_token}")
await _store_login_state(
login_token,
{
"client_id": client_id,
"redirect_uri": redirect_uri,
"scopes": scopes,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"nonce": nonce,
"prompt": prompt,
"created_at": datetime.now(timezone.utc).isoformat(),
"expires_at": (
datetime.now(timezone.utc) + timedelta(seconds=LOGIN_STATE_TTL)
).isoformat(),
},
)
logger.info(f"Login state stored successfully for token: {login_token}")
# Build redirect URL to frontend login
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
return _add_security_headers(
HTMLResponse(
render_error_page(
"server_error", "Frontend URL not configured"
),
status_code=500,
)
)
# Redirect to frontend login with oauth_session parameter
login_url = f"{frontend_base_url}/login?oauth_session={login_token}"
return _add_security_headers(
HTMLResponse(render_login_redirect_page(login_url))
)
# Check if user has already authorized these scopes
if prompt != "consent":
has_auth = await oauth_service.has_valid_authorization(
user_id, client_id, scopes
)
if has_auth:
# Skip consent, issue code directly
code = await oauth_service.create_authorization_code(
user_id=user_id,
client_id=client_id,
redirect_uri=redirect_uri,
scopes=scopes,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
nonce=nonce,
)
redirect_url = (
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
)
return RedirectResponse(url=redirect_url, status_code=302)
# Generate consent token and store state in Redis
consent_token = secrets.token_urlsafe(32)
await _store_consent_state(
consent_token,
{
"user_id": user_id,
"client_id": client_id,
"redirect_uri": redirect_uri,
"scopes": scopes,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"nonce": nonce,
"expires_at": (
datetime.now(timezone.utc) + timedelta(minutes=10)
).isoformat(),
},
)
# Render consent page
return _add_security_headers(
HTMLResponse(
render_consent_page(
client_name=client.name,
client_logo=client.logoUrl,
scopes=scopes,
consent_token=consent_token,
action_url="/oauth/authorize/consent",
privacy_policy_url=client.privacyPolicyUrl,
terms_url=client.termsOfServiceUrl,
)
)
)
except OAuthError as e:
# If we have a valid redirect_uri, redirect with error
# Otherwise show error page
try:
client = await oauth_service.get_client(client_id)
if client and redirect_uri in client.redirectUris:
return e.to_redirect(redirect_uri)
except Exception:
pass
return _add_security_headers(
HTMLResponse(
render_error_page(e.error.value, e.description or "An error occurred"),
status_code=400,
)
)
@oauth_router.post("/authorize/consent", response_model=None)
async def submit_consent(
request: Request,
consent_token: str = Form(...),
authorize: str = Form(...),
) -> HTMLResponse | RedirectResponse:
"""
Process consent form submission.
Creates authorization code and redirects to client's redirect_uri.
"""
# Rate limiting on consent submission to prevent brute force attacks
client_ip = _get_client_ip(request)
rate_result = await check_rate_limit(client_ip, "oauth_consent")
if not rate_result.allowed:
return _add_security_headers(
HTMLResponse(
render_error_page(
"rate_limit_exceeded",
"Too many consent requests. Please try again later.",
),
status_code=429,
)
)
oauth_service = get_oauth_service()
# Validate consent token (retrieves and deletes from Redis atomically)
consent_state = await _get_and_delete_consent_state(consent_token)
if not consent_state:
return HTMLResponse(
render_error_page("invalid_request", "Invalid or expired consent token"),
status_code=400,
)
# Check expiration (expires_at is stored as ISO string in Redis)
expires_at = datetime.fromisoformat(consent_state["expires_at"])
if expires_at < datetime.now(timezone.utc):
return HTMLResponse(
render_error_page("invalid_request", "Consent session expired"),
status_code=400,
)
redirect_uri = consent_state["redirect_uri"]
state = consent_state["state"]
# Check if user denied
if authorize.lower() != "true":
error_params = urlencode(
{
"error": "access_denied",
"error_description": "User denied the authorization request",
"state": state,
}
)
return RedirectResponse(
url=f"{redirect_uri}?{error_params}",
status_code=302,
)
try:
# Create authorization code
code = await oauth_service.create_authorization_code(
user_id=consent_state["user_id"],
client_id=consent_state["client_id"],
redirect_uri=redirect_uri,
scopes=consent_state["scopes"],
code_challenge=consent_state["code_challenge"],
code_challenge_method=consent_state["code_challenge_method"],
nonce=consent_state["nonce"],
)
# Redirect with code
return RedirectResponse(
url=f"{redirect_uri}?{urlencode({'code': code, 'state': state})}",
status_code=302,
)
except OAuthError as e:
return e.to_redirect(redirect_uri)
def _wants_json(request: Request) -> bool:
"""Check if client prefers JSON response (for frontend fetch calls)."""
accept = request.headers.get("Accept", "")
return "application/json" in accept
def _add_security_headers(response: HTMLResponse) -> HTMLResponse:
"""Add security headers to OAuth HTML responses."""
response.headers["X-Frame-Options"] = "DENY"
response.headers["Content-Security-Policy"] = "frame-ancestors 'none'"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
@oauth_router.get("/authorize/resume", response_model=None)
async def resume_authorization(
request: Request,
session_id: str = Query(..., description="OAuth login session ID"),
) -> HTMLResponse | RedirectResponse | JSONResponse:
"""
Resume OAuth authorization after user login.
This endpoint is called after the user completes login on the frontend.
It retrieves the stored OAuth parameters and continues the authorization flow.
Supports Accept: application/json header to return JSON for frontend fetch calls,
solving CORS issues with redirect responses.
"""
wants_json = _wants_json(request)
# Rate limiting - use client IP
client_ip = _get_client_ip(request)
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
if not rate_result.allowed:
if wants_json:
return JSONResponse(
{
"error": "rate_limit_exceeded",
"error_description": "Too many requests",
},
status_code=429,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"rate_limit_exceeded",
"Too many authorization requests. Please try again later.",
),
status_code=429,
)
)
# Verify user is now authenticated (use strict_bearer to prevent auth downgrade)
user_id = await _get_user_id_from_request(request, strict_bearer=True)
if not user_id:
from backend.util.settings import Settings
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
if wants_json:
return JSONResponse(
{
"error": "login_required",
"error_description": "Authentication required",
"redirect_url": f"{frontend_url}/login",
},
status_code=401,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"login_required",
"Authentication required. Please log in and try again.",
redirect_url=f"{frontend_url}/login",
),
status_code=401,
)
)
# Retrieve and delete login state (one-time use)
logger.info(f"Attempting to retrieve login state for session_id: {session_id}")
login_state = await _get_and_delete_login_state(session_id)
if not login_state:
logger.warning(f"Login state not found for session_id: {session_id}")
if wants_json:
return JSONResponse(
{
"error": "invalid_request",
"error_description": "Invalid or expired authorization session",
},
status_code=400,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"invalid_request",
"Invalid or expired authorization session. Please start over.",
),
status_code=400,
)
)
# Check expiration
expires_at = datetime.fromisoformat(login_state["expires_at"])
if expires_at < datetime.now(timezone.utc):
if wants_json:
return JSONResponse(
{
"error": "invalid_request",
"error_description": "Authorization session has expired",
},
status_code=400,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"invalid_request",
"Authorization session has expired. Please start over.",
),
status_code=400,
)
)
# Extract stored OAuth parameters
client_id = login_state["client_id"]
redirect_uri = login_state["redirect_uri"]
scopes = login_state["scopes"]
state = login_state["state"]
code_challenge = login_state["code_challenge"]
code_challenge_method = login_state["code_challenge_method"]
nonce = login_state.get("nonce")
prompt = login_state.get("prompt")
oauth_service = get_oauth_service()
try:
# Re-validate client (in case it was deactivated during login)
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
# Check if user has already authorized these scopes (skip consent if yes)
if prompt != "consent":
has_auth = await oauth_service.has_valid_authorization(
user_id, client_id, scopes
)
if has_auth:
# Skip consent, issue code directly
code = await oauth_service.create_authorization_code(
user_id=user_id,
client_id=client_id,
redirect_uri=redirect_uri,
scopes=scopes,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
nonce=nonce,
)
redirect_url = (
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
)
# Return JSON with redirect URL for frontend to handle
if wants_json:
return JSONResponse(
{"redirect_url": redirect_url, "needs_consent": False}
)
return RedirectResponse(url=redirect_url, status_code=302)
# Generate consent token and store state in Redis
consent_token = secrets.token_urlsafe(32)
await _store_consent_state(
consent_token,
{
"user_id": user_id,
"client_id": client_id,
"redirect_uri": redirect_uri,
"scopes": scopes,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"nonce": nonce,
"expires_at": (
datetime.now(timezone.utc) + timedelta(minutes=10)
).isoformat(),
},
)
# For JSON requests, return consent data instead of HTML
if wants_json:
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
scope_details = [
{"scope": s, "description": SCOPE_DESCRIPTIONS.get(s, s)}
for s in scopes
]
return JSONResponse(
{
"needs_consent": True,
"consent_token": consent_token,
"client": {
"name": client.name,
"logo_url": client.logoUrl,
"privacy_policy_url": client.privacyPolicyUrl,
"terms_url": client.termsOfServiceUrl,
},
"scopes": scope_details,
"action_url": "/oauth/authorize/consent",
}
)
# Render consent page (HTML response)
return _add_security_headers(
HTMLResponse(
render_consent_page(
client_name=client.name,
client_logo=client.logoUrl,
scopes=scopes,
consent_token=consent_token,
action_url="/oauth/authorize/consent",
privacy_policy_url=client.privacyPolicyUrl,
terms_url=client.termsOfServiceUrl,
)
)
)
except OAuthError as e:
if wants_json:
return JSONResponse(
{"error": e.error.value, "error_description": e.description},
status_code=400,
)
# If we have a valid redirect_uri, redirect with error
try:
client = await oauth_service.get_client(client_id)
if client and redirect_uri in client.redirectUris:
return e.to_redirect(redirect_uri)
except Exception:
pass
return _add_security_headers(
HTMLResponse(
render_error_page(e.error.value, e.description or "An error occurred"),
status_code=400,
)
)
# ================================================================
# Token Endpoint
# ================================================================
@oauth_router.post("/token", response_model=TokenResponse)
async def token(
request: Request,
grant_type: str = Form(...),
code: Optional[str] = Form(None),
redirect_uri: Optional[str] = Form(None),
client_id: str = Form(...),
client_secret: Optional[str] = Form(None),
code_verifier: Optional[str] = Form(None),
refresh_token: Optional[str] = Form(None),
scope: Optional[str] = Form(None),
) -> TokenResponse:
"""
OAuth 2.0 Token Endpoint.
Supports:
- authorization_code grant (with PKCE)
- refresh_token grant
"""
# Rate limiting - use client_id as identifier
rate_result = await check_rate_limit(client_id, "oauth_token")
if not rate_result.allowed:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers={
"Retry-After": str(int(rate_result.retry_after or 60)),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(rate_result.reset_at)),
},
)
oauth_service = get_oauth_service()
try:
# Validate client authentication
await oauth_service.validate_client_secret(client_id, client_secret)
if grant_type == "authorization_code":
# Validate required parameters
if not code:
raise InvalidRequestError("'code' is required")
if not redirect_uri:
raise InvalidRequestError("'redirect_uri' is required")
if not code_verifier:
raise InvalidRequestError("'code_verifier' is required for PKCE")
return await oauth_service.exchange_authorization_code(
code=code,
client_id=client_id,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
)
elif grant_type == "refresh_token":
if not refresh_token:
raise InvalidRequestError("'refresh_token' is required")
requested_scopes = _parse_scopes(scope) if scope else None
return await oauth_service.refresh_access_token(
refresh_token=refresh_token,
client_id=client_id,
requested_scopes=requested_scopes,
)
else:
raise UnsupportedGrantTypeError(grant_type)
except OAuthError as e:
# 401 for client auth failure, 400 for other validation errors (per RFC 6749)
raise e.to_http_exception(401 if isinstance(e, InvalidClientError) else 400)
# ================================================================
# UserInfo Endpoint
# ================================================================
@oauth_router.get("/userinfo", response_model=UserInfoResponse)
async def userinfo(request: Request) -> UserInfoResponse:
"""
OIDC UserInfo Endpoint.
Returns user profile information based on the granted scopes.
"""
token_service = get_token_service()
# Extract bearer token
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Bearer token required",
headers={"WWW-Authenticate": "Bearer"},
)
token = auth_header[7:]
try:
# Verify token
claims = token_service.verify_access_token(token)
# Check token is not revoked
token_hash = token_service.hash_token(token)
stored_token = await prisma.oauthaccesstoken.find_unique(
where={"tokenHash": token_hash}
)
if not stored_token or stored_token.revokedAt:
raise HTTPException(
status_code=401,
detail="Token has been revoked",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last used
await prisma.oauthaccesstoken.update(
where={"id": stored_token.id},
data={"lastUsedAt": datetime.now(timezone.utc)},
)
# Get user info based on scopes
user = await prisma.user.find_unique(where={"id": claims.sub})
if not user:
raise HTTPException(status_code=404, detail="User not found")
scopes = claims.scope.split()
# Build response based on scopes
email = user.email if "email" in scopes else None
email_verified = user.emailVerified if "email" in scopes else None
name = user.name if "profile" in scopes else None
updated_at = int(user.updatedAt.timestamp()) if "profile" in scopes else None
return UserInfoResponse(
sub=claims.sub,
email=email,
email_verified=email_verified,
name=name,
updated_at=updated_at,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=401,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"},
)
# ================================================================
# Token Revocation Endpoint
# ================================================================
@oauth_router.post("/revoke")
async def revoke(
request: Request,
token: str = Form(...),
token_type_hint: Optional[str] = Form(None),
) -> JSONResponse:
"""
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
Revokes an access token or refresh token.
"""
oauth_service = get_oauth_service()
# Note: Per RFC 7009, always return 200 even if token not found
await oauth_service.revoke_token(token, token_type_hint)
return JSONResponse(content={}, status_code=200)

View File

@@ -0,0 +1,625 @@
"""
Core OAuth 2.0 service logic.
Handles:
- Client validation and lookup
- Authorization code generation and exchange
- Token issuance and refresh
- User consent management
- Audit logging
"""
import hashlib
import json
import secrets
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from prisma.enums import OAuthClientStatus
from prisma.models import OAuthAuthorization, OAuthClient, User
from backend.data.db import prisma
from backend.server.oauth.errors import (
InvalidClientError,
InvalidGrantError,
InvalidRequestError,
InvalidScopeError,
)
from backend.server.oauth.models import TokenResponse
from backend.server.oauth.pkce import verify_code_challenge
from backend.server.oauth.token_service import OAuthTokenService, get_token_service
class OAuthService:
"""Core OAuth 2.0 service."""
def __init__(self, token_service: Optional[OAuthTokenService] = None):
self.token_service = token_service or get_token_service()
# ================================================================
# Client Operations
# ================================================================
async def get_client(self, client_id: str) -> Optional[OAuthClient]:
"""Get an OAuth client by client_id."""
return await prisma.oauthclient.find_unique(where={"clientId": client_id})
async def validate_client(
self,
client_id: str,
redirect_uri: str,
scopes: list[str],
) -> OAuthClient:
"""
Validate a client for authorization.
Args:
client_id: Client identifier
redirect_uri: Requested redirect URI
scopes: Requested scopes
Returns:
Validated OAuthClient
Raises:
InvalidClientError: Client not found or inactive
InvalidRequestError: Invalid redirect URI
InvalidScopeError: Invalid scopes requested
"""
client = await self.get_client(client_id)
if not client:
raise InvalidClientError(f"Client '{client_id}' not found")
if client.status != OAuthClientStatus.ACTIVE:
raise InvalidClientError(f"Client '{client_id}' is not active")
# Validate redirect URI (exact match required)
if redirect_uri not in client.redirectUris:
raise InvalidRequestError(
f"Redirect URI '{redirect_uri}' is not registered for this client"
)
# Validate scopes
invalid_scopes = set(scopes) - set(client.allowedScopes)
if invalid_scopes:
raise InvalidScopeError(
f"Scopes not allowed for this client: {', '.join(invalid_scopes)}"
)
return client
async def validate_client_secret(
self,
client_id: str,
client_secret: Optional[str],
) -> OAuthClient:
"""
Validate client authentication for token endpoint.
Args:
client_id: Client identifier
client_secret: Client secret (for confidential clients)
Returns:
Validated OAuthClient
Raises:
InvalidClientError: Invalid client or credentials
"""
client = await self.get_client(client_id)
if not client:
raise InvalidClientError(f"Client '{client_id}' not found")
if client.status != OAuthClientStatus.ACTIVE:
raise InvalidClientError(f"Client '{client_id}' is not active")
# Confidential clients must provide secret
if client.clientType == "confidential":
if not client_secret:
raise InvalidClientError("Client secret required")
# Hash and compare
secret_hash = self._hash_secret(
client_secret, client.clientSecretSalt or ""
)
if not secrets.compare_digest(secret_hash, client.clientSecretHash or ""):
raise InvalidClientError("Invalid client credentials")
return client
@staticmethod
def _hash_secret(secret: str, salt: str) -> str:
"""Hash a client secret with salt."""
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
# ================================================================
# Authorization Code Operations
# ================================================================
async def create_authorization_code(
self,
user_id: str,
client_id: str,
redirect_uri: str,
scopes: list[str],
code_challenge: str,
code_challenge_method: str = "S256",
nonce: Optional[str] = None,
) -> str:
"""
Create a new authorization code.
Args:
user_id: User who authorized
client_id: Client being authorized
redirect_uri: Redirect URI for callback
scopes: Granted scopes
code_challenge: PKCE code challenge
code_challenge_method: PKCE method (S256)
nonce: OIDC nonce (optional)
Returns:
Authorization code string
"""
code = secrets.token_urlsafe(32)
code_hash = self.token_service.hash_token(code)
# Get the OAuthClient to link
client = await self.get_client(client_id)
if not client:
raise InvalidClientError(f"Client '{client_id}' not found")
await prisma.oauthauthorizationcode.create(
data={ # type: ignore[typeddict-item]
"codeHash": code_hash,
"userId": user_id,
"clientId": client.id,
"redirectUri": redirect_uri,
"scopes": scopes,
"codeChallenge": code_challenge,
"codeChallengeMethod": code_challenge_method,
"nonce": nonce,
"expiresAt": datetime.now(timezone.utc) + timedelta(minutes=10),
}
)
return code
async def exchange_authorization_code(
self,
code: str,
client_id: str,
redirect_uri: str,
code_verifier: str,
) -> TokenResponse:
"""
Exchange an authorization code for tokens.
Args:
code: Authorization code
client_id: Client identifier
redirect_uri: Must match original redirect URI
code_verifier: PKCE code verifier
Returns:
TokenResponse with access token, refresh token, etc.
Raises:
InvalidGrantError: Invalid or expired code
InvalidRequestError: PKCE verification failed
"""
code_hash = self.token_service.hash_token(code)
# Find the authorization code
auth_code = await prisma.oauthauthorizationcode.find_unique(
where={"codeHash": code_hash},
include={"Client": True, "User": True},
)
if not auth_code:
raise InvalidGrantError("Authorization code not found")
# Ensure Client relation is loaded
if not auth_code.Client:
raise InvalidGrantError("Authorization code client not found")
# Check if already used
if auth_code.usedAt:
# Code reuse is a security incident - revoke all tokens for this authorization
await self._revoke_tokens_for_client_user(
auth_code.Client.clientId, auth_code.userId
)
raise InvalidGrantError("Authorization code has already been used")
# Check expiration
if auth_code.expiresAt < datetime.now(timezone.utc):
raise InvalidGrantError("Authorization code has expired")
# Validate client
if auth_code.Client.clientId != client_id:
raise InvalidGrantError("Client ID mismatch")
# Validate redirect URI
if auth_code.redirectUri != redirect_uri:
raise InvalidGrantError("Redirect URI mismatch")
# Verify PKCE
if not verify_code_challenge(
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
):
raise InvalidRequestError("PKCE verification failed")
# Mark code as used
await prisma.oauthauthorizationcode.update(
where={"id": auth_code.id},
data={"usedAt": datetime.now(timezone.utc)},
)
# Create or update authorization record
await self._upsert_authorization(
auth_code.userId, auth_code.Client.id, auth_code.scopes
)
# Generate tokens
return await self._create_tokens(
user_id=auth_code.userId,
client=auth_code.Client,
scopes=auth_code.scopes,
nonce=auth_code.nonce,
user=auth_code.User,
)
async def refresh_access_token(
self,
refresh_token: str,
client_id: str,
requested_scopes: Optional[list[str]] = None,
) -> TokenResponse:
"""
Refresh an access token using a refresh token.
Args:
refresh_token: Refresh token string
client_id: Client identifier
requested_scopes: Optionally request fewer scopes
Returns:
New TokenResponse
Raises:
InvalidGrantError: Invalid or expired refresh token
"""
token_hash = self.token_service.hash_token(refresh_token)
# Find the refresh token
stored_token = await prisma.oauthrefreshtoken.find_unique(
where={"tokenHash": token_hash},
include={"Client": True, "User": True},
)
if not stored_token:
raise InvalidGrantError("Refresh token not found")
# Ensure Client relation is loaded
if not stored_token.Client:
raise InvalidGrantError("Refresh token client not found")
# Check if revoked
if stored_token.revokedAt:
raise InvalidGrantError("Refresh token has been revoked")
# Check expiration
if stored_token.expiresAt < datetime.now(timezone.utc):
raise InvalidGrantError("Refresh token has expired")
# Validate client
if stored_token.Client.clientId != client_id:
raise InvalidGrantError("Client ID mismatch")
# Determine scopes
scopes = stored_token.scopes
if requested_scopes:
# Can only request a subset of original scopes
invalid = set(requested_scopes) - set(stored_token.scopes)
if invalid:
raise InvalidScopeError(
f"Cannot request scopes not in original grant: {', '.join(invalid)}"
)
scopes = requested_scopes
# Generate new tokens (rotates refresh token)
return await self._create_tokens(
user_id=stored_token.userId,
client=stored_token.Client,
scopes=scopes,
user=stored_token.User,
old_refresh_token_id=stored_token.id,
)
# ================================================================
# Token Operations
# ================================================================
async def _create_tokens(
self,
user_id: str,
client: OAuthClient,
scopes: list[str],
user: Optional[User] = None,
nonce: Optional[str] = None,
old_refresh_token_id: Optional[str] = None,
) -> TokenResponse:
"""
Create access and refresh tokens.
Args:
user_id: User ID
client: OAuth client
scopes: Granted scopes
user: User object (for ID token claims)
nonce: OIDC nonce
old_refresh_token_id: ID of refresh token being rotated
Returns:
TokenResponse
"""
# Generate access token
access_token, access_expires_at = self.token_service.generate_access_token(
user_id=user_id,
client_id=client.clientId,
scopes=scopes,
expires_in=client.tokenLifetimeSecs,
)
# Store access token hash
await prisma.oauthaccesstoken.create(
data={ # type: ignore[typeddict-item]
"tokenHash": self.token_service.hash_token(access_token),
"userId": user_id,
"clientId": client.id,
"scopes": scopes,
"expiresAt": access_expires_at,
}
)
# Generate refresh token
refresh_token = self.token_service.generate_refresh_token()
refresh_expires_at = datetime.now(timezone.utc) + timedelta(
seconds=client.refreshTokenLifetimeSecs
)
await prisma.oauthrefreshtoken.create(
data={ # type: ignore[typeddict-item]
"tokenHash": self.token_service.hash_token(refresh_token),
"userId": user_id,
"clientId": client.id,
"scopes": scopes,
"expiresAt": refresh_expires_at,
}
)
# Revoke old refresh token if rotating
if old_refresh_token_id:
await prisma.oauthrefreshtoken.update(
where={"id": old_refresh_token_id},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Generate ID token if openid scope requested
id_token = None
if "openid" in scopes and user:
email = user.email if "email" in scopes else None
name = user.name if "profile" in scopes else None
id_token = self.token_service.generate_id_token(
user_id=user_id,
client_id=client.clientId,
email=email,
name=name,
nonce=nonce,
)
# Audit log
await self._audit_log(
event_type="token.issued",
user_id=user_id,
client_id=client.clientId,
details={"scopes": scopes},
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=client.tokenLifetimeSecs,
refresh_token=refresh_token,
scope=" ".join(scopes),
id_token=id_token,
)
async def revoke_token(
self,
token: str,
token_type_hint: Optional[str] = None,
) -> bool:
"""
Revoke an access or refresh token.
Args:
token: Token to revoke
token_type_hint: Hint about token type
Returns:
True if token was found and revoked
"""
token_hash = self.token_service.hash_token(token)
now = datetime.now(timezone.utc)
# Try refresh token first if hinted or no hint
if token_type_hint in (None, "refresh_token"):
result = await prisma.oauthrefreshtoken.update_many(
where={"tokenHash": token_hash, "revokedAt": None},
data={"revokedAt": now},
)
if result > 0:
return True
# Try access token
if token_type_hint in (None, "access_token"):
result = await prisma.oauthaccesstoken.update_many(
where={"tokenHash": token_hash, "revokedAt": None},
data={"revokedAt": now},
)
if result > 0:
return True
return False
async def _revoke_tokens_for_client_user(
self,
client_id: str,
user_id: str,
) -> None:
"""Revoke all tokens for a client-user pair (security incident response)."""
client = await self.get_client(client_id)
if not client:
return
now = datetime.now(timezone.utc)
await prisma.oauthaccesstoken.update_many(
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
data={"revokedAt": now},
)
await prisma.oauthrefreshtoken.update_many(
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
data={"revokedAt": now},
)
await self._audit_log(
event_type="tokens.revoked.security",
user_id=user_id,
client_id=client_id,
details={"reason": "authorization_code_reuse"},
)
# ================================================================
# Authorization (Consent) Operations
# ================================================================
async def get_authorization(
self,
user_id: str,
client_id: str,
) -> Optional[OAuthAuthorization]:
"""Get existing authorization for user-client pair."""
client = await self.get_client(client_id)
if not client:
return None
return await prisma.oauthauthorization.find_unique(
where={
"userId_clientId": {
"userId": user_id,
"clientId": client.id,
}
}
)
async def has_valid_authorization(
self,
user_id: str,
client_id: str,
scopes: list[str],
) -> bool:
"""
Check if user has already authorized these scopes for this client.
Args:
user_id: User ID
client_id: Client identifier
scopes: Requested scopes
Returns:
True if user has already authorized all requested scopes
"""
auth = await self.get_authorization(user_id, client_id)
if not auth or auth.revokedAt:
return False
# Check if all requested scopes are already authorized
return set(scopes).issubset(set(auth.scopes))
async def _upsert_authorization(
self,
user_id: str,
client_db_id: str,
scopes: list[str],
) -> None:
"""Create or update an authorization record."""
existing = await prisma.oauthauthorization.find_unique(
where={
"userId_clientId": {
"userId": user_id,
"clientId": client_db_id,
}
}
)
if existing:
# Merge scopes
merged_scopes = list(set(existing.scopes) | set(scopes))
await prisma.oauthauthorization.update(
where={"id": existing.id},
data={"scopes": merged_scopes, "revokedAt": None},
)
else:
await prisma.oauthauthorization.create(
data={ # type: ignore[typeddict-item]
"userId": user_id,
"clientId": client_db_id,
"scopes": scopes,
}
)
# ================================================================
# Audit Logging
# ================================================================
async def _audit_log(
self,
event_type: str,
user_id: Optional[str] = None,
client_id: Optional[str] = None,
grant_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
details: Optional[dict[str, Any]] = None,
) -> None:
"""Create an audit log entry."""
# Convert details to JSON for Prisma's Json field
details_json = json.dumps(details or {})
await prisma.oauthauditlog.create(
data={
"eventType": event_type,
"userId": user_id,
"clientId": client_id,
"grantId": grant_id,
"ipAddress": ip_address,
"userAgent": user_agent,
"details": json.loads(details_json), # type: ignore[arg-type]
}
)
# Module-level singleton
_oauth_service: Optional[OAuthService] = None
def get_oauth_service() -> OAuthService:
"""Get the singleton OAuth service instance."""
global _oauth_service
if _oauth_service is None:
_oauth_service = OAuthService()
return _oauth_service

View File

@@ -0,0 +1,298 @@
"""
JWT Token Service for OAuth 2.0 Provider.
Handles generation and validation of:
- Access tokens (JWT)
- Refresh tokens (opaque)
- ID tokens (JWT, OIDC)
"""
import base64
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey,
RSAPublicKey,
generate_private_key,
)
from pydantic import BaseModel
from backend.util.settings import Settings
class TokenClaims(BaseModel):
"""Decoded token claims."""
iss: str # Issuer
sub: str # Subject (user ID)
aud: str # Audience (client ID)
exp: int # Expiration timestamp
iat: int # Issued at timestamp
jti: str # JWT ID
scope: str # Space-separated scopes
client_id: str # Client ID
class OAuthTokenService:
"""
Service for generating and validating OAuth tokens.
Uses RS256 (RSA with SHA-256) for JWT signing.
"""
def __init__(self, settings: Optional[Settings] = None):
self._settings = settings or Settings()
self._private_key: Optional[RSAPrivateKey] = None
self._public_key: Optional[RSAPublicKey] = None
self._algorithm = "RS256"
@property
def issuer(self) -> str:
"""Get the token issuer URL."""
return self._settings.config.platform_base_url or "https://platform.agpt.co"
@property
def key_id(self) -> str:
"""Get the key ID for JWKS."""
return self._settings.secrets.oauth_jwt_key_id or "default-key-id"
def _get_private_key(self) -> RSAPrivateKey:
"""Load or generate the private key."""
if self._private_key is not None:
return self._private_key
key_pem = self._settings.secrets.oauth_jwt_private_key
if key_pem:
loaded_key = serialization.load_pem_private_key(
key_pem.encode(), password=None
)
if not isinstance(loaded_key, RSAPrivateKey):
raise ValueError("OAuth JWT private key must be RSA")
self._private_key = loaded_key
else:
# Generate a key for development (should not be used in production)
self._private_key = generate_private_key(
public_exponent=65537,
key_size=2048,
)
return self._private_key
def _get_public_key(self) -> RSAPublicKey:
"""Get the public key from the private key."""
if self._public_key is not None:
return self._public_key
key_pem = self._settings.secrets.oauth_jwt_public_key
if key_pem:
loaded_key = serialization.load_pem_public_key(key_pem.encode())
if not isinstance(loaded_key, RSAPublicKey):
raise ValueError("OAuth JWT public key must be RSA")
self._public_key = loaded_key
else:
self._public_key = self._get_private_key().public_key()
return self._public_key
def generate_access_token(
self,
user_id: str,
client_id: str,
scopes: list[str],
expires_in: int = 3600,
) -> tuple[str, datetime]:
"""
Generate a JWT access token.
Args:
user_id: User ID (subject)
client_id: Client ID (audience)
scopes: List of granted scopes
expires_in: Token lifetime in seconds
Returns:
Tuple of (token string, expiration datetime)
"""
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=expires_in)
payload = {
"iss": self.issuer,
"sub": user_id,
"aud": client_id,
"exp": int(expires_at.timestamp()),
"iat": int(now.timestamp()),
"jti": secrets.token_urlsafe(16),
"scope": " ".join(scopes),
"client_id": client_id,
}
token = jwt.encode(
payload,
self._get_private_key(),
algorithm=self._algorithm,
headers={"kid": self.key_id},
)
return token, expires_at
def generate_refresh_token(self) -> str:
"""
Generate an opaque refresh token.
Returns:
URL-safe random token string
"""
return secrets.token_urlsafe(48)
def generate_id_token(
self,
user_id: str,
client_id: str,
email: Optional[str] = None,
name: Optional[str] = None,
nonce: Optional[str] = None,
expires_in: int = 3600,
) -> str:
"""
Generate an OIDC ID token.
Args:
user_id: User ID (subject)
client_id: Client ID (audience)
email: User's email (optional)
name: User's name (optional)
nonce: OIDC nonce for replay protection (optional)
expires_in: Token lifetime in seconds
Returns:
JWT ID token string
"""
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=expires_in)
payload = {
"iss": self.issuer,
"sub": user_id,
"aud": client_id,
"exp": int(expires_at.timestamp()),
"iat": int(now.timestamp()),
"auth_time": int(now.timestamp()),
}
if email:
payload["email"] = email
payload["email_verified"] = True
if name:
payload["name"] = name
if nonce:
payload["nonce"] = nonce
return jwt.encode(
payload,
self._get_private_key(),
algorithm=self._algorithm,
headers={"kid": self.key_id},
)
def verify_access_token(
self,
token: str,
expected_client_id: Optional[str] = None,
) -> TokenClaims:
"""
Verify and decode a JWT access token.
Args:
token: JWT token string
expected_client_id: Expected client ID (audience)
Returns:
Decoded token claims
Raises:
jwt.ExpiredSignatureError: Token has expired
jwt.InvalidTokenError: Token is invalid
"""
options = {}
if expected_client_id:
options["audience"] = expected_client_id
payload = jwt.decode(
token,
self._get_public_key(),
algorithms=[self._algorithm],
issuer=self.issuer,
options={"verify_aud": bool(expected_client_id)},
**options,
)
return TokenClaims(
iss=payload["iss"],
sub=payload["sub"],
aud=payload.get("aud", payload.get("client_id", "")),
exp=payload["exp"],
iat=payload["iat"],
jti=payload["jti"],
scope=payload.get("scope", ""),
client_id=payload.get("client_id", payload.get("aud", "")),
)
@staticmethod
def hash_token(token: str) -> str:
"""
Hash a token for secure storage.
Args:
token: Token string to hash
Returns:
SHA-256 hash of the token
"""
return hashlib.sha256(token.encode()).hexdigest()
def get_jwks(self) -> dict:
"""
Get the JSON Web Key Set (JWKS) for public key distribution.
Returns:
JWKS dictionary with public key(s)
"""
public_key = self._get_public_key()
public_numbers = public_key.public_numbers()
# Convert to base64url encoding without padding
def int_to_base64url(n: int, length: int) -> str:
data = n.to_bytes(length, byteorder="big")
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
# RSA modulus and exponent
n = int_to_base64url(public_numbers.n, (public_numbers.n.bit_length() + 7) // 8)
e = int_to_base64url(public_numbers.e, 3)
return {
"keys": [
{
"kty": "RSA",
"use": "sig",
"kid": self.key_id,
"alg": self._algorithm,
"n": n,
"e": e,
}
]
}
# Module-level singleton
_token_service: Optional[OAuthTokenService] = None
def get_token_service() -> OAuthTokenService:
"""Get the singleton token service instance."""
global _token_service
if _token_service is None:
_token_service = OAuthTokenService()
return _token_service

View File

@@ -21,6 +21,7 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.integrations.connect_router
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.credit_admin_routes
@@ -44,6 +45,7 @@ from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.external.api import external_app
from backend.server.middleware.security import SecurityHeadersMiddleware
from backend.server.oauth import client_router, discovery_router, oauth_router
from backend.server.utils.cors import build_cors_params
from backend.util import json
from backend.util.cloud_storage import shutdown_cloud_storage_handler
@@ -300,6 +302,18 @@ app.include_router(
app.mount("/external-api", external_app)
# OAuth Provider routes
app.include_router(oauth_router, tags=["oauth"], prefix="")
app.include_router(discovery_router, tags=["oidc-discovery"], prefix="")
app.include_router(client_router, tags=["oauth-clients"], prefix="")
# Integration Connect popup routes (for Credential Broker)
app.include_router(
backend.server.integrations.connect_router.connect_router,
tags=["integration-connect"],
prefix="",
)
@app.get(path="/health", tags=["health"], dependencies=[])
async def health():

View File

@@ -3,6 +3,7 @@ from datetime import UTC, datetime
from os import getenv
import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
@@ -49,13 +50,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create a test graph with agent input -> agent output
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for LLM tests",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create test OpenAI credentials for the user
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for Firecrawl tests",
links=[], # Required field - empty array for test profiles
)
)
# NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -802,18 +802,16 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
data=prisma.types.LibraryAgentCreateInput(
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
"isCreatedByUser": False,
"settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
),
},
isCreatedByUser=False,
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),

View File

@@ -248,7 +248,9 @@ async def log_search_term(search_query: str):
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
try:
await prisma.models.SearchTerms.prisma().create(
data={"searchTerm": search_query, "createdDate": date}
data=prisma.types.SearchTermsCreateInput(
searchTerm=search_query, createdDate=date
)
)
except Exception as e:
# Fail silently here so that logging search terms doesn't break the app
@@ -1430,13 +1432,10 @@ async def _approve_sub_agent(
# Create new version if no matching version found
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
await prisma.models.StoreListingVersion.prisma(tx).create(
data={
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
"version": next_version,
"storeListingId": listing.id,
}
)
sub_agent_data = _create_sub_agent_version_data(sub_graph, heading, main_agent_name)
sub_agent_data["version"] = next_version
sub_agent_data["storeListingId"] = listing.id
await prisma.models.StoreListingVersion.prisma(tx).create(data=sub_agent_data)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": listing.id}, data={"hasApprovedVersion": True}
)

View File

@@ -0,0 +1,228 @@
"""
Rate Limiting for External API.
Implements sliding window rate limiting using Redis for distributed systems.
"""
import logging
import time
from dataclasses import dataclass
from typing import Optional
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
@dataclass
class RateLimitResult:
"""Result of a rate limit check."""
allowed: bool
remaining: int
reset_at: float
retry_after: Optional[float] = None
class RateLimiter:
"""
Redis-based sliding window rate limiter.
Supports multiple limit tiers (per-minute, per-hour, per-day).
"""
def __init__(self, prefix: str = "ratelimit"):
self.prefix = prefix
def _make_key(self, identifier: str, window: str) -> str:
"""Create a Redis key for the rate limit counter."""
return f"{self.prefix}:{identifier}:{window}"
async def check_and_increment(
self,
identifier: str,
limits: dict[str, tuple[int, int]], # window_name -> (limit, window_seconds)
) -> RateLimitResult:
"""
Check rate limits and increment counters if allowed.
Uses atomic increment-first approach to prevent race conditions:
1. Increment all counters atomically
2. Check if any limit exceeded
3. If exceeded, decrement and return rate limit error
Args:
identifier: Unique identifier (e.g., client_id, client_id:user_id)
limits: Dictionary of limit configurations
e.g., {"minute": (60, 60), "hour": (1000, 3600)}
Returns:
RateLimitResult with allowed status and remaining quota
"""
if not limits:
# No limits configured, allow request
return RateLimitResult(
allowed=True,
remaining=999999,
reset_at=time.time() + 60,
)
redis = await get_redis_async()
current_time = time.time()
# Increment all counters atomically first
incremented_keys: list[tuple[str, int, int, int]] = (
[]
) # (key, new_count, limit, window_seconds)
for window_name, (limit, window_seconds) in limits.items():
key = self._make_key(identifier, window_name)
# Atomic increment
new_count = await redis.incr(key)
# Set expiry if this is a new key
if new_count == 1:
await redis.expire(key, window_seconds)
incremented_keys.append((key, new_count, limit, window_seconds))
# Check if any limit exceeded
for key, new_count, limit, window_seconds in incremented_keys:
if new_count > limit:
# Rate limit exceeded - decrement all counters we just incremented
for decr_key, _, _, _ in incremented_keys:
await redis.decr(decr_key)
ttl = await redis.ttl(key)
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
return RateLimitResult(
allowed=False,
remaining=0,
reset_at=reset_at,
retry_after=ttl if ttl > 0 else window_seconds,
)
# All limits passed
min_remaining = float("inf")
earliest_reset = current_time
for key, new_count, limit, window_seconds in incremented_keys:
remaining = max(0, limit - new_count)
min_remaining = min(min_remaining, remaining)
ttl = await redis.ttl(key)
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
earliest_reset = max(earliest_reset, reset_at)
return RateLimitResult(
allowed=True,
remaining=int(min_remaining),
reset_at=earliest_reset,
)
async def get_remaining(
self,
identifier: str,
limits: dict[str, tuple[int, int]],
) -> dict[str, int]:
"""
Get remaining quota for all windows without incrementing.
Args:
identifier: Unique identifier
limits: Dictionary of limit configurations
Returns:
Dictionary of remaining quota per window
"""
redis = await get_redis_async()
remaining = {}
for window_name, (limit, _) in limits.items():
key = self._make_key(identifier, window_name)
count = await redis.get(key)
current_count = int(count) if count else 0
remaining[window_name] = max(0, limit - current_count)
return remaining
async def reset(self, identifier: str, window: Optional[str] = None) -> None:
"""
Reset rate limit counters.
Args:
identifier: Unique identifier
window: Optional specific window to reset (resets all if None)
"""
redis = await get_redis_async()
if window:
key = self._make_key(identifier, window)
await redis.delete(key)
else:
# Delete known window keys instead of scanning
# This avoids potentially slow scan operations with many keys
known_windows = ["minute", "hour", "day"]
keys_to_delete = [self._make_key(identifier, w) for w in known_windows]
# Delete all in one call (Redis handles non-existent keys gracefully)
if keys_to_delete:
await redis.delete(*keys_to_delete)
# Default rate limits for different endpoints
DEFAULT_RATE_LIMITS = {
# OAuth endpoints
"oauth_authorize": {"minute": (30, 60)}, # 30/min per IP
"oauth_token": {"minute": (20, 60)}, # 20/min per client
"oauth_consent": {"minute": (20, 60)}, # 20/min per IP for consent submission
# External API endpoints
"api_execute": {
"minute": (10, 60),
"hour": (100, 3600),
}, # 10/min, 100/hour per client+user
"api_read": {
"minute": (60, 60),
"hour": (1000, 3600),
}, # 60/min, 1000/hour per client+user
}
# Module-level singleton
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""Get the singleton rate limiter instance."""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_rate_limit(
identifier: str,
limit_type: str,
) -> RateLimitResult:
"""
Convenience function to check rate limits.
Args:
identifier: Unique identifier for the rate limit
limit_type: Type of limit from DEFAULT_RATE_LIMITS
Returns:
RateLimitResult
"""
limits = DEFAULT_RATE_LIMITS.get(limit_type)
if not limits:
# No rate limit configured, allow
return RateLimitResult(
allowed=True,
remaining=999999,
reset_at=time.time() + 60,
)
rate_limiter = get_rate_limiter()
return await rate_limiter.check_and_increment(identifier, limits)

View File

@@ -651,6 +651,23 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
# OAuth Provider JWT keys
oauth_jwt_private_key: str = Field(
default="",
description="RSA private key for signing OAuth tokens (PEM format). "
"If not set, a development key will be auto-generated.",
)
oauth_jwt_public_key: str = Field(
default="",
description="RSA public key for verifying OAuth tokens (PEM format). "
"If not set, derived from private key.",
)
oauth_jwt_key_id: str = Field(
default="autogpt-oauth-key-1",
description="Key ID (kid) for JWKS. Used to identify the signing key.",
)
# Add more secret fields as needed
model_config = SettingsConfigDict(
env_file=".env",

View File

@@ -0,0 +1,43 @@
"""
Time utilities for the backend.
Common datetime operations used across the codebase.
"""
from datetime import datetime, timedelta, timezone
def expiration_datetime(seconds: int) -> datetime:
"""
Calculate an expiration datetime from now.
Args:
seconds: Number of seconds until expiration
Returns:
Datetime when the item will expire (UTC)
"""
return datetime.now(timezone.utc) + timedelta(seconds=seconds)
def is_expired(dt: datetime) -> bool:
"""
Check if a datetime has passed.
Args:
dt: The datetime to check (should be timezone-aware)
Returns:
True if the datetime is in the past
"""
return dt < datetime.now(timezone.utc)
def utc_now() -> datetime:
"""
Get the current UTC time.
Returns:
Current datetime in UTC
"""
return datetime.now(timezone.utc)

View File

@@ -0,0 +1,46 @@
"""
URL and domain validation utilities.
Common URL validation operations used across the codebase.
"""
def matches_domain_pattern(hostname: str, domain_pattern: str) -> bool:
"""
Check if a hostname matches a domain pattern.
Supports wildcard patterns (*.example.com) which match:
- The base domain (example.com)
- Any subdomain (sub.example.com, deep.sub.example.com)
Args:
hostname: The hostname to check (e.g., "api.example.com")
domain_pattern: The pattern to match against (e.g., "*.example.com" or "example.com")
Returns:
True if the hostname matches the pattern
"""
hostname = hostname.lower()
domain_pattern = domain_pattern.lower()
if domain_pattern.startswith("*."):
# Wildcard domain - matches base and any subdomains
base_domain = domain_pattern[2:]
return hostname == base_domain or hostname.endswith("." + base_domain)
# Exact match
return hostname == domain_pattern
def hostname_matches_any_domain(hostname: str, allowed_domains: list[str]) -> bool:
"""
Check if a hostname matches any of the allowed domain patterns.
Args:
hostname: The hostname to check
allowed_domains: List of allowed domain patterns (supports wildcards)
Returns:
True if the hostname matches any pattern
"""
return any(matches_domain_pattern(hostname, domain) for domain in allowed_domains)

View File

@@ -0,0 +1,249 @@
-- CreateEnum
CREATE TYPE "OAuthClientStatus" AS ENUM ('ACTIVE', 'SUSPENDED');
-- CreateEnum
CREATE TYPE "CredentialGrantPermission" AS ENUM ('USE', 'DELETE');
-- CreateTable
CREATE TABLE "OAuthClient" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"clientId" TEXT NOT NULL,
"clientSecretHash" TEXT,
"clientSecretSalt" TEXT,
"clientType" TEXT NOT NULL,
"name" TEXT NOT NULL,
"description" TEXT,
"logoUrl" TEXT,
"homepageUrl" TEXT,
"privacyPolicyUrl" TEXT,
"termsOfServiceUrl" TEXT,
"redirectUris" TEXT[],
"allowedScopes" TEXT[],
"webhookDomains" TEXT[],
"requirePkce" BOOLEAN NOT NULL DEFAULT true,
"tokenLifetimeSecs" INTEGER NOT NULL DEFAULT 3600,
"refreshTokenLifetimeSecs" INTEGER NOT NULL DEFAULT 2592000,
"status" "OAuthClientStatus" NOT NULL DEFAULT 'ACTIVE',
"ownerId" TEXT NOT NULL,
CONSTRAINT "OAuthClient_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAuthorization" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"scopes" TEXT[],
"revokedAt" TIMESTAMP(3),
CONSTRAINT "OAuthAuthorization_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAuthorizationCode" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"codeHash" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"redirectUri" TEXT NOT NULL,
"scopes" TEXT[],
"nonce" TEXT,
"codeChallenge" TEXT NOT NULL,
"codeChallengeMethod" TEXT NOT NULL DEFAULT 'S256',
"expiresAt" TIMESTAMP(3) NOT NULL,
"usedAt" TIMESTAMP(3),
CONSTRAINT "OAuthAuthorizationCode_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAccessToken" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"tokenHash" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"scopes" TEXT[],
"expiresAt" TIMESTAMP(3) NOT NULL,
"revokedAt" TIMESTAMP(3),
"lastUsedAt" TIMESTAMP(3),
CONSTRAINT "OAuthAccessToken_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthRefreshToken" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"tokenHash" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"scopes" TEXT[],
"expiresAt" TIMESTAMP(3) NOT NULL,
"revokedAt" TIMESTAMP(3),
CONSTRAINT "OAuthRefreshToken_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "CredentialGrant" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"credentialId" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"grantedScopes" TEXT[],
"permissions" "CredentialGrantPermission"[],
"expiresAt" TIMESTAMP(3),
"revokedAt" TIMESTAMP(3),
"lastUsedAt" TIMESTAMP(3),
CONSTRAINT "CredentialGrant_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAuditLog" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"eventType" TEXT NOT NULL,
"userId" TEXT,
"clientId" TEXT,
"grantId" TEXT,
"ipAddress" TEXT,
"userAgent" TEXT,
"details" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "OAuthAuditLog_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ExecutionWebhook" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"executionId" TEXT NOT NULL,
"webhookUrl" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"secret" TEXT,
CONSTRAINT "ExecutionWebhook_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "OAuthClient_clientId_key" ON "OAuthClient"("clientId");
-- CreateIndex
CREATE INDEX "OAuthClient_clientId_idx" ON "OAuthClient"("clientId");
-- CreateIndex
CREATE INDEX "OAuthClient_ownerId_idx" ON "OAuthClient"("ownerId");
-- CreateIndex
CREATE INDEX "OAuthClient_status_idx" ON "OAuthClient"("status");
-- CreateIndex
CREATE INDEX "OAuthAuthorization_userId_idx" ON "OAuthAuthorization"("userId");
-- CreateIndex
CREATE INDEX "OAuthAuthorization_clientId_idx" ON "OAuthAuthorization"("clientId");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthAuthorization_userId_clientId_key" ON "OAuthAuthorization"("userId", "clientId");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthAuthorizationCode_codeHash_key" ON "OAuthAuthorizationCode"("codeHash");
-- CreateIndex
CREATE INDEX "OAuthAuthorizationCode_codeHash_idx" ON "OAuthAuthorizationCode"("codeHash");
-- CreateIndex
CREATE INDEX "OAuthAuthorizationCode_expiresAt_idx" ON "OAuthAuthorizationCode"("expiresAt");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthAccessToken_tokenHash_key" ON "OAuthAccessToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthAccessToken_tokenHash_idx" ON "OAuthAccessToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthAccessToken_userId_clientId_idx" ON "OAuthAccessToken"("userId", "clientId");
-- CreateIndex
CREATE INDEX "OAuthAccessToken_expiresAt_idx" ON "OAuthAccessToken"("expiresAt");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthRefreshToken_tokenHash_key" ON "OAuthRefreshToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthRefreshToken_tokenHash_idx" ON "OAuthRefreshToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthRefreshToken_expiresAt_idx" ON "OAuthRefreshToken"("expiresAt");
-- CreateIndex
CREATE INDEX "CredentialGrant_userId_clientId_idx" ON "CredentialGrant"("userId", "clientId");
-- CreateIndex
CREATE INDEX "CredentialGrant_clientId_idx" ON "CredentialGrant"("clientId");
-- CreateIndex
CREATE UNIQUE INDEX "CredentialGrant_userId_clientId_credentialId_key" ON "CredentialGrant"("userId", "clientId", "credentialId");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_createdAt_idx" ON "OAuthAuditLog"("createdAt");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_eventType_idx" ON "OAuthAuditLog"("eventType");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_userId_idx" ON "OAuthAuditLog"("userId");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_clientId_idx" ON "OAuthAuditLog"("clientId");
-- CreateIndex
CREATE INDEX "ExecutionWebhook_executionId_idx" ON "ExecutionWebhook"("executionId");
-- CreateIndex
CREATE INDEX "ExecutionWebhook_clientId_idx" ON "ExecutionWebhook"("clientId");
-- AddForeignKey
ALTER TABLE "OAuthClient" ADD CONSTRAINT "OAuthClient_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "platform"."OAuthClient" ADD COLUMN "webhookSecret" TEXT;

View File

@@ -60,6 +60,14 @@ model User {
IntegrationWebhooks IntegrationWebhook[]
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
// OAuth Provider relations
OAuthClientsOwned OAuthClient[] @relation("OAuthClientOwner")
OAuthAuthorizations OAuthAuthorization[]
OAuthAuthorizationCodes OAuthAuthorizationCode[]
OAuthAccessTokens OAuthAccessToken[]
OAuthRefreshTokens OAuthRefreshToken[]
CredentialGrants CredentialGrant[]
}
enum OnboardingStep {
@@ -701,11 +709,11 @@ view StoreAgent {
storeListingVersionId String
updated_at DateTime
slug String
agent_name String
agent_video String?
agent_output_demo String?
agent_image String[]
slug String
agent_name String
agent_video String?
agent_output_demo String?
agent_image String[]
featured Boolean @default(false)
creator_username String?
@@ -834,14 +842,14 @@ model StoreListingVersion {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
// Content fields
name String
subHeading String
videoUrl String?
agentOutputDemoUrl String?
imageUrls String[]
description String
instructions String?
categories String[]
name String
subHeading String
videoUrl String?
agentOutputDemoUrl String?
imageUrls String[]
description String
instructions String?
categories String[]
isFeatured Boolean @default(false)
@@ -961,3 +969,226 @@ enum APIKeyStatus {
REVOKED
SUSPENDED
}
// ============================================================
// OAuth Provider & Credential Broker Models
// ============================================================
enum OAuthClientStatus {
ACTIVE
SUSPENDED
}
enum CredentialGrantPermission {
USE // Can use credential for agent execution
DELETE // Can delete the credential
}
// OAuth Client - Registered external applications
model OAuthClient {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Client identification
clientId String @unique // Public identifier (e.g., "app_abc123")
clientSecretHash String? // Hashed (null for public clients)
clientSecretSalt String?
clientType String // "public" or "confidential"
// Metadata (shown on consent screen)
name String
description String?
logoUrl String?
homepageUrl String?
privacyPolicyUrl String?
termsOfServiceUrl String?
// Configuration
redirectUris String[]
allowedScopes String[]
webhookDomains String[] // For webhook URL validation
webhookSecret String? // Secret for HMAC signing webhooks
// Security
requirePkce Boolean @default(true)
tokenLifetimeSecs Int @default(3600)
refreshTokenLifetimeSecs Int @default(2592000) // 30 days
// Status
status OAuthClientStatus @default(ACTIVE)
// Owner
ownerId String
Owner User @relation("OAuthClientOwner", fields: [ownerId], references: [id], onDelete: Cascade)
// Relations
Authorizations OAuthAuthorization[]
AuthorizationCodes OAuthAuthorizationCode[]
AccessTokens OAuthAccessToken[]
RefreshTokens OAuthRefreshToken[]
CredentialGrants CredentialGrant[]
@@index([clientId])
@@index([ownerId])
@@index([status])
}
// OAuth Authorization - User consent record
model OAuthAuthorization {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
scopes String[]
revokedAt DateTime?
@@unique([userId, clientId])
@@index([userId])
@@index([clientId])
}
// OAuth Authorization Code - Short-lived, single-use
model OAuthAuthorizationCode {
id String @id @default(uuid())
createdAt DateTime @default(now())
codeHash String @unique
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
redirectUri String
scopes String[]
nonce String? // OIDC nonce
// PKCE
codeChallenge String
codeChallengeMethod String @default("S256")
expiresAt DateTime // 10 minutes
usedAt DateTime?
@@index([codeHash])
@@index([expiresAt])
}
// OAuth Access Token
model OAuthAccessToken {
id String @id @default(uuid())
createdAt DateTime @default(now())
tokenHash String @unique // SHA256 of token
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
scopes String[]
expiresAt DateTime
revokedAt DateTime?
lastUsedAt DateTime?
@@index([tokenHash])
@@index([userId, clientId])
@@index([expiresAt])
}
// OAuth Refresh Token
model OAuthRefreshToken {
id String @id @default(uuid())
createdAt DateTime @default(now())
tokenHash String @unique
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
scopes String[]
expiresAt DateTime
revokedAt DateTime?
@@index([tokenHash])
@@index([expiresAt])
}
// Credential Grant - Links external app to user's credential with scoped access
model CredentialGrant {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
credentialId String // Reference to credential in User.integrations
provider String
// Fine-grained integration scopes (e.g., "google:gmail.readonly")
grantedScopes String[]
// Permissions for the credential itself
permissions CredentialGrantPermission[]
expiresAt DateTime?
revokedAt DateTime?
lastUsedAt DateTime?
@@unique([userId, clientId, credentialId])
@@index([userId, clientId])
@@index([clientId])
}
// OAuth Audit Log
model OAuthAuditLog {
id String @id @default(uuid())
createdAt DateTime @default(now())
eventType String // e.g., "token.issued", "grant.created"
userId String?
clientId String?
grantId String?
ipAddress String?
userAgent String?
details Json @default("{}")
@@index([createdAt])
@@index([eventType])
@@index([userId])
@@index([clientId])
}
// Execution Webhook - Webhook registration for external API executions
model ExecutionWebhook {
id String @id @default(uuid())
createdAt DateTime @default(now())
executionId String // The graph execution ID
webhookUrl String // URL to send notifications to
clientId String // The OAuth client database ID
userId String // The user who started the execution
secret String? // Optional webhook secret for HMAC signing
@@index([executionId])
@@index([clientId])
}

View File

@@ -22,6 +22,7 @@ import random
from typing import Any, Dict, List
from faker import Faker
from prisma.types import AgentBlockCreateInput
from backend.data.api_key import create_api_key
from backend.data.credit import get_user_credit_model
@@ -177,12 +178,12 @@ class TestDataCreator:
for block in blocks_to_create:
try:
await prisma.agentblock.create(
data={
"id": block.id,
"name": block.name,
"inputSchema": "{}",
"outputSchema": "{}",
}
data=AgentBlockCreateInput(
id=block.id,
name=block.name,
inputSchema="{}",
outputSchema="{}",
)
)
except Exception as e:
print(f"Error creating block {block.name}: {e}")

View File

@@ -30,13 +30,19 @@ from prisma.types import (
AgentGraphCreateInput,
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
AgentPresetCreateInput,
AnalyticsDetailsCreateInput,
AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput,
IntegrationWebhookCreateInput,
LibraryAgentCreateInput,
ProfileCreateInput,
StoreListingCreateInput,
StoreListingReviewCreateInput,
StoreListingVersionCreateInput,
UserCreateInput,
UserOnboardingCreateInput,
)
faker = Faker()
@@ -172,14 +178,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs)
preset = await db.agentpreset.create(
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isActive": True,
}
data=AgentPresetCreateInput(
name=faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200),
userId=user.id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
isActive=True,
)
)
agent_presets.append(preset)
@@ -220,18 +226,18 @@ async def main():
)
library_agent = await db.libraryagent.create(
data={
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"creatorId": creator_profile.id if creator_profile else None,
"imageUrl": get_image() if random.random() < 0.5 else None,
"useGraphIsActiveVersion": random.choice([True, False]),
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
data=LibraryAgentCreateInput(
userId=user.id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
creatorId=creator_profile.id if creator_profile else None,
imageUrl=get_image() if random.random() < 0.5 else None,
useGraphIsActiveVersion=random.choice([True, False]),
isFavorite=random.choice([True, False]),
isCreatedByUser=random.choice([True, False]),
isArchived=random.choice([True, False]),
isDeleted=random.choice([True, False]),
)
)
library_agents.append(library_agent)
@@ -392,13 +398,13 @@ async def main():
user = random.choice(users)
slug = faker.slug()
listing = await db.storelisting.create(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,
}
data=StoreListingCreateInput(
agentGraphId=graph.id,
agentGraphVersion=graph.version,
owningUserId=user.id,
hasApprovedVersion=random.choice([True, False]),
slug=slug,
)
)
store_listings.append(listing)
@@ -408,26 +414,26 @@ async def main():
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
version = await db.storelistingversion.create(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": get_video_url() if random.random() < 0.3 else None,
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"storeListingId": listing.id,
"submissionStatus": random.choice(
data=StoreListingVersionCreateInput(
agentGraphId=graph.id,
agentGraphVersion=graph.version,
name=graph.name or faker.sentence(nb_words=3),
subHeading=faker.sentence(),
videoUrl=get_video_url() if random.random() < 0.3 else None,
imageUrls=[get_image() for _ in range(3)],
description=faker.text(),
categories=[faker.word() for _ in range(3)],
isFeatured=random.choice([True, False]),
isAvailable=True,
storeListingId=listing.id,
submissionStatus=random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
),
}
)
)
store_listing_versions.append(version)
@@ -469,51 +475,47 @@ async def main():
try:
await db.useronboarding.create(
data={
"userId": user.id,
"completedSteps": completed_steps,
"walletShown": random.choice([True, False]),
"notified": (
data=UserOnboardingCreateInput(
userId=user.id,
completedSteps=completed_steps,
walletShown=random.choice([True, False]),
notified=(
random.sample(completed_steps, k=min(3, len(completed_steps)))
if completed_steps
else []
),
"rewardedFor": (
rewardedFor=(
random.sample(completed_steps, k=min(2, len(completed_steps)))
if completed_steps
else []
),
"usageReason": (
usageReason=(
random.choice(["personal", "business", "research", "learning"])
if random.random() < 0.7
else None
),
"integrations": random.sample(
integrations=random.sample(
["github", "google", "discord", "slack"], k=random.randint(0, 2)
),
"otherIntegrations": (
faker.word() if random.random() < 0.2 else None
),
"selectedStoreListingVersionId": (
otherIntegrations=(faker.word() if random.random() < 0.2 else None),
selectedStoreListingVersionId=(
random.choice(store_listing_versions).id
if store_listing_versions and random.random() < 0.5
else None
),
"onboardingAgentExecutionId": (
onboardingAgentExecutionId=(
random.choice(agent_graph_executions).id
if agent_graph_executions and random.random() < 0.3
else None
),
"agentRuns": random.randint(0, 10),
}
agentRuns=random.randint(0, 10),
)
)
except Exception as e:
print(f"Error creating onboarding for user {user.id}: {e}")
# Try simpler version
await db.useronboarding.create(
data={
"userId": user.id,
}
data=UserOnboardingCreateInput(userId=user.id)
)
# Insert IntegrationWebhooks for some users
@@ -544,20 +546,20 @@ async def main():
for user in users:
api_key = APIKeySmith().generate_key()
await db.apikey.create(
data={
"name": faker.word(),
"head": api_key.head,
"tail": api_key.tail,
"hash": api_key.hash,
"salt": api_key.salt,
"status": prisma.enums.APIKeyStatus.ACTIVE,
"permissions": [
data=APIKeyCreateInput(
name=faker.word(),
head=api_key.head,
tail=api_key.tail,
hash=api_key.hash,
salt=api_key.salt,
status=prisma.enums.APIKeyStatus.ACTIVE,
permissions=[
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
prisma.enums.APIKeyPermission.READ_GRAPH,
],
"description": faker.text(),
"userId": user.id,
}
description=faker.text(),
userId=user.id,
)
)
# Refresh materialized views

View File

@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
import prisma.enums
from faker import Faker
from prisma import Json, Prisma
from prisma.types import CreditTransactionCreateInput, StoreListingReviewCreateInput
faker = Faker()
@@ -166,16 +167,16 @@ async def main():
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
await db.storelistingreview.create(
data={
"storeListingVersionId": version.id,
"reviewByUserId": reviewer.id,
"score": score,
"comments": (
data=StoreListingReviewCreateInput(
storeListingVersionId=version.id,
reviewByUserId=reviewer.id,
score=score,
comments=(
faker.text(max_nb_chars=200)
if random.random() < 0.7
else None
),
}
)
)
new_reviews_count += 1
@@ -244,17 +245,17 @@ async def main():
)
await db.credittransaction.create(
data={
"userId": user.id,
"amount": amount,
"type": transaction_type,
"metadata": Json(
data=CreditTransactionCreateInput(
userId=user.id,
amount=amount,
type=transaction_type,
metadata=Json(
{
"source": "test_updater",
"timestamp": datetime.now().isoformat(),
}
),
}
)
)
transaction_count += 1

View File

@@ -8,6 +8,8 @@ import { shouldShowOnboarding } from "@/app/api/helpers";
export async function GET(request: Request) {
const { searchParams, origin } = new URL(request.url);
const code = searchParams.get("code");
const oauthSession = searchParams.get("oauth_session");
const connectSession = searchParams.get("connect_session");
let next = "/marketplace";
@@ -25,6 +27,22 @@ export async function GET(request: Request) {
const api = new BackendAPI();
await api.createUser();
// Handle oauth_session redirect - resume OAuth flow after login
// Redirect to a frontend page that will handle the OAuth resume with proper auth
if (oauthSession) {
return NextResponse.redirect(
`${origin}/auth/oauth-resume?session_id=${encodeURIComponent(oauthSession)}`,
);
}
// Handle connect_session redirect - resume connect flow after login
// Redirect to a frontend page that will handle the connect resume with proper auth
if (connectSession) {
return NextResponse.redirect(
`${origin}/auth/connect-resume?session_id=${encodeURIComponent(connectSession)}`,
);
}
if (await shouldShowOnboarding()) {
next = "/onboarding";
revalidatePath("/onboarding", "layout");

View File

@@ -0,0 +1,400 @@
"use client";
import { useEffect, useState, useRef, useCallback } from "react";
import { useSearchParams } from "next/navigation";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { getWebSocketToken } from "@/lib/supabase/actions";
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
const attemptedSessions = new Set<string>();
interface ScopeInfo {
scope: string;
description: string;
}
interface CredentialInfo {
id: string;
title: string;
username: string;
}
interface ClientInfo {
name: string;
logo_url: string | null;
}
interface ConnectData {
connect_token: string;
client: ClientInfo;
provider: string;
scopes: ScopeInfo[];
credentials: CredentialInfo[];
action_url: string;
}
interface ErrorData {
error: string;
error_description: string;
}
type ResumeResponse = ConnectData | ErrorData;
function isConnectData(data: ResumeResponse): data is ConnectData {
return "connect_token" in data;
}
function isErrorData(data: ResumeResponse): data is ErrorData {
return "error" in data;
}
/**
* Connect Consent Form Component
*
* Renders a proper React component for the integration connect consent form
*/
function ConnectForm({
client,
provider,
scopes,
credentials,
connectToken,
actionUrl,
}: {
client: ClientInfo;
provider: string;
scopes: ScopeInfo[];
credentials: CredentialInfo[];
connectToken: string;
actionUrl: string;
}) {
const [isSubmitting, setIsSubmitting] = useState(false);
const [selectedCredential, setSelectedCredential] = useState<string>(
credentials.length > 0 ? credentials[0].id : "",
);
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
const backendOrigin = backendUrl
? new URL(backendUrl).origin
: "http://localhost:8006";
const fullActionUrl = `${backendOrigin}${actionUrl}`;
function handleSubmit() {
setIsSubmitting(true);
}
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
{/* Header */}
<div className="mb-6 text-center">
<h1 className="text-xl font-semibold text-zinc-100">
Connect{" "}
<span className="rounded bg-zinc-700 px-2 py-1 text-sm capitalize">
{provider}
</span>
</h1>
<p className="mt-2 text-sm text-zinc-400">
<span className="font-semibold text-cyan-400">{client.name}</span>{" "}
wants to use your {provider} integration
</p>
</div>
{/* Divider */}
<div className="my-6 h-px bg-zinc-700" />
{/* Scopes Section */}
<div className="mb-6">
<h2 className="mb-4 text-sm font-medium text-zinc-400">
This will allow {client.name} to:
</h2>
<div className="space-y-2">
{scopes.map((scope) => (
<div key={scope.scope} className="flex items-start gap-2 py-2">
<span className="flex-shrink-0 text-cyan-400">&#10003;</span>
<span className="text-sm text-zinc-300">
{scope.description}
</span>
</div>
))}
</div>
</div>
{/* Divider */}
<div className="my-6 h-px bg-zinc-700" />
{/* Form */}
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
<input type="hidden" name="connect_token" value={connectToken} />
{/* Existing credentials selection */}
{credentials.length > 0 && (
<>
<h3 className="mb-3 text-sm font-medium text-zinc-400">
Select an existing credential:
</h3>
<div className="mb-4 space-y-2">
{credentials.map((cred) => (
<label
key={cred.id}
className={`flex cursor-pointer items-center gap-3 rounded-lg border p-3 transition-colors ${
selectedCredential === cred.id
? "border-cyan-400 bg-cyan-400/10"
: "border-zinc-700 hover:border-cyan-400/50"
}`}
>
<input
type="radio"
name="credential_id"
value={cred.id}
checked={selectedCredential === cred.id}
onChange={() => setSelectedCredential(cred.id)}
className="hidden"
/>
<div>
<div className="text-sm font-medium text-zinc-200">
{cred.title}
</div>
{cred.username && (
<div className="text-xs text-zinc-500">
{cred.username}
</div>
)}
</div>
</label>
))}
</div>
<div className="my-4 h-px bg-zinc-700" />
</>
)}
{/* Connect new account */}
<div className="mb-4">
{credentials.length > 0 ? (
<h3 className="mb-3 text-sm font-medium text-zinc-400">
Or connect a new account:
</h3>
) : (
<p className="mb-3 text-sm text-zinc-400">
You don&apos;t have any {provider} credentials yet.
</p>
)}
<button
type="submit"
name="action"
value="connect_new"
disabled={isSubmitting}
className="w-full rounded-lg bg-blue-500 px-6 py-3 text-sm font-medium text-white transition-colors hover:bg-blue-400 disabled:cursor-not-allowed disabled:opacity-50"
>
Connect {provider.charAt(0).toUpperCase() + provider.slice(1)}{" "}
Account
</button>
</div>
{/* Action buttons */}
<div className="flex gap-3">
<button
type="submit"
name="action"
value="deny"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
>
Cancel
</button>
{credentials.length > 0 && (
<button
type="submit"
name="action"
value="approve"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
>
{isSubmitting ? "Approving..." : "Approve"}
</button>
)}
</div>
</form>
</div>
</div>
);
}
/**
* Connect Resume Page
*
* This page handles resuming the integration connect flow after a user logs in.
* It fetches the connect data from the backend via JSON API and renders the consent form.
*/
export default function ConnectResumePage() {
const searchParams = useSearchParams();
const sessionId = searchParams.get("session_id");
const { isUserLoading, refreshSession } = useSupabase();
const [connectData, setConnectData] = useState<ConnectData | null>(null);
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(true);
const retryCountRef = useRef(0);
const maxRetries = 5;
const resumeConnectFlow = useCallback(async () => {
if (!sessionId) {
setError(
"Missing session ID. Please start the connection process again.",
);
setIsLoading(false);
return;
}
if (attemptedSessions.has(sessionId)) {
return;
}
if (isUserLoading) {
return;
}
attemptedSessions.add(sessionId);
try {
let tokenResult = await getWebSocketToken();
let accessToken = tokenResult.token;
while (!accessToken && retryCountRef.current < maxRetries) {
retryCountRef.current += 1;
console.log(
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
);
await refreshSession();
await new Promise((resolve) => setTimeout(resolve, 1000));
tokenResult = await getWebSocketToken();
accessToken = tokenResult.token;
}
if (!accessToken) {
setError(
"Unable to retrieve authentication token. Please log in again.",
);
setIsLoading(false);
return;
}
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
if (!backendUrl) {
setError("Backend URL not configured.");
setIsLoading(false);
return;
}
let backendOrigin: string;
try {
const url = new URL(backendUrl);
backendOrigin = url.origin;
} catch {
setError("Invalid backend URL configuration.");
setIsLoading(false);
return;
}
const response = await fetch(
`${backendOrigin}/connect/resume?session_id=${encodeURIComponent(sessionId)}`,
{
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
Accept: "application/json",
},
},
);
const data: ResumeResponse = await response.json();
if (!response.ok) {
if (isErrorData(data)) {
setError(data.error_description || data.error);
} else {
setError(`Connection failed (${response.status}). Please try again.`);
}
setIsLoading(false);
return;
}
if (isConnectData(data)) {
setConnectData(data);
setIsLoading(false);
return;
}
setError("Unexpected response from server. Please try again.");
setIsLoading(false);
} catch (err) {
console.error("Connect resume error:", err);
setError(
"An error occurred while resuming connection. Please try again.",
);
setIsLoading(false);
}
}, [sessionId, isUserLoading, refreshSession]);
useEffect(() => {
resumeConnectFlow();
}, [resumeConnectFlow]);
if (isLoading || isUserLoading) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="text-center">
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
<p className="text-zinc-400">Resuming connection...</p>
</div>
</div>
);
}
if (error) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
<svg
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
>
<circle cx="12" cy="12" r="10" />
<line x1="15" y1="9" x2="9" y2="15" />
<line x1="9" y1="9" x2="15" y2="15" />
</svg>
</div>
<h1 className="mb-2 text-xl font-semibold text-red-400">
Connection Error
</h1>
<p className="mb-6 text-zinc-400">{error}</p>
<button
onClick={() => window.close()}
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
>
Close
</button>
</div>
</div>
);
}
if (connectData) {
return (
<ConnectForm
client={connectData.client}
provider={connectData.provider}
scopes={connectData.scopes}
credentials={connectData.credentials}
connectToken={connectData.connect_token}
actionUrl={connectData.action_url}
/>
);
}
return null;
}

View File

@@ -22,20 +22,28 @@ export async function GET(request: Request) {
console.debug("Sending message to opener:", message);
// Escape JSON to prevent XSS attacks via </script> injection
const safeJson = JSON.stringify(message)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e");
// Return a response with the message as JSON and a script to close the window
return new NextResponse(
`
<html>
<body>
<script>
window.opener.postMessage(${JSON.stringify(message)});
window.close();
</script>
</body>
</html>
`,
`<!DOCTYPE html>
<html>
<body>
<script>
window.opener.postMessage(${safeJson}, '*');
window.close();
</script>
</body>
</html>`,
{
headers: { "Content-Type": "text/html" },
headers: {
"Content-Type": "text/html",
"Content-Security-Policy":
"default-src 'none'; script-src 'unsafe-inline'",
},
},
);
}

View File

@@ -0,0 +1,399 @@
"use client";
import { useEffect, useState, useRef, useCallback } from "react";
import { useSearchParams } from "next/navigation";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { getWebSocketToken } from "@/lib/supabase/actions";
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
// This is keyed by session_id to allow different sessions
const attemptedSessions = new Set<string>();
interface ScopeInfo {
scope: string;
description: string;
}
interface ClientInfo {
name: string;
logo_url: string | null;
privacy_policy_url: string | null;
terms_url: string | null;
}
interface ConsentData {
needs_consent: true;
consent_token: string;
client: ClientInfo;
scopes: ScopeInfo[];
action_url: string;
}
interface RedirectData {
redirect_url: string;
needs_consent: false;
}
interface ErrorData {
error: string;
error_description: string;
redirect_url?: string;
}
type ResumeResponse = ConsentData | RedirectData | ErrorData;
function isConsentData(data: ResumeResponse): data is ConsentData {
return "needs_consent" in data && data.needs_consent === true;
}
function isRedirectData(data: ResumeResponse): data is RedirectData {
return "redirect_url" in data && !("error" in data);
}
function isErrorData(data: ResumeResponse): data is ErrorData {
return "error" in data;
}
/**
* OAuth Consent Form Component
*
* Renders a proper React component for the consent form instead of dangerouslySetInnerHTML
*/
function ConsentForm({
client,
scopes,
consentToken,
actionUrl,
}: {
client: ClientInfo;
scopes: ScopeInfo[];
consentToken: string;
actionUrl: string;
}) {
const [isSubmitting, setIsSubmitting] = useState(false);
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
const backendOrigin = backendUrl
? new URL(backendUrl).origin
: "http://localhost:8006";
// Full action URL for form submission
const fullActionUrl = `${backendOrigin}${actionUrl}`;
function handleSubmit() {
setIsSubmitting(true);
}
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
{/* Header */}
<div className="mb-6 text-center">
<div className="mx-auto mb-4 flex h-16 w-16 items-center justify-center rounded-xl bg-zinc-700">
{client.logo_url ? (
<img
src={client.logo_url}
alt={client.name}
className="h-12 w-12 rounded-lg"
/>
) : (
<span className="text-3xl text-zinc-400">
{client.name.charAt(0).toUpperCase()}
</span>
)}
</div>
<h1 className="text-xl font-semibold text-zinc-100">
Authorize <span className="text-cyan-400">{client.name}</span>
</h1>
<p className="mt-2 text-sm text-zinc-400">
wants to access your AutoGPT account
</p>
</div>
{/* Divider */}
<div className="my-6 h-px bg-zinc-700" />
{/* Scopes Section */}
<div className="mb-6">
<h2 className="mb-4 text-sm font-medium text-zinc-400">
This will allow {client.name} to:
</h2>
<div className="space-y-3">
{scopes.map((scope) => (
<div
key={scope.scope}
className="flex items-start gap-3 border-b border-zinc-700 pb-3 last:border-0"
>
<svg
className="mt-0.5 h-5 w-5 flex-shrink-0 text-cyan-400"
viewBox="0 0 20 20"
fill="currentColor"
>
<path
fillRule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clipRule="evenodd"
/>
</svg>
<span className="text-sm leading-relaxed text-zinc-300">
{scope.description}
</span>
</div>
))}
</div>
</div>
{/* Form */}
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
<input type="hidden" name="consent_token" value={consentToken} />
<div className="flex gap-3">
<button
type="submit"
name="authorize"
value="false"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
>
Cancel
</button>
<button
type="submit"
name="authorize"
value="true"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
>
{isSubmitting ? "Authorizing..." : "Allow"}
</button>
</div>
</form>
{/* Footer Links */}
{(client.privacy_policy_url || client.terms_url) && (
<div className="mt-6 text-center text-xs text-zinc-500">
{client.privacy_policy_url && (
<a
href={client.privacy_policy_url}
target="_blank"
rel="noopener noreferrer"
className="text-zinc-400 hover:underline"
>
Privacy Policy
</a>
)}
{client.privacy_policy_url && client.terms_url && (
<span className="mx-2"></span>
)}
{client.terms_url && (
<a
href={client.terms_url}
target="_blank"
rel="noopener noreferrer"
className="text-zinc-400 hover:underline"
>
Terms of Service
</a>
)}
</div>
)}
</div>
</div>
);
}
/**
* OAuth Resume Page
*
* This page handles resuming the OAuth authorization flow after a user logs in.
* It fetches the consent data from the backend via JSON API and renders the consent form.
*/
export default function OAuthResumePage() {
const searchParams = useSearchParams();
const sessionId = searchParams.get("session_id");
const { isUserLoading, refreshSession } = useSupabase();
const [consentData, setConsentData] = useState<ConsentData | null>(null);
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(true);
const retryCountRef = useRef(0);
const maxRetries = 5;
const resumeOAuthFlow = useCallback(async () => {
// Prevent multiple attempts for the same session (handles React StrictMode)
if (!sessionId) {
setError(
"Missing session ID. Please start the authorization process again.",
);
setIsLoading(false);
return;
}
if (attemptedSessions.has(sessionId)) {
// Already attempted this session, don't retry
return;
}
if (isUserLoading) {
return; // Wait for auth state to load
}
// Mark this session as attempted IMMEDIATELY to prevent duplicate requests
attemptedSessions.add(sessionId);
try {
// Get the access token from server action (which reads cookies properly)
let tokenResult = await getWebSocketToken();
let accessToken = tokenResult.token;
// If no token, retry a few times with delays
while (!accessToken && retryCountRef.current < maxRetries) {
retryCountRef.current += 1;
console.log(
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
);
// Try refreshing the session
await refreshSession();
await new Promise((resolve) => setTimeout(resolve, 1000));
tokenResult = await getWebSocketToken();
accessToken = tokenResult.token;
}
if (!accessToken) {
setError(
"Unable to retrieve authentication token. Please log in again.",
);
setIsLoading(false);
return;
}
// Call the backend resume endpoint with JSON accept header
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
if (!backendUrl) {
setError("Backend URL not configured.");
setIsLoading(false);
return;
}
// Extract the origin from the backend URL
let backendOrigin: string;
try {
const url = new URL(backendUrl);
backendOrigin = url.origin;
} catch {
setError("Invalid backend URL configuration.");
setIsLoading(false);
return;
}
// Use Accept: application/json to get JSON response instead of HTML
// This solves the CORS/redirect issue by letting us handle redirects client-side
const response = await fetch(
`${backendOrigin}/oauth/authorize/resume?session_id=${encodeURIComponent(sessionId)}`,
{
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
Accept: "application/json",
},
},
);
const data: ResumeResponse = await response.json();
if (!response.ok) {
if (isErrorData(data)) {
setError(data.error_description || data.error);
} else {
setError(
`Authorization failed (${response.status}). Please try again.`,
);
}
setIsLoading(false);
return;
}
// Handle redirect response (user already authorized these scopes)
if (isRedirectData(data)) {
window.location.href = data.redirect_url;
return;
}
// Handle consent required
if (isConsentData(data)) {
setConsentData(data);
setIsLoading(false);
return;
}
// Unexpected response
setError("Unexpected response from server. Please try again.");
setIsLoading(false);
} catch (err) {
console.error("OAuth resume error:", err);
setError(
"An error occurred while resuming authorization. Please try again.",
);
setIsLoading(false);
}
}, [sessionId, isUserLoading, refreshSession]);
useEffect(() => {
resumeOAuthFlow();
}, [resumeOAuthFlow]);
if (isLoading || isUserLoading) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="text-center">
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
<p className="text-zinc-400">Resuming authorization...</p>
</div>
</div>
);
}
if (error) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
<svg
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
>
<circle cx="12" cy="12" r="10" />
<line x1="15" y1="9" x2="9" y2="15" />
<line x1="9" y1="9" x2="15" y2="15" />
</svg>
</div>
<h1 className="mb-2 text-xl font-semibold text-red-400">
Authorization Error
</h1>
<p className="mb-6 text-zinc-400">{error}</p>
<button
onClick={() => window.close()}
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
>
Close
</button>
</div>
</div>
);
}
if (consentData) {
return (
<ConsentForm
client={consentData.client}
scopes={consentData.scopes}
consentToken={consentData.consent_token}
actionUrl={consentData.action_url}
/>
);
}
return null;
}

View File

@@ -3,7 +3,7 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { environment } from "@/services/environment";
import { loginFormSchema, LoginProvider } from "@/types/auth";
import { zodResolver } from "@hookform/resolvers/zod";
import { useRouter } from "next/navigation";
import { useRouter, useSearchParams } from "next/navigation";
import { useEffect, useState } from "react";
import { useForm } from "react-hook-form";
import z from "zod";
@@ -13,6 +13,7 @@ export function useLoginPage() {
const { supabase, user, isUserLoading, isLoggedIn } = useSupabase();
const [feedback, setFeedback] = useState<string | null>(null);
const router = useRouter();
const searchParams = useSearchParams();
const { toast } = useToast();
const [isLoading, setIsLoading] = useState(false);
const [isLoggingIn, setIsLoggingIn] = useState(false);
@@ -20,11 +21,59 @@ export function useLoginPage() {
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
const isCloudEnv = environment.isCloud();
// Get returnUrl, oauth_session, and connect_session from query params
const returnUrl = searchParams.get("returnUrl");
const oauthSession = searchParams.get("oauth_session");
const connectSession = searchParams.get("connect_session");
function getRedirectUrl(onboarding: boolean): string {
// OAuth session takes priority - redirect to frontend oauth-resume page
// which will handle the backend call with proper authentication
if (oauthSession) {
return `/auth/oauth-resume?session_id=${encodeURIComponent(oauthSession)}`;
}
// Connect session - redirect to frontend connect-resume page
// for integration credential connection flow
if (connectSession) {
return `/auth/connect-resume?session_id=${encodeURIComponent(connectSession)}`;
}
// If returnUrl is set and is a valid URL, redirect there after login
if (returnUrl) {
try {
const url = new URL(returnUrl, window.location.origin);
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
// Same origin - return normalized path only (prevents open redirect)
if (url.origin === window.location.origin) {
return url.pathname + url.search;
}
// Backend URL - strict origin match (not startsWith to prevent prefix attacks)
if (backendUrl) {
try {
const backendOrigin = new URL(backendUrl).origin;
if (url.origin === backendOrigin) {
return url.href;
}
} catch {
// Invalid backend URL config, fall through to default
}
}
} catch {
// Invalid URL, fall through to default
}
}
return onboarding ? "/onboarding" : "/marketplace";
}
useEffect(() => {
if (isLoggedIn && !isLoggingIn) {
router.push("/marketplace");
const redirectTo = getRedirectUrl(false);
router.push(redirectTo);
}
}, [isLoggedIn, isLoggingIn]);
}, [isLoggedIn, isLoggingIn, returnUrl, oauthSession, connectSession]);
const form = useForm<z.infer<typeof loginFormSchema>>({
resolver: zodResolver(loginFormSchema),
@@ -39,10 +88,20 @@ export function useLoginPage() {
setIsLoggingIn(true);
try {
// Build redirect URL that preserves oauth_session or connect_session through the OAuth flow
let callbackUrl: string | undefined;
const origin = window.location.origin;
if (oauthSession) {
callbackUrl = `${origin}/auth/callback?oauth_session=${encodeURIComponent(oauthSession)}`;
} else if (connectSession) {
callbackUrl = `${origin}/auth/callback?connect_session=${encodeURIComponent(connectSession)}`;
}
const response = await fetch("/api/auth/provider", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ provider }),
body: JSON.stringify({ provider, redirectTo: callbackUrl }),
});
if (!response.ok) {
@@ -83,11 +142,7 @@ export function useLoginPage() {
throw new Error(result.error || "Login failed");
}
if (result.onboarding) {
router.replace("/onboarding");
} else {
router.replace("/marketplace");
}
router.replace(getRedirectUrl(result.onboarding ?? false));
} catch (error) {
toast({
title:

View File

@@ -0,0 +1,287 @@
"use client";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/__legacy__/ui/dialog";
import { Copy } from "@phosphor-icons/react";
import { Label } from "@/components/__legacy__/ui/label";
import { Input } from "@/components/__legacy__/ui/input";
import { Textarea } from "@/components/__legacy__/ui/textarea";
import { Button } from "@/components/__legacy__/ui/button";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/__legacy__/ui/select";
import { useOAuthClientModals } from "./useOAuthClientModals";
export function OAuthClientModals() {
const {
isCreateOpen,
setIsCreateOpen,
isSecretDialogOpen,
setIsSecretDialogOpen,
formState,
setFormState,
newClientSecret,
isCreating,
handleCreateClient,
handleCopyClientId,
handleCopyClientSecret,
handleCopyWebhookSecret,
resetForm,
} = useOAuthClientModals();
return (
<div className="mb-4 flex justify-end">
<Dialog
open={isCreateOpen}
onOpenChange={(open) => {
setIsCreateOpen(open);
if (!open) resetForm();
}}
>
<DialogTrigger asChild>
<Button>Register OAuth Client</Button>
</DialogTrigger>
<DialogContent className="max-h-[90vh] overflow-y-auto sm:max-w-[525px]">
<DialogHeader>
<DialogTitle>Register New OAuth Client</DialogTitle>
<DialogDescription>
Register a new OAuth client to integrate with the AutoGPT
Platform. For confidential clients, the client secret will only be
shown once.
</DialogDescription>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="name">
Name <span className="text-destructive">*</span>
</Label>
<Input
id="name"
value={formState.name}
onChange={(e) =>
setFormState((prev) => ({
...prev,
name: e.target.value,
}))
}
placeholder="My Application"
maxLength={100}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="description">Description</Label>
<Input
id="description"
value={formState.description}
onChange={(e) =>
setFormState((prev) => ({
...prev,
description: e.target.value,
}))
}
placeholder="A brief description of your application"
maxLength={500}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="redirectUris">
Redirect URIs <span className="text-destructive">*</span>
</Label>
<Textarea
id="redirectUris"
value={formState.redirectUris}
onChange={(e) =>
setFormState((prev) => ({
...prev,
redirectUris: e.target.value,
}))
}
placeholder="https://myapp.com/callback&#10;https://localhost:3000/callback"
rows={3}
/>
<p className="text-xs text-muted-foreground">
Enter one URI per line or separate with commas
</p>
</div>
<div className="grid gap-2">
<Label htmlFor="clientType">Client Type</Label>
<Select
value={formState.clientType}
onValueChange={(value: "public" | "confidential") =>
setFormState((prev) => ({
...prev,
clientType: value,
}))
}
>
<SelectTrigger>
<SelectValue placeholder="Select client type" />
</SelectTrigger>
<SelectContent>
<SelectItem value="public">
Public (SPA, Mobile apps)
</SelectItem>
<SelectItem value="confidential">
Confidential (Server-side apps)
</SelectItem>
</SelectContent>
</Select>
<p className="text-xs text-muted-foreground">
Public clients cannot securely store secrets. Confidential
clients receive a client secret.
</p>
</div>
<div className="grid gap-2">
<Label htmlFor="homepageUrl">Homepage URL</Label>
<Input
id="homepageUrl"
type="url"
value={formState.homepageUrl}
onChange={(e) =>
setFormState((prev) => ({
...prev,
homepageUrl: e.target.value,
}))
}
placeholder="https://myapp.com"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="privacyPolicyUrl">Privacy Policy URL</Label>
<Input
id="privacyPolicyUrl"
type="url"
value={formState.privacyPolicyUrl}
onChange={(e) =>
setFormState((prev) => ({
...prev,
privacyPolicyUrl: e.target.value,
}))
}
placeholder="https://myapp.com/privacy"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="termsOfServiceUrl">Terms of Service URL</Label>
<Input
id="termsOfServiceUrl"
type="url"
value={formState.termsOfServiceUrl}
onChange={(e) =>
setFormState((prev) => ({
...prev,
termsOfServiceUrl: e.target.value,
}))
}
placeholder="https://myapp.com/terms"
/>
</div>
</div>
<DialogFooter>
<Button
variant="outline"
onClick={() => {
setIsCreateOpen(false);
resetForm();
}}
>
Cancel
</Button>
<Button onClick={handleCreateClient} disabled={isCreating}>
{isCreating ? "Creating..." : "Create Client"}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<Dialog open={isSecretDialogOpen} onOpenChange={setIsSecretDialogOpen}>
<DialogContent className="sm:max-w-[525px]">
<DialogHeader>
<DialogTitle>OAuth Client Created</DialogTitle>
<DialogDescription>
Please copy your client credentials now. These secrets will not be
shown again!
</DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label>Client ID</Label>
<div className="flex items-center space-x-2">
<code className="flex-1 rounded-md bg-secondary p-2 font-mono text-sm">
{newClientSecret?.client_id}
</code>
<Button
size="icon"
variant="outline"
onClick={handleCopyClientId}
>
<Copy className="h-4 w-4" weight="bold" />
</Button>
</div>
</div>
{newClientSecret?.client_secret && (
<div className="space-y-2">
<Label>Client Secret</Label>
<div className="flex items-center space-x-2">
<code className="flex-1 break-all rounded-md bg-secondary p-2 font-mono text-sm">
{newClientSecret.client_secret}
</code>
<Button
size="icon"
variant="outline"
onClick={handleCopyClientSecret}
>
<Copy className="h-4 w-4" weight="bold" />
</Button>
</div>
</div>
)}
{newClientSecret?.webhook_secret && (
<div className="space-y-2">
<Label>Webhook Secret</Label>
<div className="flex items-center space-x-2">
<code className="flex-1 break-all rounded-md bg-secondary p-2 font-mono text-sm">
{newClientSecret.webhook_secret}
</code>
<Button
size="icon"
variant="outline"
onClick={handleCopyWebhookSecret}
>
<Copy className="h-4 w-4" weight="bold" />
</Button>
</div>
<p className="text-xs text-muted-foreground">
Use this secret to verify webhook signatures (HMAC-SHA256)
</p>
</div>
)}
<p className="text-xs text-destructive">
These secrets will only be shown once. Store them securely!
</p>
</div>
<DialogFooter>
<Button onClick={() => setIsSecretDialogOpen(false)}>Close</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
);
}

View File

@@ -0,0 +1,232 @@
"use client";
import {
getGetOauthClientsListClientsQueryKey,
usePostOauthClientsRegisterClient,
usePostOauthClientsRotateClientSecret,
} from "@/app/api/__generated__/endpoints/oauth-clients/oauth-clients";
import { ClientSecretResponse } from "@/app/api/__generated__/models/clientSecretResponse";
import { RegisterClientRequest } from "@/app/api/__generated__/models/registerClientRequest";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { validateRedirectUri } from "@/lib/utils";
import { useState } from "react";
type ClientType = "public" | "confidential";
// ClientSecretResponse already includes webhook_secret from the generated types
type ClientSecretResponseWithWebhook = ClientSecretResponse;
interface ClientFormState {
name: string;
description: string;
redirectUris: string;
clientType: ClientType;
homepageUrl: string;
privacyPolicyUrl: string;
termsOfServiceUrl: string;
}
const initialFormState: ClientFormState = {
name: "",
description: "",
redirectUris: "",
clientType: "public",
homepageUrl: "",
privacyPolicyUrl: "",
termsOfServiceUrl: "",
};
export function useOAuthClientModals() {
const [isCreateOpen, setIsCreateOpen] = useState(false);
const [isSecretDialogOpen, setIsSecretDialogOpen] = useState(false);
const [formState, setFormState] = useState<ClientFormState>(initialFormState);
const [newClientSecret, setNewClientSecret] =
useState<ClientSecretResponseWithWebhook | null>(null);
const queryClient = getQueryClient();
const { toast } = useToast();
const { mutateAsync: registerClient, isPending: isCreating } =
usePostOauthClientsRegisterClient({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetOauthClientsListClientsQueryKey(),
});
},
},
});
const { mutateAsync: rotateSecret, isPending: isRotating } =
usePostOauthClientsRotateClientSecret({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetOauthClientsListClientsQueryKey(),
});
},
},
});
function resetForm() {
setFormState(initialFormState);
}
async function handleCreateClient() {
// Parse redirect URIs (comma or newline separated)
const redirectUris = formState.redirectUris
.split(/[,\n]/)
.map((uri) => uri.trim())
.filter((uri) => uri.length > 0);
if (redirectUris.length === 0) {
toast({
title: "Error",
description: "At least one redirect URI is required",
variant: "destructive",
});
return;
}
// Validate each redirect URI
for (const uri of redirectUris) {
const validation = validateRedirectUri(uri);
if (!validation.valid) {
toast({
title: "Invalid Redirect URI",
description: `"${uri}": ${validation.error}`,
variant: "destructive",
});
return;
}
}
if (!formState.name.trim()) {
toast({
title: "Error",
description: "Client name is required",
variant: "destructive",
});
return;
}
try {
const requestData: RegisterClientRequest = {
name: formState.name.trim(),
redirect_uris: redirectUris,
client_type: formState.clientType,
};
if (formState.description.trim()) {
requestData.description = formState.description.trim();
}
if (formState.homepageUrl.trim()) {
requestData.homepage_url = formState.homepageUrl.trim();
}
if (formState.privacyPolicyUrl.trim()) {
requestData.privacy_policy_url = formState.privacyPolicyUrl.trim();
}
if (formState.termsOfServiceUrl.trim()) {
requestData.terms_of_service_url = formState.termsOfServiceUrl.trim();
}
const response = await registerClient({
data: requestData,
});
if (response.status === 200) {
const secretData = response.data as ClientSecretResponseWithWebhook;
setNewClientSecret(secretData);
setIsCreateOpen(false);
setIsSecretDialogOpen(true);
resetForm();
toast({
title: "Success",
description: "OAuth client created successfully",
});
}
} catch {
toast({
title: "Error",
description: "Failed to create OAuth client",
variant: "destructive",
});
}
}
async function handleRotateSecret(clientId: string) {
try {
const response = await rotateSecret({ clientId });
if (response.status === 200) {
const secretData = response.data as ClientSecretResponseWithWebhook;
setNewClientSecret(secretData);
setIsSecretDialogOpen(true);
toast({
title: "Success",
description: "Client secret rotated successfully",
});
}
} catch {
toast({
title: "Error",
description: "Failed to rotate client secret",
variant: "destructive",
});
}
}
function handleCopyClientId() {
if (newClientSecret?.client_id) {
navigator.clipboard.writeText(newClientSecret.client_id);
toast({
title: "Copied",
description: "Client ID copied to clipboard",
});
}
}
function handleCopyClientSecret() {
if (newClientSecret?.client_secret) {
navigator.clipboard.writeText(newClientSecret.client_secret);
toast({
title: "Copied",
description: "Client secret copied to clipboard",
});
}
}
function handleCopyWebhookSecret() {
if (newClientSecret?.webhook_secret) {
navigator.clipboard.writeText(newClientSecret.webhook_secret);
toast({
title: "Copied",
description: "Webhook secret copied to clipboard",
});
}
}
function handleSecretDialogChange(open: boolean) {
setIsSecretDialogOpen(open);
if (!open) {
setNewClientSecret(null);
}
}
return {
isCreateOpen,
setIsCreateOpen,
isSecretDialogOpen,
setIsSecretDialogOpen: handleSecretDialogChange,
formState,
setFormState,
newClientSecret,
isCreating,
isRotating,
handleCreateClient,
handleRotateSecret,
handleCopyClientId,
handleCopyClientSecret,
handleCopyWebhookSecret,
resetForm,
};
}

View File

@@ -0,0 +1,381 @@
"use client";
import { CircleNotch, Copy, DotsThreeVertical } from "@phosphor-icons/react";
import { Button } from "@/components/__legacy__/ui/button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
import { Badge } from "@/components/__legacy__/ui/badge";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/__legacy__/ui/dropdown-menu";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/__legacy__/ui/dialog";
import { Label } from "@/components/__legacy__/ui/label";
import { Input } from "@/components/__legacy__/ui/input";
import { Textarea } from "@/components/__legacy__/ui/textarea";
import { useOAuthClientSection } from "./useOAuthClientSection";
export function OAuthClientSection() {
const {
oauthClients,
isLoading,
isDeleting,
isSuspending,
isActivating,
isRotatingWebhookSecret,
isUpdating,
handleDeleteClient,
handleSuspendClient,
handleActivateClient,
handleRotateWebhookSecret,
handleCopyWebhookSecret,
handleEditClient,
handleSaveClient,
webhookSecretDialogOpen,
setWebhookSecretDialogOpen,
newWebhookSecret,
editDialogOpen,
setEditDialogOpen,
editingClient,
editFormState,
setEditFormState,
} = useOAuthClientSection();
const isActionPending =
isDeleting ||
isSuspending ||
isActivating ||
isRotatingWebhookSecret ||
isUpdating;
return (
<>
{isLoading ? (
<div className="flex justify-center p-4">
<CircleNotch className="h-6 w-6 animate-spin" weight="bold" />
</div>
) : oauthClients && oauthClients.length > 0 ? (
<Table>
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>Client ID</TableHead>
<TableHead>Type</TableHead>
<TableHead>Status</TableHead>
<TableHead>Created</TableHead>
<TableHead></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{oauthClients.map((client) => (
<TableRow key={client.id} data-testid="oauth-client-row">
<TableCell>
<div className="flex flex-col">
<span className="font-medium">{client.name}</span>
{client.description && (
<span className="text-xs text-muted-foreground">
{client.description}
</span>
)}
</div>
</TableCell>
<TableCell data-testid="oauth-client-id">
<div className="rounded-md border p-1 px-2 font-mono text-xs">
{client.client_id}
</div>
</TableCell>
<TableCell>
<Badge variant="outline">
{client.client_type === "confidential"
? "Confidential"
: "Public"}
</Badge>
</TableCell>
<TableCell>
<Badge
variant={
client.status === "ACTIVE" ? "default" : "destructive"
}
className={
client.status === "ACTIVE"
? "border-green-600 bg-green-100 text-green-800"
: "border-red-600 bg-red-100 text-red-800"
}
>
{client.status}
</Badge>
</TableCell>
<TableCell>
{new Date(client.created_at).toLocaleDateString()}
</TableCell>
<TableCell>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
data-testid="oauth-client-actions"
variant="ghost"
size="sm"
>
<DotsThreeVertical className="h-4 w-4" weight="bold" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={() => handleEditClient(client)}
disabled={isActionPending}
>
Edit
</DropdownMenuItem>
<DropdownMenuItem
onClick={() =>
handleRotateWebhookSecret(client.client_id)
}
disabled={isActionPending}
>
Rotate Webhook Secret
</DropdownMenuItem>
<DropdownMenuSeparator />
{client.status === "ACTIVE" ? (
<DropdownMenuItem
onClick={() => handleSuspendClient(client.client_id)}
disabled={isActionPending}
>
Suspend
</DropdownMenuItem>
) : (
<DropdownMenuItem
onClick={() => handleActivateClient(client.client_id)}
disabled={isActionPending}
>
Activate
</DropdownMenuItem>
)}
<DropdownMenuSeparator />
<DropdownMenuItem
className="text-destructive"
onClick={() => handleDeleteClient(client.client_id)}
disabled={isActionPending}
>
Delete
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
) : (
<div className="py-8 text-center text-muted-foreground">
No OAuth clients registered yet. Create one to get started.
</div>
)}
<Dialog
open={webhookSecretDialogOpen}
onOpenChange={setWebhookSecretDialogOpen}
>
<DialogContent className="sm:max-w-[525px]">
<DialogHeader>
<DialogTitle>Webhook Secret Rotated</DialogTitle>
<DialogDescription>
Your new webhook secret has been generated. Please copy it now as
it will not be shown again.
</DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label>New Webhook Secret</Label>
<div className="flex items-center space-x-2">
<code className="flex-1 break-all rounded-md bg-secondary p-2 font-mono text-sm">
{newWebhookSecret}
</code>
<Button
size="icon"
variant="outline"
onClick={handleCopyWebhookSecret}
>
<Copy className="h-4 w-4" weight="bold" />
</Button>
</div>
<p className="text-xs text-muted-foreground">
Use this secret to verify webhook signatures (HMAC-SHA256)
</p>
</div>
<p className="text-xs text-destructive">
This secret will only be shown once. Store it securely!
</p>
</div>
<DialogFooter>
<Button onClick={() => setWebhookSecretDialogOpen(false)}>
Close
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<Dialog open={editDialogOpen} onOpenChange={setEditDialogOpen}>
<DialogContent className="max-h-[90vh] overflow-y-auto sm:max-w-[525px]">
<DialogHeader>
<DialogTitle>Edit OAuth Client</DialogTitle>
<DialogDescription>
Update your OAuth client settings. Changes will take effect
immediately.
</DialogDescription>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="edit-name">Name</Label>
<Input
id="edit-name"
value={editFormState.name ?? ""}
onChange={(e) =>
setEditFormState((prev) => ({
...prev,
name: e.target.value,
}))
}
placeholder="My Application"
maxLength={100}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="edit-description">Description</Label>
<Input
id="edit-description"
value={editFormState.description ?? ""}
onChange={(e) =>
setEditFormState((prev) => ({
...prev,
description: e.target.value,
}))
}
placeholder="A brief description of your application"
maxLength={500}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="edit-redirectUris">Redirect URIs</Label>
<Textarea
id="edit-redirectUris"
value={editFormState.redirect_uris?.join("\n") ?? ""}
onChange={(e) =>
setEditFormState((prev) => ({
...prev,
redirect_uris: e.target.value
.split(/[\n,]/)
.map((uri) => uri.trim())
.filter(Boolean),
}))
}
placeholder="https://myapp.com/callback&#10;https://localhost:3000/callback"
rows={3}
/>
<p className="text-xs text-muted-foreground">
Enter one URI per line or separate with commas
</p>
</div>
<div className="grid gap-2">
<Label htmlFor="edit-webhookDomains">Webhook Domains</Label>
<Textarea
id="edit-webhookDomains"
value={editFormState.webhook_domains?.join("\n") ?? ""}
onChange={(e) =>
setEditFormState((prev) => ({
...prev,
webhook_domains: e.target.value
.split(/[\n,]/)
.map((domain) => domain.trim())
.filter(Boolean),
}))
}
placeholder="https://myapp.com&#10;https://api.myapp.com"
rows={3}
/>
<p className="text-xs text-muted-foreground">
Domains that can receive webhook notifications
</p>
</div>
<div className="grid gap-2">
<Label htmlFor="edit-homepageUrl">Homepage URL</Label>
<Input
id="edit-homepageUrl"
type="url"
value={editFormState.homepage_url ?? ""}
onChange={(e) =>
setEditFormState((prev) => ({
...prev,
homepage_url: e.target.value || undefined,
}))
}
placeholder="https://myapp.com"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="edit-privacyPolicyUrl">Privacy Policy URL</Label>
<Input
id="edit-privacyPolicyUrl"
type="url"
value={editFormState.privacy_policy_url ?? ""}
onChange={(e) =>
setEditFormState((prev) => ({
...prev,
privacy_policy_url: e.target.value || undefined,
}))
}
placeholder="https://myapp.com/privacy"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="edit-termsOfServiceUrl">
Terms of Service URL
</Label>
<Input
id="edit-termsOfServiceUrl"
type="url"
value={editFormState.terms_of_service_url ?? ""}
onChange={(e) =>
setEditFormState((prev) => ({
...prev,
terms_of_service_url: e.target.value || undefined,
}))
}
placeholder="https://myapp.com/terms"
/>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setEditDialogOpen(false)}>
Cancel
</Button>
<Button onClick={handleSaveClient} disabled={isUpdating}>
{isUpdating ? "Saving..." : "Save Changes"}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</>
);
}

View File

@@ -0,0 +1,242 @@
"use client";
import { useState } from "react";
import {
getGetOauthClientsListClientsQueryKey,
useDeleteOauthClientsDeleteClient,
useGetOauthClientsListClients,
usePatchOauthClientsUpdateClient,
usePostOauthClientsActivateClient,
usePostOauthClientsRotateWebhookSecret,
usePostOauthClientsSuspendClient,
} from "@/app/api/__generated__/endpoints/oauth-clients/oauth-clients";
import type { ClientResponse } from "@/app/api/__generated__/models/clientResponse";
import type { UpdateClientRequest } from "@/app/api/__generated__/models/updateClientRequest";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { validateRedirectUri } from "@/lib/utils";
export function useOAuthClientSection() {
const queryClient = getQueryClient();
const { toast } = useToast();
const [webhookSecretDialogOpen, setWebhookSecretDialogOpen] = useState(false);
const [newWebhookSecret, setNewWebhookSecret] = useState<string | null>(null);
const [editDialogOpen, setEditDialogOpen] = useState(false);
const [editingClient, setEditingClient] = useState<ClientResponse | null>(
null,
);
const [editFormState, setEditFormState] = useState<UpdateClientRequest>({});
const { data: oauthClients, isLoading } = useGetOauthClientsListClients({
query: {
select: (res) => {
if (res.status !== 200) return undefined;
return res.data;
},
},
});
const { mutateAsync: deleteClient, isPending: isDeleting } =
useDeleteOauthClientsDeleteClient({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetOauthClientsListClientsQueryKey(),
});
},
},
});
const { mutateAsync: suspendClient, isPending: isSuspending } =
usePostOauthClientsSuspendClient({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetOauthClientsListClientsQueryKey(),
});
},
},
});
const { mutateAsync: activateClient, isPending: isActivating } =
usePostOauthClientsActivateClient({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetOauthClientsListClientsQueryKey(),
});
},
},
});
const {
mutateAsync: rotateWebhookSecret,
isPending: isRotatingWebhookSecret,
} = usePostOauthClientsRotateWebhookSecret();
const { mutateAsync: updateClient, isPending: isUpdating } =
usePatchOauthClientsUpdateClient({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetOauthClientsListClientsQueryKey(),
});
},
},
});
async function handleDeleteClient(clientId: string) {
try {
await deleteClient({ clientId });
toast({
title: "Success",
description: "OAuth client deleted successfully",
});
} catch {
toast({
title: "Error",
description: "Failed to delete OAuth client",
variant: "destructive",
});
}
}
async function handleSuspendClient(clientId: string) {
try {
await suspendClient({ clientId });
toast({
title: "Success",
description: "OAuth client suspended successfully",
});
} catch {
toast({
title: "Error",
description: "Failed to suspend OAuth client",
variant: "destructive",
});
}
}
async function handleActivateClient(clientId: string) {
try {
await activateClient({ clientId });
toast({
title: "Success",
description: "OAuth client activated successfully",
});
} catch {
toast({
title: "Error",
description: "Failed to activate OAuth client",
variant: "destructive",
});
}
}
async function handleRotateWebhookSecret(clientId: string) {
try {
const response = await rotateWebhookSecret({ clientId });
if (response.status === 200) {
setNewWebhookSecret(response.data.webhook_secret);
setWebhookSecretDialogOpen(true);
toast({
title: "Success",
description: "Webhook secret rotated successfully",
});
}
} catch {
toast({
title: "Error",
description: "Failed to rotate webhook secret",
variant: "destructive",
});
}
}
function handleCopyWebhookSecret() {
if (newWebhookSecret) {
navigator.clipboard.writeText(newWebhookSecret);
toast({
title: "Copied",
description: "Webhook secret copied to clipboard",
});
}
}
function handleEditClient(client: ClientResponse) {
setEditingClient(client);
setEditFormState({
name: client.name,
description: client.description ?? undefined,
homepage_url: client.homepage_url ?? undefined,
privacy_policy_url: client.privacy_policy_url ?? undefined,
terms_of_service_url: client.terms_of_service_url ?? undefined,
redirect_uris: client.redirect_uris,
webhook_domains: client.webhook_domains,
});
setEditDialogOpen(true);
}
async function handleSaveClient() {
if (!editingClient) return;
// Validate redirect URIs before saving
if (editFormState.redirect_uris) {
for (const uri of editFormState.redirect_uris) {
const validation = validateRedirectUri(uri);
if (!validation.valid) {
toast({
title: "Invalid Redirect URI",
description: `"${uri}": ${validation.error}`,
variant: "destructive",
});
return;
}
}
}
try {
await updateClient({
clientId: editingClient.client_id,
data: editFormState,
});
toast({
title: "Success",
description: "OAuth client updated successfully",
});
setEditDialogOpen(false);
setEditingClient(null);
} catch {
toast({
title: "Error",
description: "Failed to update OAuth client",
variant: "destructive",
});
}
}
return {
oauthClients,
isLoading,
isDeleting,
isSuspending,
isActivating,
isRotatingWebhookSecret,
isUpdating,
handleDeleteClient,
handleSuspendClient,
handleActivateClient,
handleRotateWebhookSecret,
handleCopyWebhookSecret,
handleEditClient,
handleSaveClient,
webhookSecretDialogOpen,
setWebhookSecretDialogOpen,
newWebhookSecret,
editDialogOpen,
setEditDialogOpen,
editingClient,
editFormState,
setEditFormState,
};
}

View File

@@ -0,0 +1,37 @@
import { Metadata } from "next/types";
import { OAuthClientSection } from "@/app/(platform)/profile/(user)/developer/components/OAuthClientSection/OAuthClientSection";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/__legacy__/ui/card";
import { OAuthClientModals } from "./components/OAuthClientModals/OAuthClientModals";
export const metadata: Metadata = {
title: "Developer Settings - AutoGPT Platform",
};
function DeveloperPage() {
return (
<div className="w-full pr-4 pt-24 md:pt-0">
<Card>
<CardHeader>
<CardTitle>OAuth Applications</CardTitle>
<CardDescription>
Register and manage OAuth clients to integrate third-party
applications with the AutoGPT Platform. OAuth clients allow external
applications to access AutoGPT APIs on behalf of users.
</CardDescription>
</CardHeader>
<CardContent>
<OAuthClientModals />
<OAuthClientSection />
</CardContent>
</Card>
</div>
);
}
export default DeveloperPage;

File diff suppressed because it is too large Load Diff

View File

@@ -883,6 +883,30 @@ export const IconUploadCloud = createIcon((props) => (
</svg>
));
/**
* Code icon component for developer settings.
*
* @component IconCode
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
* @returns {JSX.Element} - The code icon.
*/
export const IconCode = createIcon((props) => (
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
aria-label="Code Icon"
{...props}
>
<polyline points="16 18 22 12 16 6" />
<polyline points="8 6 2 12 8 18" />
</svg>
));
/**
* Chevron up icon component.
*
@@ -1838,6 +1862,7 @@ export enum IconType {
AutoGPTLogo,
Sliders,
Chat,
Code,
}
export function getIconForSocial(

View File

@@ -1,5 +1,6 @@
import {
IconBuilder,
IconCode,
IconEdit,
IconLibrary,
IconLogOut,
@@ -130,10 +131,15 @@ export function getAccountMenuItems(userRole?: string): MenuItemGroup[] {
});
}
// Add settings and logout
// Add developer settings and settings
baseMenuItems.push(
{
items: [
{
icon: IconType.Code,
text: "Developer",
href: "/profile/developer",
},
{
icon: IconType.Settings,
text: "Settings",
@@ -177,6 +183,8 @@ export function getAccountMenuOptionIcon(icon: IconType) {
return <IconSliders className={iconClass} />;
case IconType.Chat:
return <ChatsIcon className={iconClass} />;
case IconType.Code:
return <IconCode className={iconClass} />;
default:
return <IconRefresh className={iconClass} />;
}

View File

@@ -445,3 +445,31 @@ export function validateYouTubeUrl(val: string): boolean {
return false;
}
}
/**
* Validate OAuth redirect URI.
* Allows HTTPS URLs or http://localhost for development.
*/
export function validateRedirectUri(uri: string): {
valid: boolean;
error?: string;
} {
try {
const url = new URL(uri);
if (url.protocol === "https:") {
return { valid: true };
}
if (
url.protocol === "http:" &&
(url.hostname === "localhost" || url.hostname === "127.0.0.1")
) {
return { valid: true };
}
return {
valid: false,
error: "Must be HTTPS (or http://localhost for development)",
};
} catch {
return { valid: false, error: "Invalid URL format" };
}
}

View File

@@ -0,0 +1,701 @@
# External API Integration Guide
This guide explains how third-party applications can integrate with AutoGPT Platform to execute agents on behalf of users using the OAuth Provider and Credential Broker system.
## Overview
The AutoGPT External API allows your application to:
- **Execute agents** - Run user-owned or marketplace agents with user-granted credentials
- **Access integrations** - Use third-party service credentials (Google, GitHub, etc.) that users have connected
- **Receive webhooks** - Get notified when agent executions complete
The integration uses standard OAuth 2.0 with PKCE for secure authentication, with API key-based user identification during the authorization flow.
## Architecture
```
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
│ Your App │────▶│ AutoGPT OAuth │────▶│ External API │
│ │ │ Provider │ │ │
└─────────────────┘ └──────────────────┘ └─────────────────┘
┌──────────────────┐
│ Credential │
│ Broker │
└──────────────────┘
```
**Key concepts:**
1. **OAuth Client** - Your registered application with AutoGPT
2. **API Key** - User's AutoGPT API key for identifying the user during authorization
3. **OAuth Tokens** - Access/refresh tokens for API authentication
4. **Credential Grants** - User permissions to use their connected integrations
5. **Integration Scopes** - Specific permissions for each integration (e.g., `google:gmail.readonly`)
## Getting Started
### 1. Register Your OAuth Client
Register your application to get a `client_id` and `client_secret`:
```bash
# Requires user authentication (JWT token)
curl -X POST https://platform.agpt.co/oauth/clients/ \
-H "Authorization: Bearer <user_jwt>" \
-H "Content-Type: application/json" \
-d '{
"name": "My App",
"description": "Description of your app",
"client_type": "confidential",
"redirect_uris": ["https://myapp.com/oauth/callback"],
"webhook_domains": ["myapp.com", "*.myapp.com"]
}'
```
**Response:**
```json
{
"client_id": "app_abc123xyz",
"client_secret": "secret_xyz789..."
}
```
> **Important:** Store the `client_secret` securely - it's only shown once!
**Client types:**
- `confidential` - Server-side apps that can securely store secrets
- `public` - Browser/mobile apps (no client secret)
### 2. OAuth Authorization Flow
Use the standard OAuth 2.0 Authorization Code flow with PKCE to get user consent and tokens.
#### Authentication Methods
The authorization endpoint supports two authentication methods:
1. **API Key (Recommended for server-side apps)**: Pass the user's AutoGPT API key via `X-API-Key` header. This shows the consent page directly without requiring a browser login.
2. **Login Flow (For browser-based apps)**: If no API key is provided, the user is redirected to the AutoGPT login page, which then continues the OAuth flow automatically after login.
#### Generate PKCE Parameters
```javascript
// Generate code verifier and challenge
function generateCodeVerifier() {
const array = new Uint8Array(32);
crypto.getRandomValues(array);
return base64UrlEncode(array);
}
async function generateCodeChallenge(verifier) {
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const digest = await crypto.subtle.digest('SHA-256', data);
return base64UrlEncode(new Uint8Array(digest));
}
```
#### Request Authorization
**Option 1: With API Key (Server-side apps)**
If you have the user's API key, make a server-side request to get the consent page directly:
```javascript
const state = crypto.randomUUID(); // Store this for validation
const codeVerifier = generateCodeVerifier(); // Store this securely
const codeChallenge = await generateCodeChallenge(codeVerifier);
const authUrl = 'https://platform.agpt.co/oauth/authorize?' + new URLSearchParams({
response_type: 'code',
client_id: CLIENT_ID,
redirect_uri: REDIRECT_URI,
state: state,
code_challenge: codeChallenge,
code_challenge_method: 'S256',
scope: 'openid profile email agents:execute integrations:connect',
});
// Server-side request with user's API key
const response = await fetch(authUrl, {
headers: {
'X-API-Key': userApiKey, // User's AutoGPT API key (optional)
},
});
// Response is HTML consent page - render it to the user
const consentHtml = await response.text();
```
**Option 2: Browser Login Flow (No API Key)**
If you don't have the user's API key, redirect them to the authorization URL. They will be prompted to log in to AutoGPT, and the OAuth flow will continue automatically after login:
```javascript
const state = crypto.randomUUID(); // Store this for validation
const codeVerifier = generateCodeVerifier(); // Store this securely
const codeChallenge = await generateCodeChallenge(codeVerifier);
const authUrl = 'https://platform.agpt.co/oauth/authorize?' + new URLSearchParams({
response_type: 'code',
client_id: CLIENT_ID,
redirect_uri: REDIRECT_URI,
state: state,
code_challenge: codeChallenge,
code_challenge_method: 'S256',
scope: 'openid profile email agents:execute integrations:connect',
});
// Redirect user to AutoGPT login page
// After login, they'll see the consent page and be redirected to your callback
window.location.href = authUrl;
```
**Authentication methods (in order of preference):**
| Method | Header | Description |
|--------|--------|-------------|
| API Key | `X-API-Key: agpt_xxx` | User's AutoGPT API key (preferred for external apps) |
| JWT Token | `Authorization: Bearer <jwt>` | Supabase JWT token |
| Cookie | `access_token` cookie | Browser-based authentication |
**Available scopes:**
| Scope | Description |
|-------|-------------|
| `openid` | Required for OIDC |
| `profile` | Access user profile (name) |
| `email` | Access user email |
| `agents:execute` | Execute agents and check status |
| `integrations:list` | List user's credential grants |
| `integrations:connect` | Request new credential grants |
| `integrations:delete` | Delete credentials via grants |
#### Handle OAuth Callback
```javascript
// Your callback endpoint receives: ?code=xxx&state=xxx
app.get('/oauth/callback', async (req, res) => {
const { code, state } = req.query;
// Verify state matches what you stored
if (state !== storedState) {
return res.status(400).send('Invalid state');
}
// Exchange code for tokens
const response = await fetch('https://platform.agpt.co/oauth/token', {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: new URLSearchParams({
grant_type: 'authorization_code',
code,
redirect_uri: REDIRECT_URI,
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code_verifier: storedCodeVerifier,
}),
});
const tokens = await response.json();
// { access_token, refresh_token, token_type, expires_in }
});
```
### 3. Request Credential Grants (Connect Flow)
Before executing agents that require integrations (like Gmail, Google Sheets, GitHub), you need credential grants from the user.
#### Open Connect Popup
```javascript
function requestCredentialGrant(provider, scopes) {
const nonce = crypto.randomUUID();
// Store nonce to validate response
sessionStorage.setItem('connect_nonce', nonce);
const connectUrl = new URL(`https://platform.agpt.co/connect/${provider}`);
connectUrl.searchParams.set('client_id', CLIENT_ID);
connectUrl.searchParams.set('scopes', scopes.join(','));
connectUrl.searchParams.set('nonce', nonce);
connectUrl.searchParams.set('redirect_origin', window.location.origin);
// Open popup (user must be logged into AutoGPT)
const popup = window.open(
connectUrl.toString(),
'AutoGPT Connect',
'width=500,height=600,popup=true'
);
// Listen for result
window.addEventListener('message', handleConnectResult, { once: true });
}
function handleConnectResult(event) {
// Verify origin
if (event.origin !== 'https://platform.agpt.co') return;
const data = event.data;
if (data.type !== 'autogpt_connect_result') return;
// Verify nonce
if (data.nonce !== sessionStorage.getItem('connect_nonce')) return;
if (data.success) {
console.log('Grant created:', data.grant_id);
console.log('Credential ID:', data.credential_id);
console.log('Provider:', data.provider);
} else {
console.error('Connect failed:', data.error, data.error_description);
}
}
```
**Integration scopes by provider:**
| Provider | Available Scopes |
|----------|------------------|
| Google | `google:gmail.readonly`, `google:gmail.send`, `google:sheets.read`, `google:sheets.write`, `google:calendar.read`, `google:calendar.write`, `google:drive.read`, `google:drive.write` |
| GitHub | `github:repo.read`, `github:repo.write`, `github:user.read` |
| Twitter/X | `twitter:tweet.read`, `twitter:tweet.write`, `twitter:user.read` |
| Linear | `linear:read`, `linear:write` |
| Notion | `notion:read`, `notion:write` |
| Slack | `slack:read`, `slack:write` |
### 4. Execute Agents
With an OAuth token and credential grants, you can execute agents:
```javascript
async function executeAgent(agentId, inputs, grantIds = null, webhookUrl = null) {
const response = await fetch(
`https://platform.agpt.co/api/external/v1/executions/agents/${agentId}/execute`,
{
method: 'POST',
headers: {
'Authorization': `Bearer ${accessToken}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
inputs,
grant_ids: grantIds, // Optional: specific grants to use
webhook_url: webhookUrl, // Optional: receive completion webhook
}),
}
);
const result = await response.json();
// { execution_id, status: "queued", message: "..." }
return result;
}
```
### 5. Check Execution Status
Poll for execution status or use webhooks:
```javascript
async function getExecutionStatus(executionId) {
const response = await fetch(
`https://platform.agpt.co/api/external/v1/executions/${executionId}`,
{
headers: { 'Authorization': `Bearer ${accessToken}` },
}
);
return await response.json();
// {
// execution_id,
// status: "queued" | "running" | "completed" | "failed",
// started_at,
// completed_at,
// outputs, // Present when completed
// error, // Present when failed
// }
}
```
### 6. Handle Webhooks
If you provided a `webhook_url`, you'll receive POST requests with execution events:
```javascript
app.post('/webhooks/autogpt', (req, res) => {
// Verify webhook signature (if configured)
const signature = req.headers['x-webhook-signature'];
const timestamp = req.headers['x-webhook-timestamp'];
if (signature) {
const expectedSignature = crypto
.createHmac('sha256', WEBHOOK_SECRET)
.update(JSON.stringify(req.body))
.digest('hex');
if (signature !== `sha256=${expectedSignature}`) {
return res.status(401).send('Invalid signature');
}
}
const { event, timestamp, data } = req.body;
switch (event) {
case 'execution.started':
console.log(`Execution ${data.execution_id} started`);
break;
case 'execution.completed':
console.log(`Execution ${data.execution_id} completed`, data.outputs);
break;
case 'execution.failed':
console.error(`Execution ${data.execution_id} failed:`, data.error);
break;
case 'grant.revoked':
console.log(`Grant ${data.grant_id} was revoked`);
break;
}
res.status(200).send('OK');
});
```
> **Note:** Webhook URLs must match domains registered in your OAuth client's `webhook_domains`.
## API Reference
### External API Endpoints
Base URL: `https://platform.agpt.co/api/external/v1`
| Method | Endpoint | Scope Required | Description |
|--------|----------|----------------|-------------|
| GET | `/executions/capabilities` | `agents:execute` | Get available grants and scopes |
| POST | `/executions/agents/{agent_id}/execute` | `agents:execute` | Execute an agent |
| GET | `/executions/{execution_id}` | `agents:execute` | Get execution status |
| POST | `/executions/{execution_id}/cancel` | `agents:execute` | Cancel execution |
| GET | `/grants/` | `integrations:list` | List credential grants |
| GET | `/grants/{grant_id}` | `integrations:list` | Get grant details |
| DELETE | `/grants/{grant_id}/credential` | `integrations:delete` | Delete credential via grant |
### OAuth Endpoints
Base URL: `https://platform.agpt.co`
| Method | Endpoint | Auth Required | Description |
|--------|----------|---------------|-------------|
| GET | `/oauth/authorize` | API Key or JWT | Authorization endpoint |
| POST | `/oauth/token` | None | Token endpoint |
| GET | `/oauth/userinfo` | OAuth Token | OIDC UserInfo |
| POST | `/oauth/revoke` | OAuth Token | Revoke tokens |
| GET | `/.well-known/openid-configuration` | None | OIDC Discovery |
### Client Management Endpoints
| Method | Endpoint | Description |
|--------|----------|-------------|
| POST | `/oauth/clients/` | Register new client |
| GET | `/oauth/clients/` | List your clients |
| GET | `/oauth/clients/{client_id}` | Get client details |
| PATCH | `/oauth/clients/{client_id}` | Update client |
| DELETE | `/oauth/clients/{client_id}` | Delete client |
| POST | `/oauth/clients/{client_id}/rotate-secret` | Rotate client secret |
## Rate Limits
| Endpoint Type | Limit |
|--------------|-------|
| OAuth endpoints | 20-30 requests/minute |
| Agent execution | 10 requests/minute, 100/hour |
| Read endpoints | 60 requests/minute, 1000/hour |
Rate limit headers are included in responses:
- `X-RateLimit-Remaining` - Requests remaining in current window
- `X-RateLimit-Reset` - Unix timestamp when limit resets
- `Retry-After` - Seconds to wait (when rate limited)
## Error Handling
### OAuth Errors
```json
{
"error": "invalid_grant",
"error_description": "Authorization code has expired"
}
```
Common OAuth errors:
- `invalid_client` - Unknown or invalid client
- `invalid_grant` - Expired/invalid authorization code
- `access_denied` - User denied consent
- `invalid_scope` - Requested scope not allowed
- `login_required` - User authentication required (provide API key or JWT)
### API Errors
```json
{
"detail": "Grant validation failed: No valid grants found for requested integrations"
}
```
HTTP status codes:
- `400` - Bad request (invalid parameters)
- `401` - Unauthorized (invalid/expired token or missing API key)
- `403` - Forbidden (insufficient scopes or grants)
- `404` - Resource not found
- `429` - Rate limited
- `500` - Internal server error
## Security Best Practices
1. **Store secrets securely** - Never expose `client_secret` in client-side code
2. **Protect API keys** - User API keys should be stored securely and transmitted over HTTPS
3. **Validate state parameter** - Prevent CSRF attacks
4. **Use PKCE** - Required for all authorization flows
5. **Verify webhook signatures** - Prevent spoofed webhooks
6. **Request minimal scopes** - Only request what you need
7. **Handle token refresh** - Refresh tokens before they expire
8. **Validate redirect origins** - Only accept messages from expected origins
## Complete Integration Example
```javascript
class AutoGPTClient {
constructor(clientId, clientSecret, redirectUri) {
this.clientId = clientId;
this.clientSecret = clientSecret;
this.redirectUri = redirectUri;
this.baseUrl = 'https://platform.agpt.co';
}
// Step 1: Build authorization URL
async buildAuthorizationUrl(scopes) {
const state = crypto.randomUUID();
const codeVerifier = this.generateCodeVerifier();
const codeChallenge = await this.generateCodeChallenge(codeVerifier);
// Store for callback
this.pendingAuth = { state, codeVerifier };
const params = new URLSearchParams({
response_type: 'code',
client_id: this.clientId,
redirect_uri: this.redirectUri,
state: state,
code_challenge: codeChallenge,
code_challenge_method: 'S256',
scope: scopes.join(' '),
});
return `${this.baseUrl}/oauth/authorize?${params}`;
}
// Step 1a: Start authorization with user's API key (server-side apps)
async startAuthorizationWithApiKey(userApiKey, scopes) {
const authUrl = await this.buildAuthorizationUrl(scopes);
// Request authorization with user's API key
const response = await fetch(authUrl, {
headers: {
'X-API-Key': userApiKey,
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Authorization failed: ${error}`);
}
// Return consent page HTML for rendering to user
return await response.text();
}
// Step 1b: Start authorization via login flow (browser-based apps)
async startAuthorizationWithLogin(scopes) {
const authUrl = await this.buildAuthorizationUrl(scopes);
// Redirect user to AutoGPT login - they'll be redirected back after login
window.location.href = authUrl;
}
// Step 2: Exchange code for tokens (after user consents)
async exchangeCode(code, state) {
if (state !== this.pendingAuth?.state) {
throw new Error('Invalid state');
}
const response = await fetch(`${this.baseUrl}/oauth/token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: new URLSearchParams({
grant_type: 'authorization_code',
code,
redirect_uri: this.redirectUri,
client_id: this.clientId,
client_secret: this.clientSecret,
code_verifier: this.pendingAuth.codeVerifier,
}),
});
if (!response.ok) {
throw new Error(await response.text());
}
return response.json();
}
// Step 3: Request credential grant via popup
requestGrant(provider, scopes) {
return new Promise((resolve, reject) => {
const nonce = crypto.randomUUID();
const url = new URL(`${this.baseUrl}/connect/${provider}`);
url.searchParams.set('client_id', this.clientId);
url.searchParams.set('scopes', scopes.join(','));
url.searchParams.set('nonce', nonce);
url.searchParams.set('redirect_origin', window.location.origin);
const popup = window.open(url.toString(), 'connect', 'width=500,height=600');
const handler = (event) => {
if (event.origin !== this.baseUrl) return;
if (event.data?.type !== 'autogpt_connect_result') return;
if (event.data?.nonce !== nonce) return;
window.removeEventListener('message', handler);
if (event.data.success) {
resolve(event.data);
} else {
reject(new Error(event.data.error_description));
}
};
window.addEventListener('message', handler);
});
}
// Step 4: Execute agent
async executeAgent(accessToken, agentId, inputs, options = {}) {
const response = await fetch(
`${this.baseUrl}/api/external/v1/executions/agents/${agentId}/execute`,
{
method: 'POST',
headers: {
'Authorization': `Bearer ${accessToken}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
inputs,
grant_ids: options.grantIds,
webhook_url: options.webhookUrl,
}),
}
);
if (!response.ok) {
throw new Error(await response.text());
}
return response.json();
}
// Step 5: Poll for completion
async waitForCompletion(accessToken, executionId, timeoutMs = 300000) {
const startTime = Date.now();
while (Date.now() - startTime < timeoutMs) {
const response = await fetch(
`${this.baseUrl}/api/external/v1/executions/${executionId}`,
{ headers: { 'Authorization': `Bearer ${accessToken}` } }
);
const status = await response.json();
if (status.status === 'completed') {
return status.outputs;
}
if (status.status === 'failed') {
throw new Error(status.error || 'Execution failed');
}
// Wait before polling again
await new Promise(resolve => setTimeout(resolve, 2000));
}
throw new Error('Execution timeout');
}
// Helper methods
generateCodeVerifier() {
const array = new Uint8Array(32);
crypto.getRandomValues(array);
return this.base64UrlEncode(array);
}
async generateCodeChallenge(verifier) {
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const digest = await crypto.subtle.digest('SHA-256', data);
return this.base64UrlEncode(new Uint8Array(digest));
}
base64UrlEncode(buffer) {
return btoa(String.fromCharCode(...buffer))
.replace(/\+/g, '-')
.replace(/\//g, '_')
.replace(/=+$/, '');
}
}
// Usage
const client = new AutoGPTClient(
'app_abc123',
'secret_xyz789',
'https://myapp.com/oauth/callback'
);
const scopes = ['openid', 'profile', 'agents:execute', 'integrations:connect'];
// Option A: Start authorization with user's API key (server-side apps)
// User provides their AutoGPT API key (from Settings > Developer)
const userApiKey = 'agpt_xxx...';
const consentHtml = await client.startAuthorizationWithApiKey(userApiKey, scopes);
// Render consent page HTML to user in popup/iframe
// Option B: Start authorization via login flow (browser-based apps)
// No API key needed - user will log in to AutoGPT
await client.startAuthorizationWithLogin(scopes);
// User is redirected to AutoGPT login, then back to your callback
// 2. After user consents, your callback receives ?code=xxx&state=xxx
// Exchange code for tokens
const tokens = await client.exchangeCode(code, state);
// 3. Request Google credentials (user must be logged into AutoGPT in browser)
const grant = await client.requestGrant('google', ['google:gmail.readonly']);
// 4. Execute an agent
const execution = await client.executeAgent(
tokens.access_token,
'agent-uuid-here',
{ query: 'Search my emails for invoices' },
{ grantIds: [grant.grant_id] }
);
// 5. Wait for results
const outputs = await client.waitForCompletion(
tokens.access_token,
execution.execution_id
);
console.log('Agent outputs:', outputs);
```
## Support
- [GitHub Issues](https://github.com/Significant-Gravitas/AutoGPT/issues) - Bug reports and feature requests
- [Discord Community](https://discord.gg/autogpt) - Community support

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,10 @@ nav:
- Using AI/ML API: platform/aimlapi.md
- Using D-ID: platform/d_id.md
- Blocks: platform/blocks/blocks.md
- External API Integration: platform/external-api-integration.md
- Framework Integrations:
- Next.js: platform/integrations/nextjs.md
- Ruby on Rails: platform/integrations/rails.md
- Contributing:
- Tests: platform/contributing/tests.md
- OAuth Flows: platform/contributing/oauth-integration-flow.md