mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
10 Commits
dev
...
swiftyos/o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c851414a9 | ||
|
|
7e145734c7 | ||
|
|
0d0c426209 | ||
|
|
2a4d474ca4 | ||
|
|
9e67f0bf45 | ||
|
|
0b25c643a4 | ||
|
|
9430ea2354 | ||
|
|
15033d8ebf | ||
|
|
741d1b40aa | ||
|
|
1acb18f5ff |
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
@@ -83,17 +83,17 @@ async def create_api_key(
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
data=APIKeyCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
head=generated_key.head,
|
||||
tail=generated_key.tail,
|
||||
hash=generated_key.hash,
|
||||
salt=generated_key.salt,
|
||||
permissions=permissions,
|
||||
description=description,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
|
||||
327
autogpt_platform/backend/backend/data/credential_grants.py
Normal file
327
autogpt_platform/backend/backend/data/credential_grants.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Credential Grant data layer.
|
||||
|
||||
Handles database operations for credential grants which allow OAuth clients
|
||||
to use credentials on behalf of users.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from prisma.enums import CredentialGrantPermission
|
||||
from prisma.models import CredentialGrant
|
||||
|
||||
from backend.data.db import prisma
|
||||
|
||||
|
||||
async def create_credential_grant(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
credential_id: str,
|
||||
provider: str,
|
||||
granted_scopes: list[str],
|
||||
permissions: list[CredentialGrantPermission],
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> CredentialGrant:
|
||||
"""
|
||||
Create a new credential grant.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user granting access
|
||||
client_id: Database ID of the OAuth client
|
||||
credential_id: ID of the credential being granted
|
||||
provider: Provider name (e.g., "google", "github")
|
||||
granted_scopes: List of integration scopes granted
|
||||
permissions: List of permissions (USE, DELETE)
|
||||
expires_at: Optional expiration datetime
|
||||
|
||||
Returns:
|
||||
Created CredentialGrant
|
||||
"""
|
||||
return await prisma.credentialgrant.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"credentialId": credential_id,
|
||||
"provider": provider,
|
||||
"grantedScopes": granted_scopes,
|
||||
"permissions": permissions,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_credential_grant(
|
||||
grant_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
) -> Optional[CredentialGrant]:
|
||||
"""
|
||||
Get a credential grant by ID.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
user_id: Optional user ID filter
|
||||
client_id: Optional client database ID filter
|
||||
|
||||
Returns:
|
||||
CredentialGrant or None
|
||||
"""
|
||||
where: dict[str, str] = {"id": grant_id}
|
||||
if user_id:
|
||||
where["userId"] = user_id
|
||||
if client_id:
|
||||
where["clientId"] = client_id
|
||||
|
||||
return await prisma.credentialgrant.find_first(where=where) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def get_grants_for_user_client(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
include_revoked: bool = False,
|
||||
include_expired: bool = False,
|
||||
) -> list[CredentialGrant]:
|
||||
"""
|
||||
Get all credential grants for a user-client pair.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client_id: Client database ID
|
||||
include_revoked: Include revoked grants
|
||||
include_expired: Include expired grants
|
||||
|
||||
Returns:
|
||||
List of CredentialGrant objects
|
||||
"""
|
||||
where: dict[str, str | None] = {
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
}
|
||||
|
||||
if not include_revoked:
|
||||
where["revokedAt"] = None
|
||||
|
||||
grants = await prisma.credentialgrant.find_many(
|
||||
where=where, # type: ignore[arg-type]
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
|
||||
# Filter expired if needed
|
||||
if not include_expired:
|
||||
now = datetime.now(timezone.utc)
|
||||
grants = [g for g in grants if g.expiresAt is None or g.expiresAt > now]
|
||||
|
||||
return grants
|
||||
|
||||
|
||||
async def get_grants_for_credential(
|
||||
user_id: str,
|
||||
credential_id: str,
|
||||
) -> list[CredentialGrant]:
|
||||
"""
|
||||
Get all active grants for a specific credential.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
credential_id: Credential ID
|
||||
|
||||
Returns:
|
||||
List of active CredentialGrant objects
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
grants = await prisma.credentialgrant.find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialId": credential_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
include={"Client": True},
|
||||
)
|
||||
|
||||
# Filter expired
|
||||
return [g for g in grants if g.expiresAt is None or g.expiresAt > now]
|
||||
|
||||
|
||||
async def get_grant_by_credential_and_client(
|
||||
user_id: str,
|
||||
credential_id: str,
|
||||
client_id: str,
|
||||
) -> Optional[CredentialGrant]:
|
||||
"""
|
||||
Get the grant for a specific credential and client.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
credential_id: Credential ID
|
||||
client_id: Client database ID
|
||||
|
||||
Returns:
|
||||
CredentialGrant or None
|
||||
"""
|
||||
return await prisma.credentialgrant.find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialId": credential_id,
|
||||
"clientId": client_id,
|
||||
"revokedAt": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def update_grant_scopes(
|
||||
grant_id: str,
|
||||
granted_scopes: list[str],
|
||||
) -> CredentialGrant:
|
||||
"""
|
||||
Update the granted scopes for a credential grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
granted_scopes: New list of granted scopes
|
||||
|
||||
Returns:
|
||||
Updated CredentialGrant
|
||||
"""
|
||||
result = await prisma.credentialgrant.update(
|
||||
where={"id": grant_id},
|
||||
data={"grantedScopes": granted_scopes},
|
||||
)
|
||||
if result is None:
|
||||
raise ValueError(f"Grant {grant_id} not found")
|
||||
return result
|
||||
|
||||
|
||||
async def update_grant_last_used(grant_id: str) -> None:
|
||||
"""
|
||||
Update the lastUsedAt timestamp for a grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
"""
|
||||
await prisma.credentialgrant.update(
|
||||
where={"id": grant_id},
|
||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
async def revoke_grant(grant_id: str) -> CredentialGrant:
|
||||
"""
|
||||
Revoke a credential grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
|
||||
Returns:
|
||||
Revoked CredentialGrant
|
||||
"""
|
||||
result = await prisma.credentialgrant.update(
|
||||
where={"id": grant_id},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
if result is None:
|
||||
raise ValueError(f"Grant {grant_id} not found")
|
||||
return result
|
||||
|
||||
|
||||
async def revoke_grants_for_credential(
|
||||
user_id: str,
|
||||
credential_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all grants for a specific credential.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
credential_id: Credential ID
|
||||
|
||||
Returns:
|
||||
Number of grants revoked
|
||||
"""
|
||||
return await prisma.credentialgrant.update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialId": credential_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
async def revoke_grants_for_client(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all grants for a specific client.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client_id: Client database ID
|
||||
|
||||
Returns:
|
||||
Number of grants revoked
|
||||
"""
|
||||
return await prisma.credentialgrant.update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
async def delete_grant(grant_id: str) -> None:
|
||||
"""
|
||||
Permanently delete a credential grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
"""
|
||||
await prisma.credentialgrant.delete(where={"id": grant_id})
|
||||
|
||||
|
||||
async def check_grant_permission(
|
||||
grant_id: str,
|
||||
required_permission: CredentialGrantPermission,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a grant has a specific permission.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
required_permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if grant has the permission
|
||||
"""
|
||||
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
|
||||
if not grant:
|
||||
return False
|
||||
|
||||
return required_permission in grant.permissions
|
||||
|
||||
|
||||
async def is_grant_valid(grant_id: str) -> bool:
|
||||
"""
|
||||
Check if a grant is valid (not revoked and not expired).
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
|
||||
Returns:
|
||||
True if grant is valid
|
||||
"""
|
||||
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
|
||||
if not grant:
|
||||
return False
|
||||
|
||||
if grant.revokedAt:
|
||||
return False
|
||||
|
||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||
update={"balance": 0},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -28,11 +29,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -41,7 +42,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||
update={"balance": 0},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -342,10 +346,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
|
||||
update={"balance": max_int - 100},
|
||||
),
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
|
||||
@@ -8,6 +8,7 @@ which would have caught the CreditTransactionType enum casting bug.
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserCreateInput
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -29,12 +30,12 @@ async def cleanup_test_user():
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
topUpConfig=SafeJson({}),
|
||||
timezone="UTC",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
|
||||
@@ -12,6 +12,12 @@ import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
UserBalanceCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -35,32 +41,32 @@ async def setup_test_user_with_topup():
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=REFUND_TEST_USER_ID,
|
||||
email=f"{REFUND_TEST_USER_ID}@example.com",
|
||||
name="Refund Test User",
|
||||
)
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
data=UserBalanceCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
balance=1000, # $10
|
||||
)
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
data=CreditTransactionCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
amount=1000,
|
||||
type=CreditTransactionType.TOP_UP,
|
||||
transactionKey="pi_test_12345",
|
||||
runningBalance=1000,
|
||||
isActive=True,
|
||||
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
)
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
@@ -93,12 +99,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
data=CreditRefundRequestCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
amount=500,
|
||||
transactionKey=topup_tx.transactionKey, # Should match the original transaction
|
||||
reason="Test refund",
|
||||
)
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
@@ -286,12 +292,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
data=CreditRefundRequestCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
amount=100, # $1 each
|
||||
transactionKey=topup_tx.transactionKey,
|
||||
reason=f"Test refund {i}",
|
||||
)
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
|
||||
@@ -3,6 +3,11 @@ from datetime import datetime, timedelta, timezone
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
from prisma.types import (
|
||||
CreditTransactionCreateInput,
|
||||
UserBalanceCreateInput,
|
||||
UserBalanceUpsertInput,
|
||||
)
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -23,10 +28,10 @@ async def disable_test_user_transactions():
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
|
||||
update={"balance": 0, "updatedAt": old_date},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -140,23 +145,23 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
data=CreditTransactionCreateInput(
|
||||
userId=DEFAULT_USER_ID,
|
||||
amount=100,
|
||||
type=CreditTransactionType.TOP_UP,
|
||||
runningBalance=1100,
|
||||
isActive=True,
|
||||
createdAt=month1, # Set specific timestamp
|
||||
)
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
|
||||
update={"balance": 1100},
|
||||
),
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
@@ -175,14 +180,14 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
data=CreditTransactionCreateInput(
|
||||
userId=DEFAULT_USER_ID,
|
||||
amount=-700, # Spent 700 to get to 400
|
||||
type=CreditTransactionType.USAGE,
|
||||
runningBalance=400,
|
||||
isActive=True,
|
||||
createdAt=month2,
|
||||
)
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
|
||||
@@ -12,6 +12,7 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||
update={"balance": 0},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -66,14 +70,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
initial_balance_target = POSTGRES_INT_MIN + 100
|
||||
|
||||
# Use direct database update to set the balance close to underflow
|
||||
from prisma.models import UserBalance
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(
|
||||
userId=user_id, balance=initial_balance_target
|
||||
),
|
||||
update={"balance": initial_balance_target},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -110,10 +114,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
|
||||
update={"balance": POSTGRES_INT_MIN},
|
||||
),
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
@@ -147,15 +151,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
# Set up balance close to underflow threshold to test the protection
|
||||
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
||||
# This should trigger underflow protection
|
||||
from prisma.models import UserBalance
|
||||
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
|
||||
update={"balance": test_balance},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -212,15 +214,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
||||
update={"balance": initial_balance},
|
||||
),
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
@@ -290,15 +290,13 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
||||
update={"balance": initial_balance},
|
||||
),
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
|
||||
@@ -14,6 +14,7 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -24,11 +25,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -121,7 +122,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
@@ -160,7 +161,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
await UserBalance.prisma().create(
|
||||
data=UserBalanceCreateInput(userId=user_id, balance=1000)
|
||||
)
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
|
||||
@@ -27,6 +27,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -34,7 +35,7 @@ from prisma.types import (
|
||||
AgentNodeExecutionKeyValueDataCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
_AgentNodeExecutionWhereUnique_id_Input,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
@@ -71,6 +72,13 @@ logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class GrantResolverContext(BaseModel):
|
||||
"""Context for grant-based credential resolution in external API executions."""
|
||||
|
||||
client_db_id: str # The OAuth client database UUID
|
||||
grant_ids: list[str] # List of grant IDs to use for credential resolution
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""
|
||||
Unified context that carries execution-level data throughout the entire execution flow.
|
||||
@@ -81,6 +89,8 @@ class ExecutionContext(BaseModel):
|
||||
user_timezone: str = "UTC"
|
||||
root_execution_id: Optional[str] = None
|
||||
parent_execution_id: Optional[str] = None
|
||||
# For external API executions using credential grants
|
||||
grant_resolver_context: Optional[GrantResolverContext] = None
|
||||
|
||||
|
||||
# -------------------------- Models -------------------------- #
|
||||
@@ -705,18 +715,18 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
inputs=SafeJson(inputs),
|
||||
credentialInputs=(
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
nodesInputMasks=(
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
NodeExecutions={
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
@@ -732,10 +742,10 @@ async def create_graph_execution(
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
parentGraphExecutionId=parent_graph_exec_id,
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -827,10 +837,10 @@ async def upsert_execution_output(
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
data: AgentNodeExecutionInputOutputCreateInput = {
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
data = AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
if output_data is not None:
|
||||
data["data"] = SafeJson(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
@@ -948,7 +958,7 @@ async def update_node_execution_status(
|
||||
|
||||
if res := await AgentNodeExecution.prisma().update(
|
||||
where=cast(
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
_AgentNodeExecutionWhereUnique_id_Input,
|
||||
{
|
||||
"id": node_exec_id,
|
||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
||||
|
||||
@@ -10,7 +10,11 @@ from typing import Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from prisma.types import (
|
||||
PendingHumanReviewCreateInput,
|
||||
PendingHumanReviewUpdateInput,
|
||||
PendingHumanReviewUpsertInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.server.v2.executions.review.model import (
|
||||
@@ -66,20 +70,20 @@ async def get_or_create_human_review(
|
||||
# Upsert - get existing or create new review
|
||||
review = await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": node_exec_id},
|
||||
data={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
},
|
||||
data=PendingHumanReviewUpsertInput(
|
||||
create=PendingHumanReviewCreateInput(
|
||||
userId=user_id,
|
||||
nodeExecId=node_exec_id,
|
||||
graphExecId=graph_exec_id,
|
||||
graphId=graph_id,
|
||||
graphVersion=graph_version,
|
||||
payload=SafeJson(input_data),
|
||||
instructions=message,
|
||||
editable=editable,
|
||||
status=ReviewStatus.WAITING,
|
||||
),
|
||||
update={}, # Do nothing on update - keep existing review as is
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
302
autogpt_platform/backend/backend/data/integration_scopes.py
Normal file
302
autogpt_platform/backend/backend/data/integration_scopes.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Integration scopes mapping.
|
||||
|
||||
Maps AutoGPT's fine-grained integration scopes to provider-specific OAuth scopes.
|
||||
These scopes are used to request granular permissions when connecting integrations
|
||||
through the Credential Broker.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class IntegrationScope(str, Enum):
|
||||
"""
|
||||
Fine-grained integration scopes for credential grants.
|
||||
|
||||
Format: {provider}:{resource}.{permission}
|
||||
"""
|
||||
|
||||
# Google scopes
|
||||
GOOGLE_EMAIL_READ = "google:email.read"
|
||||
GOOGLE_GMAIL_READONLY = "google:gmail.readonly"
|
||||
GOOGLE_GMAIL_SEND = "google:gmail.send"
|
||||
GOOGLE_GMAIL_MODIFY = "google:gmail.modify"
|
||||
GOOGLE_DRIVE_READONLY = "google:drive.readonly"
|
||||
GOOGLE_DRIVE_FILE = "google:drive.file"
|
||||
GOOGLE_CALENDAR_READONLY = "google:calendar.readonly"
|
||||
GOOGLE_CALENDAR_EVENTS = "google:calendar.events"
|
||||
GOOGLE_SHEETS_READONLY = "google:sheets.readonly"
|
||||
GOOGLE_SHEETS = "google:sheets"
|
||||
GOOGLE_DOCS_READONLY = "google:docs.readonly"
|
||||
GOOGLE_DOCS = "google:docs"
|
||||
|
||||
# GitHub scopes
|
||||
GITHUB_REPOS_READ = "github:repos.read"
|
||||
GITHUB_REPOS_WRITE = "github:repos.write"
|
||||
GITHUB_ISSUES_READ = "github:issues.read"
|
||||
GITHUB_ISSUES_WRITE = "github:issues.write"
|
||||
GITHUB_USER_READ = "github:user.read"
|
||||
GITHUB_GISTS = "github:gists"
|
||||
GITHUB_NOTIFICATIONS = "github:notifications"
|
||||
|
||||
# Discord scopes
|
||||
DISCORD_IDENTIFY = "discord:identify"
|
||||
DISCORD_EMAIL = "discord:email"
|
||||
DISCORD_GUILDS = "discord:guilds"
|
||||
DISCORD_MESSAGES_READ = "discord:messages.read"
|
||||
|
||||
# Twitter scopes
|
||||
TWITTER_READ = "twitter:read"
|
||||
TWITTER_WRITE = "twitter:write"
|
||||
TWITTER_DM = "twitter:dm"
|
||||
|
||||
# Notion scopes
|
||||
NOTION_READ = "notion:read"
|
||||
NOTION_WRITE = "notion:write"
|
||||
|
||||
# Todoist scopes
|
||||
TODOIST_READ = "todoist:read"
|
||||
TODOIST_WRITE = "todoist:write"
|
||||
|
||||
|
||||
# Scope descriptions for consent UI
|
||||
INTEGRATION_SCOPE_DESCRIPTIONS: dict[str, str] = {
|
||||
# Google
|
||||
IntegrationScope.GOOGLE_EMAIL_READ.value: "Read your email address",
|
||||
IntegrationScope.GOOGLE_GMAIL_READONLY.value: "Read your Gmail messages",
|
||||
IntegrationScope.GOOGLE_GMAIL_SEND.value: "Send emails on your behalf",
|
||||
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: "Read, send, and manage your emails",
|
||||
IntegrationScope.GOOGLE_DRIVE_READONLY.value: "View files in your Google Drive",
|
||||
IntegrationScope.GOOGLE_DRIVE_FILE.value: "Create and edit files in Google Drive",
|
||||
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: "View your calendar",
|
||||
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: "Create and edit calendar events",
|
||||
IntegrationScope.GOOGLE_SHEETS_READONLY.value: "View your spreadsheets",
|
||||
IntegrationScope.GOOGLE_SHEETS.value: "Create and edit spreadsheets",
|
||||
IntegrationScope.GOOGLE_DOCS_READONLY.value: "View your documents",
|
||||
IntegrationScope.GOOGLE_DOCS.value: "Create and edit documents",
|
||||
# GitHub
|
||||
IntegrationScope.GITHUB_REPOS_READ.value: "Read repository information",
|
||||
IntegrationScope.GITHUB_REPOS_WRITE.value: "Create and manage repositories",
|
||||
IntegrationScope.GITHUB_ISSUES_READ.value: "Read issues and pull requests",
|
||||
IntegrationScope.GITHUB_ISSUES_WRITE.value: "Create and manage issues",
|
||||
IntegrationScope.GITHUB_USER_READ.value: "Read your GitHub profile",
|
||||
IntegrationScope.GITHUB_GISTS.value: "Create and manage gists",
|
||||
IntegrationScope.GITHUB_NOTIFICATIONS.value: "Access notifications",
|
||||
# Discord
|
||||
IntegrationScope.DISCORD_IDENTIFY.value: "Access your Discord username",
|
||||
IntegrationScope.DISCORD_EMAIL.value: "Access your Discord email",
|
||||
IntegrationScope.DISCORD_GUILDS.value: "View your server list",
|
||||
IntegrationScope.DISCORD_MESSAGES_READ.value: "Read messages",
|
||||
# Twitter
|
||||
IntegrationScope.TWITTER_READ.value: "Read tweets and profile",
|
||||
IntegrationScope.TWITTER_WRITE.value: "Post tweets on your behalf",
|
||||
IntegrationScope.TWITTER_DM.value: "Send and read direct messages",
|
||||
# Notion
|
||||
IntegrationScope.NOTION_READ.value: "View Notion pages",
|
||||
IntegrationScope.NOTION_WRITE.value: "Create and edit Notion pages",
|
||||
# Todoist
|
||||
IntegrationScope.TODOIST_READ.value: "View your tasks",
|
||||
IntegrationScope.TODOIST_WRITE.value: "Create and manage tasks",
|
||||
}
|
||||
|
||||
|
||||
# Mapping from integration scopes to provider OAuth scopes
|
||||
INTEGRATION_SCOPE_MAPPING: dict[str, dict[str, list[str]]] = {
|
||||
ProviderName.GOOGLE.value: {
|
||||
IntegrationScope.GOOGLE_EMAIL_READ.value: [
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"openid",
|
||||
],
|
||||
IntegrationScope.GOOGLE_GMAIL_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_GMAIL_SEND.value: [
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
],
|
||||
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: [
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DRIVE_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DRIVE_FILE.value: [
|
||||
"https://www.googleapis.com/auth/drive.file",
|
||||
],
|
||||
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: [
|
||||
"https://www.googleapis.com/auth/calendar.events",
|
||||
],
|
||||
IntegrationScope.GOOGLE_SHEETS_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/spreadsheets.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_SHEETS.value: [
|
||||
"https://www.googleapis.com/auth/spreadsheets",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DOCS_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/documents.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DOCS.value: [
|
||||
"https://www.googleapis.com/auth/documents",
|
||||
],
|
||||
},
|
||||
ProviderName.GITHUB.value: {
|
||||
IntegrationScope.GITHUB_REPOS_READ.value: [
|
||||
"repo:status",
|
||||
"public_repo",
|
||||
],
|
||||
IntegrationScope.GITHUB_REPOS_WRITE.value: [
|
||||
"repo",
|
||||
],
|
||||
IntegrationScope.GITHUB_ISSUES_READ.value: [
|
||||
"repo:status",
|
||||
],
|
||||
IntegrationScope.GITHUB_ISSUES_WRITE.value: [
|
||||
"repo",
|
||||
],
|
||||
IntegrationScope.GITHUB_USER_READ.value: [
|
||||
"read:user",
|
||||
"user:email",
|
||||
],
|
||||
IntegrationScope.GITHUB_GISTS.value: [
|
||||
"gist",
|
||||
],
|
||||
IntegrationScope.GITHUB_NOTIFICATIONS.value: [
|
||||
"notifications",
|
||||
],
|
||||
},
|
||||
ProviderName.DISCORD.value: {
|
||||
IntegrationScope.DISCORD_IDENTIFY.value: [
|
||||
"identify",
|
||||
],
|
||||
IntegrationScope.DISCORD_EMAIL.value: [
|
||||
"email",
|
||||
],
|
||||
IntegrationScope.DISCORD_GUILDS.value: [
|
||||
"guilds",
|
||||
],
|
||||
IntegrationScope.DISCORD_MESSAGES_READ.value: [
|
||||
"messages.read",
|
||||
],
|
||||
},
|
||||
ProviderName.TWITTER.value: {
|
||||
IntegrationScope.TWITTER_READ.value: [
|
||||
"tweet.read",
|
||||
"users.read",
|
||||
],
|
||||
IntegrationScope.TWITTER_WRITE.value: [
|
||||
"tweet.write",
|
||||
],
|
||||
IntegrationScope.TWITTER_DM.value: [
|
||||
"dm.read",
|
||||
"dm.write",
|
||||
],
|
||||
},
|
||||
ProviderName.NOTION.value: {
|
||||
IntegrationScope.NOTION_READ.value: [], # Notion uses workspace-level access
|
||||
IntegrationScope.NOTION_WRITE.value: [],
|
||||
},
|
||||
ProviderName.TODOIST.value: {
|
||||
IntegrationScope.TODOIST_READ.value: [
|
||||
"data:read",
|
||||
],
|
||||
IntegrationScope.TODOIST_WRITE.value: [
|
||||
"data:read_write",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_scopes(
|
||||
provider: ProviderName | str, integration_scopes: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Convert integration scopes to provider-specific OAuth scopes.
|
||||
|
||||
Args:
|
||||
provider: The provider name
|
||||
integration_scopes: List of integration scope strings
|
||||
|
||||
Returns:
|
||||
List of provider-specific OAuth scopes
|
||||
"""
|
||||
provider_value = provider.value if isinstance(provider, ProviderName) else provider
|
||||
provider_mapping = INTEGRATION_SCOPE_MAPPING.get(provider_value, {})
|
||||
|
||||
oauth_scopes: set[str] = set()
|
||||
for scope in integration_scopes:
|
||||
if scope in provider_mapping:
|
||||
oauth_scopes.update(provider_mapping[scope])
|
||||
|
||||
return list(oauth_scopes)
|
||||
|
||||
|
||||
def get_provider_for_scope(scope: str) -> Optional[ProviderName]:
|
||||
"""
|
||||
Get the provider for an integration scope.
|
||||
|
||||
Args:
|
||||
scope: Integration scope string (e.g., "google:gmail.readonly")
|
||||
|
||||
Returns:
|
||||
ProviderName or None if not recognized
|
||||
"""
|
||||
if ":" not in scope:
|
||||
return None
|
||||
|
||||
provider_prefix = scope.split(":")[0]
|
||||
|
||||
# Map prefixes to providers
|
||||
prefix_mapping = {
|
||||
"google": ProviderName.GOOGLE,
|
||||
"github": ProviderName.GITHUB,
|
||||
"discord": ProviderName.DISCORD,
|
||||
"twitter": ProviderName.TWITTER,
|
||||
"notion": ProviderName.NOTION,
|
||||
"todoist": ProviderName.TODOIST,
|
||||
}
|
||||
|
||||
return prefix_mapping.get(provider_prefix)
|
||||
|
||||
|
||||
def validate_integration_scopes(scopes: list[str]) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate a list of integration scopes.
|
||||
|
||||
Args:
|
||||
scopes: List of integration scope strings
|
||||
|
||||
Returns:
|
||||
Tuple of (valid, invalid_scopes)
|
||||
"""
|
||||
valid_scopes = {s.value for s in IntegrationScope}
|
||||
invalid = [s for s in scopes if s not in valid_scopes]
|
||||
return len(invalid) == 0, invalid
|
||||
|
||||
|
||||
def group_scopes_by_provider(
|
||||
scopes: list[str],
|
||||
) -> dict[ProviderName, list[str]]:
|
||||
"""
|
||||
Group integration scopes by their provider.
|
||||
|
||||
Args:
|
||||
scopes: List of integration scope strings
|
||||
|
||||
Returns:
|
||||
Dictionary mapping providers to their scopes
|
||||
"""
|
||||
grouped: dict[ProviderName, list[str]] = {}
|
||||
|
||||
for scope in scopes:
|
||||
provider = get_provider_for_scope(scope)
|
||||
if provider:
|
||||
if provider not in grouped:
|
||||
grouped[provider] = []
|
||||
grouped[provider].append(scope)
|
||||
|
||||
return grouped
|
||||
176
autogpt_platform/backend/backend/data/oauth_audit.py
Normal file
176
autogpt_platform/backend/backend/data/oauth_audit.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
OAuth Audit Logging.
|
||||
|
||||
Logs all OAuth-related operations for security auditing and compliance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.data.db import prisma
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEventType(str, Enum):
|
||||
"""Types of OAuth events to audit."""
|
||||
|
||||
# Client events
|
||||
CLIENT_REGISTERED = "client.registered"
|
||||
CLIENT_UPDATED = "client.updated"
|
||||
CLIENT_DELETED = "client.deleted"
|
||||
CLIENT_SECRET_ROTATED = "client.secret_rotated"
|
||||
CLIENT_SUSPENDED = "client.suspended"
|
||||
CLIENT_ACTIVATED = "client.activated"
|
||||
|
||||
# Authorization events
|
||||
AUTHORIZATION_REQUESTED = "authorization.requested"
|
||||
AUTHORIZATION_GRANTED = "authorization.granted"
|
||||
AUTHORIZATION_DENIED = "authorization.denied"
|
||||
AUTHORIZATION_REVOKED = "authorization.revoked"
|
||||
|
||||
# Token events
|
||||
TOKEN_ISSUED = "token.issued"
|
||||
TOKEN_REFRESHED = "token.refreshed"
|
||||
TOKEN_REVOKED = "token.revoked"
|
||||
TOKEN_EXPIRED = "token.expired"
|
||||
|
||||
# Grant events
|
||||
GRANT_CREATED = "grant.created"
|
||||
GRANT_UPDATED = "grant.updated"
|
||||
GRANT_REVOKED = "grant.revoked"
|
||||
GRANT_USED = "grant.used"
|
||||
|
||||
# Credential events
|
||||
CREDENTIAL_CONNECTED = "credential.connected"
|
||||
CREDENTIAL_DELETED = "credential.deleted"
|
||||
|
||||
# Execution events
|
||||
EXECUTION_STARTED = "execution.started"
|
||||
EXECUTION_COMPLETED = "execution.completed"
|
||||
EXECUTION_FAILED = "execution.failed"
|
||||
EXECUTION_CANCELLED = "execution.cancelled"
|
||||
|
||||
|
||||
async def log_oauth_event(
|
||||
event_type: OAuthEventType,
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
grant_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
details: Optional[dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Log an OAuth audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
user_id: User ID involved (if any)
|
||||
client_id: OAuth client ID involved (if any)
|
||||
grant_id: Grant ID involved (if any)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
details: Additional event details
|
||||
|
||||
Returns:
|
||||
ID of the created audit log entry
|
||||
"""
|
||||
try:
|
||||
from prisma import Json
|
||||
|
||||
audit_entry = await prisma.oauthauditlog.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"eventType": event_type.value,
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"grantId": grant_id,
|
||||
"ipAddress": ip_address,
|
||||
"userAgent": user_agent,
|
||||
"details": Json(details or {}),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"OAuth audit: {event_type.value} - "
|
||||
f"user={user_id}, client={client_id}, grant={grant_id}"
|
||||
)
|
||||
|
||||
return audit_entry.id
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail the operation if audit logging fails
|
||||
logger.error(f"Failed to create OAuth audit log: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def get_audit_logs(
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
event_type: Optional[OAuthEventType] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list:
|
||||
"""
|
||||
Query OAuth audit logs.
|
||||
|
||||
Args:
|
||||
user_id: Filter by user ID
|
||||
client_id: Filter by client ID
|
||||
event_type: Filter by event type
|
||||
start_date: Filter by start date
|
||||
end_date: Filter by end date
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of audit log entries
|
||||
"""
|
||||
where: dict[str, Any] = {}
|
||||
|
||||
if user_id:
|
||||
where["userId"] = user_id
|
||||
if client_id:
|
||||
where["clientId"] = client_id
|
||||
if event_type:
|
||||
where["eventType"] = event_type.value
|
||||
if start_date:
|
||||
where["createdAt"] = {"gte": start_date}
|
||||
if end_date:
|
||||
if "createdAt" in where:
|
||||
where["createdAt"]["lte"] = end_date
|
||||
else:
|
||||
where["createdAt"] = {"lte": end_date}
|
||||
|
||||
return await prisma.oauthauditlog.find_many(
|
||||
where=where if where else None, # type: ignore[arg-type]
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_old_audit_logs(days_to_keep: int = 90) -> int:
|
||||
"""
|
||||
Delete audit logs older than the specified number of days.
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days of logs to retain
|
||||
|
||||
Returns:
|
||||
Number of logs deleted
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
result = await prisma.oauthauditlog.delete_many(
|
||||
where={"createdAt": {"lt": cutoff_date}}
|
||||
)
|
||||
|
||||
logger.info(f"Cleaned up {result} OAuth audit logs older than {days_to_keep} days")
|
||||
return result
|
||||
@@ -7,7 +7,11 @@ import prisma
|
||||
import pydantic
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
from prisma.types import (
|
||||
UserOnboardingCreateInput,
|
||||
UserOnboardingUpdateInput,
|
||||
UserOnboardingUpsertInput,
|
||||
)
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
@@ -112,10 +116,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
data=UserOnboardingUpsertInput(
|
||||
create=UserOnboardingCreateInput(userId=user_id, **update),
|
||||
update=update,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhook_notifier import get_webhook_notifier
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
@@ -221,11 +222,31 @@ async def execute_node(
|
||||
creds_locks: list[AsyncRedisLock] = []
|
||||
input_model = cast(type[BlockSchema], node_block.input_schema)
|
||||
|
||||
# Check if this is an external API execution using grant-based credential resolution
|
||||
grant_resolver = None
|
||||
if execution_context and execution_context.grant_resolver_context:
|
||||
from backend.integrations.grant_resolver import GrantBasedCredentialResolver
|
||||
|
||||
grant_ctx = execution_context.grant_resolver_context
|
||||
grant_resolver = GrantBasedCredentialResolver(
|
||||
user_id=user_id,
|
||||
client_id=grant_ctx.client_db_id,
|
||||
grant_ids=grant_ctx.grant_ids,
|
||||
)
|
||||
await grant_resolver.initialize()
|
||||
|
||||
# Handle regular credentials fields
|
||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||
credentials_meta = input_type(**input_data[field_name])
|
||||
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||||
creds_locks.append(lock)
|
||||
if grant_resolver:
|
||||
# External API execution - use grant resolver (no locking needed)
|
||||
credentials = await grant_resolver.resolve_credential(credentials_meta.id)
|
||||
else:
|
||||
# Normal execution - use credentials manager with locking
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||
@@ -243,10 +264,17 @@ async def execute_node(
|
||||
)
|
||||
file_name = field_data.get("name", "selected file")
|
||||
try:
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, cred_id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
if grant_resolver:
|
||||
# External API execution - use grant resolver
|
||||
credentials = await grant_resolver.resolve_credential(
|
||||
cred_id
|
||||
)
|
||||
else:
|
||||
# Normal execution - use credentials manager
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, cred_id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
extra_exec_kwargs[kwarg_name] = credentials
|
||||
except ValueError:
|
||||
# Credential was deleted or doesn't exist
|
||||
@@ -785,6 +813,7 @@ class ExecutionProcessor:
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=exec_meta.status,
|
||||
stats=exec_stats,
|
||||
event_loop=self.node_execution_loop,
|
||||
)
|
||||
|
||||
def _charge_usage(
|
||||
@@ -1916,6 +1945,53 @@ def update_node_execution_status(
|
||||
return exec_update
|
||||
|
||||
|
||||
async def _notify_execution_webhook(
|
||||
execution_id: str,
|
||||
agent_id: str,
|
||||
status: ExecutionStatus,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send webhook notification for execution completion if registered.
|
||||
|
||||
This is a fire-and-forget operation that checks if a webhook was registered
|
||||
for this execution and sends the appropriate notification.
|
||||
"""
|
||||
from backend.data.db import prisma
|
||||
|
||||
try:
|
||||
webhook = await prisma.executionwebhook.find_first(
|
||||
where={"executionId": execution_id}
|
||||
)
|
||||
if not webhook:
|
||||
return
|
||||
|
||||
notifier = get_webhook_notifier()
|
||||
|
||||
if status == ExecutionStatus.COMPLETED:
|
||||
await notifier.notify_execution_completed(
|
||||
execution_id=execution_id,
|
||||
agent_id=agent_id,
|
||||
client_id=webhook.clientId,
|
||||
webhook_url=webhook.webhookUrl,
|
||||
outputs=outputs or {},
|
||||
webhook_secret=webhook.secret,
|
||||
)
|
||||
elif status == ExecutionStatus.FAILED:
|
||||
await notifier.notify_execution_failed(
|
||||
execution_id=execution_id,
|
||||
agent_id=agent_id,
|
||||
client_id=webhook.clientId,
|
||||
webhook_url=webhook.webhookUrl,
|
||||
error=error or "Execution failed",
|
||||
webhook_secret=webhook.secret,
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't let webhook failures affect execution state updates
|
||||
logger.warning(f"Failed to send webhook notification for {execution_id}: {e}")
|
||||
|
||||
|
||||
async def async_update_graph_execution_state(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
@@ -1928,6 +2004,17 @@ async def async_update_graph_execution_state(
|
||||
)
|
||||
if graph_update:
|
||||
await send_async_execution_update(graph_update)
|
||||
|
||||
# Send webhook notification for terminal states
|
||||
if status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED:
|
||||
await _notify_execution_webhook(
|
||||
execution_id=graph_exec_id,
|
||||
agent_id=graph_update.graph_id,
|
||||
status=status,
|
||||
outputs=(
|
||||
graph_update.outputs if hasattr(graph_update, "outputs") else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
@@ -1938,11 +2025,33 @@ def update_graph_execution_state(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
event_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||||
if graph_update:
|
||||
send_execution_update(graph_update)
|
||||
|
||||
# Send webhook notification for terminal states (fire-and-forget)
|
||||
if (
|
||||
status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED
|
||||
) and event_loop:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_notify_execution_webhook(
|
||||
execution_id=graph_exec_id,
|
||||
agent_id=graph_update.graph_id,
|
||||
status=status,
|
||||
outputs=(
|
||||
graph_update.outputs
|
||||
if hasattr(graph_update, "outputs")
|
||||
else None
|
||||
),
|
||||
),
|
||||
event_loop,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to schedule webhook notification: {e}")
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
278
autogpt_platform/backend/backend/integrations/grant_resolver.py
Normal file
278
autogpt_platform/backend/backend/integrations/grant_resolver.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Grant-Based Credential Resolver.
|
||||
|
||||
Resolves credentials during agent execution based on credential grants.
|
||||
External applications can only use credentials they have been granted access to,
|
||||
and only for the scopes that were granted.
|
||||
|
||||
Credentials are NEVER exposed to external applications - this resolver
|
||||
provides the credentials to the execution engine internally.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from prisma.enums import CredentialGrantPermission
|
||||
from prisma.models import CredentialGrant
|
||||
|
||||
from backend.data import credential_grants as grants_db
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrantValidationError(Exception):
|
||||
"""Raised when a grant is invalid or lacks required permissions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialNotFoundError(Exception):
|
||||
"""Raised when a credential referenced by a grant is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScopeMismatchError(Exception):
|
||||
"""Raised when the grant doesn't cover required scopes."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GrantBasedCredentialResolver:
|
||||
"""
|
||||
Resolves credentials for agent execution based on credential grants.
|
||||
|
||||
This resolver validates that:
|
||||
1. The grant exists and is valid (not revoked/expired)
|
||||
2. The grant has USE permission
|
||||
3. The grant covers the required scopes (if specified)
|
||||
4. The underlying credential exists
|
||||
|
||||
Then it provides the credential to the execution engine internally.
|
||||
The credential value is NEVER exposed to external applications.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
grant_ids: list[str],
|
||||
):
|
||||
"""
|
||||
Initialize the resolver.
|
||||
|
||||
Args:
|
||||
user_id: User ID who owns the credentials
|
||||
client_id: Database ID of the OAuth client
|
||||
grant_ids: List of grant IDs the client is using for this execution
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.client_id = client_id
|
||||
self.grant_ids = grant_ids
|
||||
self._grants: dict[str, CredentialGrant] = {}
|
||||
self._credentials_manager = IntegrationCredentialsManager()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Load and validate all grants.
|
||||
|
||||
This should be called before any credential resolution.
|
||||
|
||||
Raises:
|
||||
GrantValidationError: If any grant is invalid
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for grant_id in self.grant_ids:
|
||||
grant = await grants_db.get_credential_grant(
|
||||
grant_id=grant_id,
|
||||
user_id=self.user_id,
|
||||
client_id=self.client_id,
|
||||
)
|
||||
|
||||
if not grant:
|
||||
raise GrantValidationError(f"Grant {grant_id} not found")
|
||||
|
||||
# Check if revoked
|
||||
if grant.revokedAt:
|
||||
raise GrantValidationError(f"Grant {grant_id} has been revoked")
|
||||
|
||||
# Check if expired
|
||||
if grant.expiresAt and grant.expiresAt < now:
|
||||
raise GrantValidationError(f"Grant {grant_id} has expired")
|
||||
|
||||
# Check USE permission
|
||||
if CredentialGrantPermission.USE not in grant.permissions:
|
||||
raise GrantValidationError(
|
||||
f"Grant {grant_id} does not have USE permission"
|
||||
)
|
||||
|
||||
self._grants[grant_id] = grant
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
f"Initialized grant resolver with {len(self._grants)} grants "
|
||||
f"for user {self.user_id}, client {self.client_id}"
|
||||
)
|
||||
|
||||
async def resolve_credential(
|
||||
self,
|
||||
credential_id: str,
|
||||
required_scopes: Optional[list[str]] = None,
|
||||
) -> Credentials:
|
||||
"""
|
||||
Resolve a credential for agent execution.
|
||||
|
||||
This method:
|
||||
1. Finds a grant that covers this credential
|
||||
2. Validates the grant covers required scopes
|
||||
3. Retrieves the actual credential
|
||||
4. Updates grant usage tracking
|
||||
|
||||
Args:
|
||||
credential_id: ID of the credential to resolve
|
||||
required_scopes: Optional list of scopes the credential must have
|
||||
|
||||
Returns:
|
||||
The resolved Credentials object
|
||||
|
||||
Raises:
|
||||
GrantValidationError: If no valid grant covers this credential
|
||||
ScopeMismatchError: If the grant doesn't cover required scopes
|
||||
CredentialNotFoundError: If the underlying credential doesn't exist
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("Resolver not initialized. Call initialize() first.")
|
||||
|
||||
# Find a grant that covers this credential
|
||||
matching_grant: Optional[CredentialGrant] = None
|
||||
for grant in self._grants.values():
|
||||
if grant.credentialId == credential_id:
|
||||
matching_grant = grant
|
||||
break
|
||||
|
||||
if not matching_grant:
|
||||
raise GrantValidationError(f"No grant found for credential {credential_id}")
|
||||
|
||||
# Validate scopes if required
|
||||
if required_scopes:
|
||||
granted_scopes = set(matching_grant.grantedScopes)
|
||||
required_scopes_set = set(required_scopes)
|
||||
|
||||
missing_scopes = required_scopes_set - granted_scopes
|
||||
if missing_scopes:
|
||||
raise ScopeMismatchError(
|
||||
f"Grant {matching_grant.id} is missing required scopes: "
|
||||
f"{', '.join(missing_scopes)}"
|
||||
)
|
||||
|
||||
# Get the actual credential
|
||||
credentials = await self._credentials_manager.get(
|
||||
user_id=self.user_id,
|
||||
credentials_id=credential_id,
|
||||
lock=True,
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
raise CredentialNotFoundError(
|
||||
f"Credential {credential_id} not found for user {self.user_id}"
|
||||
)
|
||||
|
||||
# Update last used timestamp for the grant
|
||||
await grants_db.update_grant_last_used(matching_grant.id)
|
||||
|
||||
logger.debug(
|
||||
f"Resolved credential {credential_id} via grant {matching_grant.id} "
|
||||
f"for client {self.client_id}"
|
||||
)
|
||||
|
||||
return credentials
|
||||
|
||||
async def get_available_credentials(self) -> list[dict]:
|
||||
"""
|
||||
Get list of available credentials based on grants.
|
||||
|
||||
Returns a list of credential metadata (NOT the actual credential values).
|
||||
|
||||
Returns:
|
||||
List of dicts with credential metadata
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("Resolver not initialized. Call initialize() first.")
|
||||
|
||||
credentials_info = []
|
||||
for grant in self._grants.values():
|
||||
credentials_info.append(
|
||||
{
|
||||
"grant_id": grant.id,
|
||||
"credential_id": grant.credentialId,
|
||||
"provider": grant.provider,
|
||||
"granted_scopes": grant.grantedScopes,
|
||||
}
|
||||
)
|
||||
|
||||
return credentials_info
|
||||
|
||||
def get_grant_for_credential(self, credential_id: str) -> Optional[CredentialGrant]:
|
||||
"""
|
||||
Get the grant for a specific credential.
|
||||
|
||||
Args:
|
||||
credential_id: ID of the credential
|
||||
|
||||
Returns:
|
||||
CredentialGrant or None if not found
|
||||
"""
|
||||
for grant in self._grants.values():
|
||||
if grant.credentialId == credential_id:
|
||||
return grant
|
||||
return None
|
||||
|
||||
|
||||
async def create_resolver_from_oauth_token(
|
||||
user_id: str,
|
||||
client_public_id: str,
|
||||
grant_ids: Optional[list[str]] = None,
|
||||
) -> GrantBasedCredentialResolver:
|
||||
"""
|
||||
Create a credential resolver from OAuth token context.
|
||||
|
||||
This is a convenience function for creating a resolver from
|
||||
the context available in OAuth-authenticated requests.
|
||||
|
||||
Args:
|
||||
user_id: User ID from the OAuth token
|
||||
client_public_id: Public client ID from the OAuth token
|
||||
grant_ids: Optional list of grant IDs to use
|
||||
|
||||
Returns:
|
||||
Initialized GrantBasedCredentialResolver
|
||||
"""
|
||||
# Look up the OAuth client database ID from the public client ID
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": client_public_id})
|
||||
if not client:
|
||||
raise GrantValidationError(f"OAuth client {client_public_id} not found")
|
||||
|
||||
# If no grant IDs specified, get all grants for this client+user
|
||||
if grant_ids is None:
|
||||
grants = await grants_db.get_grants_for_user_client(
|
||||
user_id=user_id,
|
||||
client_id=client.id,
|
||||
include_revoked=False,
|
||||
include_expired=False,
|
||||
)
|
||||
grant_ids = [g.id for g in grants]
|
||||
|
||||
resolver = GrantBasedCredentialResolver(
|
||||
user_id=user_id,
|
||||
client_id=client.id,
|
||||
grant_ids=grant_ids,
|
||||
)
|
||||
await resolver.initialize()
|
||||
|
||||
return resolver
|
||||
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Webhook Notification System for External API.
|
||||
|
||||
Sends webhook notifications to external applications for execution events.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import weakref
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Coroutine, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Webhook delivery settings
|
||||
WEBHOOK_TIMEOUT_SECONDS = 30
|
||||
WEBHOOK_MAX_RETRIES = 3
|
||||
WEBHOOK_RETRY_DELAYS = [5, 30, 300] # seconds: 5s, 30s, 5min
|
||||
|
||||
|
||||
class WebhookDeliveryError(Exception):
|
||||
"""Raised when webhook delivery fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def sign_webhook_payload(payload: dict[str, Any], secret: str) -> str:
|
||||
"""
|
||||
Create HMAC-SHA256 signature for webhook payload.
|
||||
|
||||
Args:
|
||||
payload: The webhook payload to sign
|
||||
secret: The webhook secret key
|
||||
|
||||
Returns:
|
||||
Hex-encoded HMAC-SHA256 signature
|
||||
"""
|
||||
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
|
||||
signature = hmac.new(
|
||||
secret.encode(),
|
||||
payload_bytes,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return signature
|
||||
|
||||
|
||||
def verify_webhook_signature(
|
||||
payload: dict[str, Any],
|
||||
signature: str,
|
||||
secret: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify a webhook signature.
|
||||
|
||||
Args:
|
||||
payload: The webhook payload
|
||||
signature: The signature to verify
|
||||
secret: The webhook secret key
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
"""
|
||||
expected = sign_webhook_payload(payload, secret)
|
||||
return hmac.compare_digest(expected, signature)
|
||||
|
||||
|
||||
def validate_webhook_url(url: str, allowed_domains: list[str]) -> bool:
|
||||
"""
|
||||
Validate that a webhook URL is allowed.
|
||||
|
||||
Args:
|
||||
url: The webhook URL to validate
|
||||
allowed_domains: List of allowed domains (from OAuth client config)
|
||||
|
||||
Returns:
|
||||
True if URL is valid and allowed
|
||||
"""
|
||||
from backend.util.url import hostname_matches_any_domain
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Must be HTTPS (except for localhost in development)
|
||||
if parsed.scheme != "https":
|
||||
if not (
|
||||
parsed.scheme == "http"
|
||||
and parsed.hostname in ["localhost", "127.0.0.1"]
|
||||
):
|
||||
return False
|
||||
|
||||
# Must have a host
|
||||
if not parsed.hostname:
|
||||
return False
|
||||
|
||||
# Check against allowed domains
|
||||
return hostname_matches_any_domain(parsed.hostname, allowed_domains)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def send_webhook(
|
||||
url: str,
|
||||
payload: dict[str, Any],
|
||||
secret: Optional[str] = None,
|
||||
timeout: int = WEBHOOK_TIMEOUT_SECONDS,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a webhook notification.
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
payload: Payload to send
|
||||
secret: Optional secret for signature
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if webhook was delivered successfully
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "AutoGPT-Webhook/1.0",
|
||||
"X-Webhook-Timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
if secret:
|
||||
signature = sign_webhook_payload(payload, secret)
|
||||
headers["X-Webhook-Signature"] = f"sha256={signature}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
logger.debug(f"Webhook delivered successfully to {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Webhook delivery failed: {url} returned {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning(f"Webhook delivery timed out: {url}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Webhook delivery error: {url} - {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def send_webhook_with_retry(
|
||||
url: str,
|
||||
payload: dict[str, Any],
|
||||
secret: Optional[str] = None,
|
||||
max_retries: int = WEBHOOK_MAX_RETRIES,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a webhook with automatic retries.
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
payload: Payload to send
|
||||
secret: Optional secret for signature
|
||||
max_retries: Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
True if webhook was eventually delivered successfully
|
||||
"""
|
||||
for attempt in range(max_retries + 1):
|
||||
if await send_webhook(url, payload, secret):
|
||||
return True
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = WEBHOOK_RETRY_DELAYS[min(attempt, len(WEBHOOK_RETRY_DELAYS) - 1)]
|
||||
logger.info(
|
||||
f"Webhook delivery failed, retrying in {delay}s (attempt {attempt + 1})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error(f"Webhook delivery failed after {max_retries} retries: {url}")
|
||||
return False
|
||||
|
||||
|
||||
# Track pending webhook tasks to prevent garbage collection
|
||||
# Using WeakSet so tasks are automatically removed when they complete and are dereferenced
|
||||
_pending_webhook_tasks: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
|
||||
|
||||
|
||||
def _create_tracked_task(coro: Coroutine[Any, Any, bool]) -> asyncio.Task[bool]:
|
||||
"""Create a task that is tracked to prevent garbage collection."""
|
||||
task = asyncio.create_task(coro)
|
||||
_pending_webhook_tasks.add(task)
|
||||
# No explicit done callback needed - WeakSet automatically removes
|
||||
# references when tasks are garbage collected after completion
|
||||
return task
|
||||
|
||||
|
||||
class WebhookNotifier:
|
||||
"""
|
||||
Service for sending webhook notifications to external applications.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def notify_execution_started(
|
||||
self,
|
||||
execution_id: str,
|
||||
agent_id: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that an execution has started.
|
||||
"""
|
||||
payload = {
|
||||
"event": "execution.started",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"agent_id": agent_id,
|
||||
"status": "running",
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
async def notify_execution_completed(
|
||||
self,
|
||||
execution_id: str,
|
||||
agent_id: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
outputs: dict[str, Any],
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that an execution has completed successfully.
|
||||
"""
|
||||
payload = {
|
||||
"event": "execution.completed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"agent_id": agent_id,
|
||||
"status": "completed",
|
||||
"outputs": outputs,
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
async def notify_execution_failed(
|
||||
self,
|
||||
execution_id: str,
|
||||
agent_id: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
error: str,
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that an execution has failed.
|
||||
"""
|
||||
payload = {
|
||||
"event": "execution.failed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"agent_id": agent_id,
|
||||
"status": "failed",
|
||||
"error": error,
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
async def notify_grant_revoked(
|
||||
self,
|
||||
grant_id: str,
|
||||
credential_id: str,
|
||||
provider: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that a credential grant has been revoked.
|
||||
"""
|
||||
payload = {
|
||||
"event": "grant.revoked",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"grant_id": grant_id,
|
||||
"credential_id": credential_id,
|
||||
"provider": provider,
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_webhook_notifier: Optional[WebhookNotifier] = None
|
||||
|
||||
|
||||
def get_webhook_notifier() -> WebhookNotifier:
|
||||
"""Get the singleton webhook notifier instance."""
|
||||
global _webhook_notifier
|
||||
if _webhook_notifier is None:
|
||||
_webhook_notifier = WebhookNotifier()
|
||||
return _webhook_notifier
|
||||
@@ -3,21 +3,19 @@ from fastapi import FastAPI
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.integrations import integrations_router
|
||||
from .routes.tools import tools_router
|
||||
from .routes.v1 import v1_router
|
||||
from .routes.execution import execution_router
|
||||
from .routes.grants import grants_router
|
||||
|
||||
external_app = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
description="External API for AutoGPT integrations (OAuth-based)",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
external_app.include_router(tools_router, prefix="/v1")
|
||||
external_app.include_router(integrations_router, prefix="/v1")
|
||||
external_app.include_router(grants_router, prefix="/v1")
|
||||
external_app.include_router(execution_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Base middleware for API key authentication"""
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""Dependency function for checking specific permissions"""
|
||||
|
||||
async def check_permission(
|
||||
api_key: APIKeyInfo = Security(require_api_key),
|
||||
) -> APIKeyInfo:
|
||||
if not has_permission(api_key, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"API key lacks the required permission '{permission}'",
|
||||
)
|
||||
return api_key
|
||||
|
||||
return check_permission
|
||||
164
autogpt_platform/backend/backend/server/external/oauth_middleware.py
vendored
Normal file
164
autogpt_platform/backend/backend/server/external/oauth_middleware.py
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
OAuth Access Token middleware for external API.
|
||||
|
||||
Validates OAuth access tokens and provides user/client context
|
||||
for external API endpoints that use OAuth authentication.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.server.oauth.token_service import get_token_service
|
||||
|
||||
|
||||
class OAuthTokenInfo(BaseModel):
|
||||
"""Information extracted from a validated OAuth access token."""
|
||||
|
||||
user_id: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
token_id: str
|
||||
|
||||
|
||||
# HTTP Bearer token extractor
|
||||
oauth_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_oauth_token(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(oauth_bearer),
|
||||
) -> OAuthTokenInfo:
|
||||
"""
|
||||
Validate an OAuth access token and return token info.
|
||||
|
||||
Extracts the Bearer token from the Authorization header,
|
||||
validates the JWT signature and claims, and checks that
|
||||
the token hasn't been revoked.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if token is missing, invalid, or revoked
|
||||
"""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
token_service = get_token_service()
|
||||
|
||||
try:
|
||||
# Verify JWT signature and claims
|
||||
claims = token_service.verify_access_token(token)
|
||||
|
||||
# Check if token is in database and not revoked
|
||||
token_hash = token_service.hash_token(token)
|
||||
stored_token = await prisma.oauthaccesstoken.find_unique(
|
||||
where={"tokenHash": token_hash}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if stored_token.revokedAt:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has been revoked",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if stored_token.expiresAt < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Update last used timestamp (fire and forget)
|
||||
await prisma.oauthaccesstoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
return OAuthTokenInfo(
|
||||
user_id=claims.sub,
|
||||
client_id=claims.client_id,
|
||||
scopes=claims.scope.split() if claims.scope else [],
|
||||
token_id=stored_token.id,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid token: {str(e)}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def require_scope(required_scope: str):
|
||||
"""
|
||||
Dependency that validates OAuth token and checks for required scope.
|
||||
|
||||
Args:
|
||||
required_scope: The scope required for this endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that returns OAuthTokenInfo if authorized
|
||||
"""
|
||||
|
||||
async def check_scope(
|
||||
token: OAuthTokenInfo = Security(require_oauth_token),
|
||||
) -> OAuthTokenInfo:
|
||||
if required_scope not in token.scopes:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Token lacks required scope '{required_scope}'",
|
||||
headers={"WWW-Authenticate": f'Bearer scope="{required_scope}"'},
|
||||
)
|
||||
return token
|
||||
|
||||
return check_scope
|
||||
|
||||
|
||||
def require_any_scope(*required_scopes: str):
|
||||
"""
|
||||
Dependency that validates OAuth token and checks for any of the required scopes.
|
||||
|
||||
Args:
|
||||
required_scopes: At least one of these scopes is required
|
||||
|
||||
Returns:
|
||||
Dependency function that returns OAuthTokenInfo if authorized
|
||||
"""
|
||||
|
||||
async def check_scopes(
|
||||
token: OAuthTokenInfo = Security(require_oauth_token),
|
||||
) -> OAuthTokenInfo:
|
||||
for scope in required_scopes:
|
||||
if scope in token.scopes:
|
||||
return token
|
||||
|
||||
scope_list = " ".join(required_scopes)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Token lacks required scopes (need one of: {scope_list})",
|
||||
headers={"WWW-Authenticate": f'Bearer scope="{scope_list}"'},
|
||||
)
|
||||
|
||||
return check_scopes
|
||||
377
autogpt_platform/backend/backend/server/external/routes/execution.py
vendored
Normal file
377
autogpt_platform/backend/backend/server/external/routes/execution.py
vendored
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Agent Execution endpoints for external OAuth clients.
|
||||
|
||||
Allows external applications to:
|
||||
- Execute agents using granted credentials
|
||||
- Poll execution status
|
||||
- Cancel running executions
|
||||
- Get available capabilities
|
||||
|
||||
External apps can only use credentials they have been granted access to.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import ExecutionContext, GrantResolverContext
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.grant_resolver import (
|
||||
GrantValidationError,
|
||||
create_resolver_from_oauth_token,
|
||||
)
|
||||
from backend.integrations.webhook_notifier import validate_webhook_url
|
||||
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
execution_router = APIRouter(prefix="/executions", tags=["executions"])
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Models
|
||||
# ================================================================
|
||||
|
||||
|
||||
class ExecuteAgentRequest(BaseModel):
|
||||
"""Request to execute an agent."""
|
||||
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Input values for the agent",
|
||||
)
|
||||
grant_ids: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="Specific grant IDs to use. If not provided, uses all available grants.",
|
||||
)
|
||||
webhook_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="URL to receive execution status webhooks",
|
||||
)
|
||||
|
||||
|
||||
class ExecuteAgentResponse(BaseModel):
|
||||
"""Response from starting an agent execution."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class ExecutionStatusResponse(BaseModel):
|
||||
"""Response with execution status."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class GrantInfo(BaseModel):
|
||||
"""Summary of a credential grant for capabilities."""
|
||||
|
||||
grant_id: str
|
||||
provider: str
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class CapabilitiesResponse(BaseModel):
|
||||
"""Response describing what the client can do."""
|
||||
|
||||
user_id: str
|
||||
client_id: str
|
||||
grants: list[GrantInfo]
|
||||
available_scopes: list[str]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
|
||||
@execution_router.get("/capabilities", response_model=CapabilitiesResponse)
|
||||
async def get_capabilities(
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> CapabilitiesResponse:
|
||||
"""
|
||||
Get the capabilities available to this client for the authenticated user.
|
||||
|
||||
Returns information about:
|
||||
- Available credential grants (NOT credential values)
|
||||
- Scopes the client has access to
|
||||
"""
|
||||
try:
|
||||
resolver = await create_resolver_from_oauth_token(
|
||||
user_id=token.user_id,
|
||||
client_public_id=token.client_id,
|
||||
)
|
||||
credentials_info = await resolver.get_available_credentials()
|
||||
|
||||
grants = [
|
||||
GrantInfo(
|
||||
grant_id=info["grant_id"],
|
||||
provider=info["provider"],
|
||||
scopes=info["granted_scopes"],
|
||||
)
|
||||
for info in credentials_info
|
||||
]
|
||||
|
||||
return CapabilitiesResponse(
|
||||
user_id=token.user_id,
|
||||
client_id=token.client_id,
|
||||
grants=grants,
|
||||
available_scopes=token.scopes,
|
||||
)
|
||||
except GrantValidationError:
|
||||
# No grants available is not an error, just empty capabilities
|
||||
return CapabilitiesResponse(
|
||||
user_id=token.user_id,
|
||||
client_id=token.client_id,
|
||||
grants=[],
|
||||
available_scopes=token.scopes,
|
||||
)
|
||||
|
||||
|
||||
@execution_router.post(
|
||||
"/agents/{agent_id}/execute",
|
||||
response_model=ExecuteAgentResponse,
|
||||
)
|
||||
async def execute_agent(
|
||||
agent_id: str,
|
||||
request: ExecuteAgentRequest,
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> ExecuteAgentResponse:
|
||||
"""
|
||||
Execute an agent using granted credentials.
|
||||
|
||||
The agent must be accessible to the user, and the client must have
|
||||
valid credential grants that satisfy the agent's requirements.
|
||||
|
||||
Args:
|
||||
agent_id: The agent (graph) ID to execute
|
||||
request: Execution parameters including inputs and optional grant IDs
|
||||
"""
|
||||
# Verify the agent exists and user has access
|
||||
# First try to get the latest version
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=None,
|
||||
user_id=token.user_id,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
# Try to find it in the store (public agents)
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=None,
|
||||
user_id=None,
|
||||
skip_access_check=True,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Agent {agent_id} not found or not accessible",
|
||||
)
|
||||
|
||||
# Initialize the grant resolver to validate grants exist
|
||||
# The resolver context will be passed to the execution engine
|
||||
grant_resolver_context = None
|
||||
try:
|
||||
resolver = await create_resolver_from_oauth_token(
|
||||
user_id=token.user_id,
|
||||
client_public_id=token.client_id,
|
||||
grant_ids=request.grant_ids,
|
||||
)
|
||||
# Get available credentials info to build resolver context
|
||||
credentials_info = await resolver.get_available_credentials()
|
||||
grant_resolver_context = GrantResolverContext(
|
||||
client_db_id=resolver.client_id,
|
||||
grant_ids=[c["grant_id"] for c in credentials_info],
|
||||
)
|
||||
except GrantValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Grant validation failed: {str(e)}",
|
||||
)
|
||||
|
||||
try:
|
||||
# Build execution context with grant resolver info
|
||||
execution_context = ExecutionContext(
|
||||
grant_resolver_context=grant_resolver_context,
|
||||
)
|
||||
|
||||
# Execute the agent with grant resolver context
|
||||
graph_exec = await add_graph_execution(
|
||||
graph_id=agent_id,
|
||||
user_id=token.user_id,
|
||||
inputs=request.inputs,
|
||||
graph_version=graph.version,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
# Log the execution for audit
|
||||
logger.info(
|
||||
f"External execution started: agent={agent_id}, "
|
||||
f"execution={graph_exec.id}, client={token.client_id}, "
|
||||
f"user={token.user_id}"
|
||||
)
|
||||
|
||||
# Register webhook if provided
|
||||
if request.webhook_url:
|
||||
# Get client to check webhook domains
|
||||
client = await prisma.oauthclient.find_unique(
|
||||
where={"clientId": token.client_id}
|
||||
)
|
||||
if client:
|
||||
if not validate_webhook_url(request.webhook_url, client.webhookDomains):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Webhook URL not in allowed domains for this client",
|
||||
)
|
||||
|
||||
# Store webhook registration with client's webhook secret
|
||||
await prisma.executionwebhook.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"executionId": graph_exec.id,
|
||||
"webhookUrl": request.webhook_url,
|
||||
"clientId": client.id,
|
||||
"userId": token.user_id,
|
||||
"secret": client.webhookSecret,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Registered webhook for execution {graph_exec.id}: {request.webhook_url}"
|
||||
)
|
||||
|
||||
return ExecuteAgentResponse(
|
||||
execution_id=graph_exec.id,
|
||||
status="queued",
|
||||
message="Agent execution has been queued",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# Client error - invalid input or configuration
|
||||
logger.warning(
|
||||
f"Invalid execution request: agent={agent_id}, "
|
||||
f"client={token.client_id}, error={str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid request: {str(e)}",
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except Exception:
|
||||
# Server error - log full exception but don't expose details to client
|
||||
logger.exception(
|
||||
f"Unexpected error starting execution: agent={agent_id}, "
|
||||
f"client={token.client_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An internal error occurred while starting execution",
|
||||
)
|
||||
|
||||
|
||||
@execution_router.get(
|
||||
"/{execution_id}",
|
||||
response_model=ExecutionStatusResponse,
|
||||
)
|
||||
async def get_execution_status(
|
||||
execution_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> ExecutionStatusResponse:
|
||||
"""
|
||||
Get the status of an agent execution.
|
||||
|
||||
Returns current status, outputs (if completed), and any error messages.
|
||||
"""
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=token.user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Execution {execution_id} not found",
|
||||
)
|
||||
|
||||
# Build response
|
||||
outputs = None
|
||||
error = None
|
||||
|
||||
if graph_exec.status == AgentExecutionStatus.COMPLETED:
|
||||
outputs = graph_exec.outputs
|
||||
elif graph_exec.status == AgentExecutionStatus.FAILED:
|
||||
# Get error from execution stats
|
||||
# Note: Currently no standard error field in stats, but could be added
|
||||
error = "Execution failed"
|
||||
|
||||
return ExecutionStatusResponse(
|
||||
execution_id=execution_id,
|
||||
status=graph_exec.status.value,
|
||||
started_at=graph_exec.started_at,
|
||||
completed_at=graph_exec.ended_at,
|
||||
outputs=outputs,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
@execution_router.post("/{execution_id}/cancel")
|
||||
async def cancel_execution(
|
||||
execution_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> dict:
|
||||
"""
|
||||
Cancel a running agent execution.
|
||||
|
||||
Only executions in QUEUED or RUNNING status can be cancelled.
|
||||
"""
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=token.user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Execution {execution_id} not found",
|
||||
)
|
||||
|
||||
# Check if execution can be cancelled
|
||||
if graph_exec.status not in [
|
||||
AgentExecutionStatus.QUEUED,
|
||||
AgentExecutionStatus.RUNNING,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel execution with status {graph_exec.status.value}",
|
||||
)
|
||||
|
||||
# Update execution status to TERMINATED
|
||||
# Note: This is a simplified implementation. A full implementation would
|
||||
# need to signal the executor to stop processing.
|
||||
await prisma.agentgraphexecution.update(
|
||||
where={"id": execution_id},
|
||||
data={"executionStatus": AgentExecutionStatus.TERMINATED},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Execution terminated: execution={execution_id}, "
|
||||
f"client={token.client_id}, user={token.user_id}"
|
||||
)
|
||||
|
||||
return {"message": "Execution terminated", "execution_id": execution_id}
|
||||
207
autogpt_platform/backend/backend/server/external/routes/grants.py
vendored
Normal file
207
autogpt_platform/backend/backend/server/external/routes/grants.py
vendored
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Credential Grants endpoints for external OAuth clients.
|
||||
|
||||
Allows external applications to:
|
||||
- List their credential grants (metadata only, NOT credential values)
|
||||
- Get grant details
|
||||
- Delete credentials via grants (if permitted)
|
||||
|
||||
Credentials are NEVER returned to external applications.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import credential_grants as grants_db
|
||||
from backend.data.db import prisma
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
|
||||
|
||||
grants_router = APIRouter(prefix="/grants", tags=["grants"])
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Response Models
|
||||
# ================================================================
|
||||
|
||||
|
||||
class GrantSummary(BaseModel):
|
||||
"""Summary of a credential grant (returned in list endpoints)."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
granted_scopes: list[str]
|
||||
permissions: list[str]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class GrantDetail(BaseModel):
|
||||
"""Detailed grant information."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
credential_id: str
|
||||
granted_scopes: list[str]
|
||||
permissions: list[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
|
||||
@grants_router.get("/", response_model=list[GrantSummary])
|
||||
async def list_grants(
|
||||
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
|
||||
) -> list[GrantSummary]:
|
||||
"""
|
||||
List all active credential grants for this client and user.
|
||||
|
||||
Returns grant metadata but NOT credential values.
|
||||
Credentials are never exposed to external applications.
|
||||
"""
|
||||
# Get the OAuth client's database ID from the public client_id
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
||||
if not client:
|
||||
raise HTTPException(status_code=400, detail="Invalid client")
|
||||
|
||||
grants = await grants_db.get_grants_for_user_client(
|
||||
user_id=token.user_id,
|
||||
client_id=client.id,
|
||||
include_revoked=False,
|
||||
include_expired=False,
|
||||
)
|
||||
|
||||
return [
|
||||
GrantSummary(
|
||||
id=grant.id,
|
||||
provider=grant.provider,
|
||||
granted_scopes=grant.grantedScopes,
|
||||
permissions=[p.value for p in grant.permissions],
|
||||
created_at=grant.createdAt,
|
||||
last_used_at=grant.lastUsedAt,
|
||||
expires_at=grant.expiresAt,
|
||||
)
|
||||
for grant in grants
|
||||
]
|
||||
|
||||
|
||||
@grants_router.get("/{grant_id}", response_model=GrantDetail)
|
||||
async def get_grant(
|
||||
grant_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
|
||||
) -> GrantDetail:
|
||||
"""
|
||||
Get detailed information about a specific grant.
|
||||
|
||||
Returns grant metadata including scopes and permissions.
|
||||
Does NOT return the credential value.
|
||||
"""
|
||||
# Get the OAuth client's database ID
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
||||
if not client:
|
||||
raise HTTPException(status_code=400, detail="Invalid client")
|
||||
|
||||
grant = await grants_db.get_credential_grant(
|
||||
grant_id=grant_id,
|
||||
user_id=token.user_id,
|
||||
client_id=client.id,
|
||||
)
|
||||
|
||||
if not grant:
|
||||
raise HTTPException(status_code=404, detail="Grant not found")
|
||||
|
||||
# Check if expired
|
||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=404, detail="Grant has expired")
|
||||
|
||||
# Check if revoked
|
||||
if grant.revokedAt:
|
||||
raise HTTPException(status_code=404, detail="Grant has been revoked")
|
||||
|
||||
return GrantDetail(
|
||||
id=grant.id,
|
||||
provider=grant.provider,
|
||||
credential_id=grant.credentialId,
|
||||
granted_scopes=grant.grantedScopes,
|
||||
permissions=[p.value for p in grant.permissions],
|
||||
created_at=grant.createdAt,
|
||||
updated_at=grant.updatedAt,
|
||||
last_used_at=grant.lastUsedAt,
|
||||
expires_at=grant.expiresAt,
|
||||
revoked_at=grant.revokedAt,
|
||||
)
|
||||
|
||||
|
||||
@grants_router.delete("/{grant_id}/credential")
|
||||
async def delete_credential_via_grant(
|
||||
grant_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("integrations:delete")),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete the underlying credential associated with a grant.
|
||||
|
||||
This requires the grant to have the DELETE permission.
|
||||
Deleting the credential also invalidates all grants for that credential.
|
||||
"""
|
||||
from prisma.enums import CredentialGrantPermission
|
||||
|
||||
# Get the OAuth client's database ID
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
||||
if not client:
|
||||
raise HTTPException(status_code=400, detail="Invalid client")
|
||||
|
||||
# Get the grant
|
||||
grant = await grants_db.get_credential_grant(
|
||||
grant_id=grant_id,
|
||||
user_id=token.user_id,
|
||||
client_id=client.id,
|
||||
)
|
||||
|
||||
if not grant:
|
||||
raise HTTPException(status_code=404, detail="Grant not found")
|
||||
|
||||
# Check if grant is valid
|
||||
if grant.revokedAt:
|
||||
raise HTTPException(status_code=400, detail="Grant has been revoked")
|
||||
|
||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=400, detail="Grant has expired")
|
||||
|
||||
# Check DELETE permission
|
||||
if CredentialGrantPermission.DELETE not in grant.permissions:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Grant does not have DELETE permission for this credential",
|
||||
)
|
||||
|
||||
# Delete the credential using the credentials store
|
||||
try:
|
||||
creds_store = IntegrationCredentialsStore()
|
||||
await creds_store.delete_creds_by_id(
|
||||
user_id=token.user_id,
|
||||
credentials_id=grant.credentialId,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete credential: {str(e)}",
|
||||
)
|
||||
|
||||
# Revoke all grants for this credential
|
||||
await grants_db.revoke_grants_for_credential(
|
||||
user_id=token.user_id,
|
||||
credential_id=grant.credentialId,
|
||||
)
|
||||
|
||||
return {"message": "Credential deleted successfully"}
|
||||
@@ -1,650 +0,0 @@
|
||||
"""
|
||||
External API endpoints for integrations and credentials.
|
||||
|
||||
This module provides endpoints for external applications (like Autopilot) to:
|
||||
- Initiate OAuth flows with custom callback URLs
|
||||
- Complete OAuth flows by exchanging authorization codes
|
||||
- Create API key, user/password, and host-scoped credentials
|
||||
- List and manage user credentials
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.integrations.models import get_all_provider_names
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
integrations_router = APIRouter(prefix="/integrations", tags=["integrations"])
|
||||
|
||||
|
||||
# ==================== Request/Response Models ==================== #
|
||||
|
||||
|
||||
class OAuthInitiateRequest(BaseModel):
|
||||
"""Request model for initiating an OAuth flow."""
|
||||
|
||||
callback_url: str = Field(
|
||||
..., description="The external app's callback URL for OAuth redirect"
|
||||
)
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="OAuth scopes to request"
|
||||
)
|
||||
state_metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Arbitrary metadata to echo back on completion",
|
||||
)
|
||||
|
||||
|
||||
class OAuthInitiateResponse(BaseModel):
|
||||
"""Response model for OAuth initiation."""
|
||||
|
||||
login_url: str = Field(..., description="URL to redirect user for OAuth consent")
|
||||
state_token: str = Field(..., description="State token for CSRF protection")
|
||||
expires_at: int = Field(
|
||||
..., description="Unix timestamp when the state token expires"
|
||||
)
|
||||
|
||||
|
||||
class OAuthCompleteRequest(BaseModel):
|
||||
"""Request model for completing an OAuth flow."""
|
||||
|
||||
code: str = Field(..., description="Authorization code from OAuth provider")
|
||||
state_token: str = Field(..., description="State token from initiate request")
|
||||
|
||||
|
||||
class OAuthCompleteResponse(BaseModel):
|
||||
"""Response model for OAuth completion."""
|
||||
|
||||
credentials_id: str = Field(..., description="ID of the stored credentials")
|
||||
provider: str = Field(..., description="Provider name")
|
||||
type: str = Field(..., description="Credential type (oauth2)")
|
||||
title: Optional[str] = Field(None, description="Credential title")
|
||||
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
|
||||
username: Optional[str] = Field(None, description="Username from provider")
|
||||
state_metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Echoed metadata from initiate request"
|
||||
)
|
||||
|
||||
|
||||
class CredentialSummary(BaseModel):
|
||||
"""Summary of a credential without sensitive data."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
scopes: Optional[list[str]] = None
|
||||
username: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an integration provider."""
|
||||
|
||||
name: str
|
||||
supports_oauth: bool = False
|
||||
supports_api_key: bool = False
|
||||
supports_user_password: bool = False
|
||||
supports_host_scoped: bool = False
|
||||
default_scopes: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ==================== Credential Creation Models ==================== #
|
||||
|
||||
|
||||
class CreateAPIKeyCredentialRequest(BaseModel):
|
||||
"""Request model for creating API key credentials."""
|
||||
|
||||
type: Literal["api_key"] = "api_key"
|
||||
api_key: str = Field(..., description="The API key")
|
||||
title: str = Field(..., description="A name for this credential")
|
||||
expires_at: Optional[int] = Field(
|
||||
None, description="Unix timestamp when the API key expires"
|
||||
)
|
||||
|
||||
|
||||
class CreateUserPasswordCredentialRequest(BaseModel):
|
||||
"""Request model for creating username/password credentials."""
|
||||
|
||||
type: Literal["user_password"] = "user_password"
|
||||
username: str = Field(..., description="Username")
|
||||
password: str = Field(..., description="Password")
|
||||
title: str = Field(..., description="A name for this credential")
|
||||
|
||||
|
||||
class CreateHostScopedCredentialRequest(BaseModel):
|
||||
"""Request model for creating host-scoped credentials."""
|
||||
|
||||
type: Literal["host_scoped"] = "host_scoped"
|
||||
host: str = Field(..., description="Host/domain pattern to match")
|
||||
headers: dict[str, str] = Field(..., description="Headers to include in requests")
|
||||
title: str = Field(..., description="A name for this credential")
|
||||
|
||||
|
||||
# Union type for credential creation
|
||||
CreateCredentialRequest = Annotated[
|
||||
CreateAPIKeyCredentialRequest
|
||||
| CreateUserPasswordCredentialRequest
|
||||
| CreateHostScopedCredentialRequest,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class CreateCredentialResponse(BaseModel):
|
||||
"""Response model for credential creation."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== Helper Functions ==================== #
|
||||
|
||||
|
||||
def validate_callback_url(callback_url: str) -> bool:
|
||||
"""Validate that the callback URL is from an allowed origin."""
|
||||
allowed_origins = settings.config.external_oauth_callback_origins
|
||||
|
||||
try:
|
||||
parsed = urlparse(callback_url)
|
||||
callback_origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
for allowed in allowed_origins:
|
||||
# Simple origin matching
|
||||
if callback_origin == allowed:
|
||||
return True
|
||||
|
||||
# Allow localhost with any port in development (proper hostname check)
|
||||
if parsed.hostname == "localhost":
|
||||
for allowed in allowed_origins:
|
||||
allowed_parsed = urlparse(allowed)
|
||||
if allowed_parsed.hostname == "localhost":
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_oauth_handler_for_external(
|
||||
provider_name: str, redirect_uri: str
|
||||
) -> "BaseOAuthHandler":
|
||||
"""Get an OAuth handler configured with an external redirect URI."""
|
||||
# Ensure blocks are loaded so SDK providers are available
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks: {e}")
|
||||
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name}' does not support OAuth",
|
||||
)
|
||||
|
||||
# Check if this provider has custom OAuth credentials
|
||||
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_name)
|
||||
|
||||
if oauth_credentials and not oauth_credentials.use_secrets:
|
||||
import os
|
||||
|
||||
client_id = (
|
||||
os.getenv(oauth_credentials.client_id_env_var)
|
||||
if oauth_credentials.client_id_env_var
|
||||
else None
|
||||
)
|
||||
client_secret = (
|
||||
os.getenv(oauth_credentials.client_secret_env_var)
|
||||
if oauth_credentials.client_secret_env_var
|
||||
else None
|
||||
)
|
||||
else:
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id", None)
|
||||
client_secret = getattr(
|
||||
settings.secrets, f"{provider_name}_client_secret", None
|
||||
)
|
||||
|
||||
if not (client_id and client_secret):
|
||||
logger.error(f"Attempt to use unconfigured {provider_name} OAuth integration")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail={
|
||||
"message": f"Integration with provider '{provider_name}' is not configured.",
|
||||
"hint": "Set client ID and secret in the application's deployment environment",
|
||||
},
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Endpoints ==================== #
|
||||
|
||||
|
||||
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
||||
async def list_providers(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[ProviderInfo]:
|
||||
"""
|
||||
List all available integration providers.
|
||||
|
||||
Returns a list of all providers with their supported credential types.
|
||||
Most providers support API key credentials, and some also support OAuth.
|
||||
"""
|
||||
# Ensure blocks are loaded
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks: {e}")
|
||||
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
providers = []
|
||||
for name in get_all_provider_names():
|
||||
supports_oauth = name in HANDLERS_BY_NAME
|
||||
handler_class = HANDLERS_BY_NAME.get(name)
|
||||
default_scopes = (
|
||||
getattr(handler_class, "DEFAULT_SCOPES", []) if handler_class else []
|
||||
)
|
||||
|
||||
# Check if provider has specific auth types from SDK registration
|
||||
sdk_provider = AutoRegistry.get_provider(name)
|
||||
if sdk_provider and sdk_provider.supported_auth_types:
|
||||
supports_api_key = "api_key" in sdk_provider.supported_auth_types
|
||||
supports_user_password = (
|
||||
"user_password" in sdk_provider.supported_auth_types
|
||||
)
|
||||
supports_host_scoped = "host_scoped" in sdk_provider.supported_auth_types
|
||||
else:
|
||||
# Fallback for legacy providers
|
||||
supports_api_key = True # All providers can accept API keys
|
||||
supports_user_password = name in ("smtp",)
|
||||
supports_host_scoped = name == "http"
|
||||
|
||||
providers.append(
|
||||
ProviderInfo(
|
||||
name=name,
|
||||
supports_oauth=supports_oauth,
|
||||
supports_api_key=supports_api_key,
|
||||
supports_user_password=supports_user_password,
|
||||
supports_host_scoped=supports_host_scoped,
|
||||
default_scopes=default_scopes,
|
||||
)
|
||||
)
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
@integrations_router.post(
|
||||
"/{provider}/oauth/initiate",
|
||||
response_model=OAuthInitiateResponse,
|
||||
summary="Initiate OAuth flow",
|
||||
)
|
||||
async def initiate_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthInitiateRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthInitiateResponse:
|
||||
"""
|
||||
Initiate an OAuth flow for an external application.
|
||||
|
||||
This endpoint allows external apps to start an OAuth flow with a custom
|
||||
callback URL. The callback URL must be from an allowed origin configured
|
||||
in the platform settings.
|
||||
|
||||
Returns a login URL to redirect the user to, along with a state token
|
||||
for CSRF protection.
|
||||
"""
|
||||
# Validate callback URL
|
||||
if not validate_callback_url(request.callback_url):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
)
|
||||
|
||||
# Validate provider
|
||||
try:
|
||||
provider_name = ProviderName(provider)
|
||||
except ValueError:
|
||||
# Check if it's a dynamically registered provider
|
||||
if provider not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider}' not found",
|
||||
)
|
||||
provider_name = provider
|
||||
|
||||
# Get OAuth handler with external callback URL
|
||||
handler = _get_oauth_handler_for_external(
|
||||
provider if isinstance(provider_name, str) else provider_name.value,
|
||||
request.callback_url,
|
||||
)
|
||||
|
||||
# Store state token with external flow metadata
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id=api_key.user_id,
|
||||
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
||||
scopes=request.scopes,
|
||||
callback_url=request.callback_url,
|
||||
state_metadata=request.state_metadata,
|
||||
initiated_by_api_key_id=api_key.id,
|
||||
)
|
||||
|
||||
# Build login URL
|
||||
login_url = handler.get_login_url(
|
||||
request.scopes, state_token, code_challenge=code_challenge
|
||||
)
|
||||
|
||||
# Calculate expiration (10 minutes from now)
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
expires_at = int((datetime.now(timezone.utc) + timedelta(minutes=10)).timestamp())
|
||||
|
||||
return OAuthInitiateResponse(
|
||||
login_url=login_url,
|
||||
state_token=state_token,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.post(
|
||||
"/{provider}/oauth/complete",
|
||||
response_model=OAuthCompleteResponse,
|
||||
summary="Complete OAuth flow",
|
||||
)
|
||||
async def complete_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthCompleteRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthCompleteResponse:
|
||||
"""
|
||||
Complete an OAuth flow by exchanging the authorization code for tokens.
|
||||
|
||||
This endpoint should be called after the user has authorized the application
|
||||
and been redirected back to the external app's callback URL with an
|
||||
authorization code.
|
||||
"""
|
||||
# Verify state token
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
api_key.user_id, request.state_token, provider
|
||||
)
|
||||
|
||||
if not valid_state:
|
||||
logger.warning(f"Invalid or expired state token for provider {provider}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired state token",
|
||||
)
|
||||
|
||||
# Verify this is an external flow (callback_url must be set)
|
||||
if not valid_state.callback_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="State token was not created for external OAuth flow",
|
||||
)
|
||||
|
||||
# Get OAuth handler with the original callback URL
|
||||
handler = _get_oauth_handler_for_external(provider, valid_state.callback_url)
|
||||
|
||||
try:
|
||||
scopes = valid_state.scopes
|
||||
scopes = handler.handle_default_scopes(scopes)
|
||||
|
||||
credentials = await handler.exchange_code_for_tokens(
|
||||
request.code, scopes, valid_state.code_verifier
|
||||
)
|
||||
|
||||
# Handle Linear's space-separated scopes
|
||||
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
|
||||
credentials.scopes = credentials.scopes[0].split(" ")
|
||||
|
||||
# Check scope mismatch
|
||||
if not set(scopes).issubset(set(credentials.scopes)):
|
||||
logger.warning(
|
||||
f"Granted scopes {credentials.scopes} for provider {provider} "
|
||||
f"do not include all requested scopes {scopes}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth2 Code->Token exchange failed for provider {provider}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
|
||||
)
|
||||
|
||||
# Store credentials
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
|
||||
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
||||
|
||||
return OAuthCompleteResponse(
|
||||
credentials_id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
state_metadata=valid_state.state_metadata,
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
async def list_credentials(
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
"""
|
||||
List all credentials for the authenticated user.
|
||||
|
||||
Returns metadata about each credential without exposing sensitive tokens.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
"""
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_creds_by_provider(
|
||||
api_key.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.post(
|
||||
"/{provider}/credentials",
|
||||
response_model=CreateCredentialResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create credentials",
|
||||
)
|
||||
async def create_credential(
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
request: Union[
|
||||
CreateAPIKeyCredentialRequest,
|
||||
CreateUserPasswordCredentialRequest,
|
||||
CreateHostScopedCredentialRequest,
|
||||
] = Body(..., discriminator="type"),
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> CreateCredentialResponse:
|
||||
"""
|
||||
Create non-OAuth credentials for a provider.
|
||||
|
||||
Supports creating:
|
||||
- API key credentials (type: "api_key")
|
||||
- Username/password credentials (type: "user_password")
|
||||
- Host-scoped credentials (type: "host_scoped")
|
||||
|
||||
For OAuth credentials, use the OAuth initiate/complete flow instead.
|
||||
"""
|
||||
# Validate provider exists
|
||||
all_providers = get_all_provider_names()
|
||||
if provider not in all_providers:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider}' not found",
|
||||
)
|
||||
|
||||
# Create the appropriate credential type
|
||||
credentials: Credentials
|
||||
if request.type == "api_key":
|
||||
credentials = APIKeyCredentials(
|
||||
provider=provider,
|
||||
api_key=SecretStr(request.api_key),
|
||||
title=request.title,
|
||||
expires_at=request.expires_at,
|
||||
)
|
||||
elif request.type == "user_password":
|
||||
credentials = UserPasswordCredentials(
|
||||
provider=provider,
|
||||
username=SecretStr(request.username),
|
||||
password=SecretStr(request.password),
|
||||
title=request.title,
|
||||
)
|
||||
elif request.type == "host_scoped":
|
||||
# Convert string headers to SecretStr
|
||||
secret_headers = {k: SecretStr(v) for k, v in request.headers.items()}
|
||||
credentials = HostScopedCredentials(
|
||||
provider=provider,
|
||||
host=request.host,
|
||||
headers=secret_headers,
|
||||
title=request.title,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported credential type: {request.type}",
|
||||
)
|
||||
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
)
|
||||
|
||||
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||
|
||||
return CreateCredentialResponse(
|
||||
id=credentials.id,
|
||||
provider=provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
)
|
||||
|
||||
|
||||
class DeleteCredentialResponse(BaseModel):
|
||||
"""Response model for deleting a credential."""
|
||||
|
||||
deleted: bool = Field(..., description="Whether the credential was deleted")
|
||||
credentials_id: str = Field(..., description="ID of the deleted credential")
|
||||
|
||||
|
||||
@integrations_router.delete(
|
||||
"/{provider}/credentials/{cred_id}",
|
||||
response_model=DeleteCredentialResponse,
|
||||
)
|
||||
async def delete_credential(
|
||||
provider: Annotated[str, Path(title="The provider")],
|
||||
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
||||
),
|
||||
) -> DeleteCredentialResponse:
|
||||
"""
|
||||
Delete a credential.
|
||||
|
||||
Note: This does not revoke the tokens with the provider. For full cleanup,
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
await creds_manager.delete(api_key.user_id, cred_id)
|
||||
|
||||
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
||||
@@ -1,148 +0,0 @@
|
||||
"""External API routes for chat tools - stateless HTTP endpoints.
|
||||
|
||||
Note: These endpoints use ephemeral sessions that are not persisted to Redis.
|
||||
As a result, session-based rate limiting (max_agent_runs, max_agent_schedules)
|
||||
is not enforced for external API calls. Each request creates a fresh session
|
||||
with zeroed counters. Rate limiting for external API consumers should be
|
||||
handled separately (e.g., via API key quotas).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.server.v2.chat.tools.models import ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
|
||||
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
||||
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
|
||||
|
||||
|
||||
# Request models
|
||||
class FindAgentRequest(BaseModel):
|
||||
query: str = Field(..., description="Search query for finding agents")
|
||||
|
||||
|
||||
class RunAgentRequest(BaseModel):
|
||||
"""Request to run or schedule an agent.
|
||||
|
||||
The tool automatically handles the setup flow:
|
||||
- First call returns available inputs so user can decide what values to use
|
||||
- Returns missing credentials if user needs to configure them
|
||||
- Executes when inputs are provided OR use_defaults=true
|
||||
- Schedules execution if schedule_name and cron are provided
|
||||
"""
|
||||
|
||||
username_agent_slug: str = Field(
|
||||
...,
|
||||
description="The marketplace agent slug (e.g., 'username/agent-name')",
|
||||
)
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary of input values for the agent",
|
||||
)
|
||||
use_defaults: bool = Field(
|
||||
default=False,
|
||||
description="Set to true to run with default values (user must confirm)",
|
||||
)
|
||||
schedule_name: str | None = Field(
|
||||
None,
|
||||
description="Name for scheduled execution (triggers scheduling mode)",
|
||||
)
|
||||
cron: str | None = Field(
|
||||
None,
|
||||
description="Cron expression (5 fields: minute hour day month weekday)",
|
||||
)
|
||||
timezone: str = Field(
|
||||
default="UTC",
|
||||
description="IANA timezone (e.g., 'America/New_York', 'UTC')",
|
||||
)
|
||||
|
||||
|
||||
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
path="/find-agent",
|
||||
)
|
||||
async def find_agent(
|
||||
request: FindAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for agents in the marketplace based on capabilities and user needs.
|
||||
|
||||
Args:
|
||||
request: Search query for finding agents
|
||||
|
||||
Returns:
|
||||
List of matching agents or no results response
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
result = await find_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
session=session,
|
||||
query=request.query,
|
||||
)
|
||||
return _response_to_dict(result)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
path="/run-agent",
|
||||
)
|
||||
async def run_agent(
|
||||
request: RunAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run or schedule an agent from the marketplace.
|
||||
|
||||
The endpoint automatically handles the setup flow:
|
||||
- Returns missing inputs if required fields are not provided
|
||||
- Returns missing credentials if user needs to configure them
|
||||
- Executes immediately if all requirements are met
|
||||
- Schedules execution if schedule_name and cron are provided
|
||||
|
||||
For scheduled execution:
|
||||
- Cron format: "minute hour day month weekday"
|
||||
- Examples: "0 9 * * 1-5" (9am weekdays), "0 0 * * *" (daily at midnight)
|
||||
- Timezone: Use IANA timezone names like "America/New_York"
|
||||
|
||||
Args:
|
||||
request: Agent slug, inputs, and optional schedule config
|
||||
|
||||
Returns:
|
||||
- setup_requirements: If inputs or credentials are missing
|
||||
- execution_started: If agent was run or scheduled successfully
|
||||
- error: If something went wrong
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
result = await run_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
session=session,
|
||||
username_agent_slug=request.username_agent_slug,
|
||||
inputs=request.inputs,
|
||||
use_defaults=request.use_defaults,
|
||||
schedule_name=request.schedule_name or "",
|
||||
cron=request.cron or "",
|
||||
timezone=request.timezone,
|
||||
)
|
||||
return _response_to_dict(result)
|
||||
|
||||
|
||||
def _response_to_dict(result: ToolResponseBase) -> dict[str, Any]:
|
||||
"""Convert a tool response to a dictionary for JSON serialization."""
|
||||
return result.model_dump()
|
||||
@@ -1,295 +0,0 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
v1_router = APIRouter()
|
||||
|
||||
|
||||
class NodeOutput(TypedDict):
|
||||
key: str
|
||||
value: Any
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: list[NodeOutput]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
return [b.to_dict() for b in blocks if not b.disabled]
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/blocks/{block_id}/execute",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
|
||||
)
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
)
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=api_key.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
return {"id": graph_exec.id}
|
||||
except Exception as e:
|
||||
msg = str(e).encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||
tags=["graphs"],
|
||||
)
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
) -> GraphExecutionResult:
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=api_key.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
if not await graph_db.get_graph(
|
||||
graph_id=graph_exec.graph_id,
|
||||
version=graph_exec.graph_version,
|
||||
user_id=api_key.user_id,
|
||||
):
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return GraphExecutionResult(
|
||||
execution_id=graph_exec_id,
|
||||
status=graph_exec.status.value,
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
node_id=node_exec.node_id,
|
||||
input=node_exec.input_data.get("value", node_exec.input_data),
|
||||
output={k: v for k, v in node_exec.output_data.items()},
|
||||
)
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
output=(
|
||||
[
|
||||
{name: value}
|
||||
for name, values in graph_exec.outputs.items()
|
||||
for value in values
|
||||
]
|
||||
if graph_exec.status == AgentExecutionStatus.COMPLETED
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Store Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/store/agents",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured: Filter to only show featured agents
|
||||
creator: Filter agents by creator username
|
||||
sorted_by: Sort agents by "runs", "rating", "name", or "updated_at"
|
||||
search_query: Search agents by name, subheading and description
|
||||
category: Filter agents by category
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of agents per page (default 20)
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
"""
|
||||
if page < 1:
|
||||
raise HTTPException(status_code=422, detail="Page must be greater than 0")
|
||||
|
||||
if page_size < 1:
|
||||
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
|
||||
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/store/agents/{username}/{agent_name}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""
|
||||
Get details of a specific store agent by username and agent name.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
|
||||
Returns:
|
||||
StoreAgentDetails: Detailed information about the agent
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await store_cache._get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/store/creators",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.CreatorsResponse:
|
||||
"""
|
||||
Get a paginated list of store creators with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured: Filter to only show featured creators
|
||||
search_query: Search creators by profile description
|
||||
sorted_by: Sort by "agent_rating", "agent_runs", or "num_agents"
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of creators per page (default 20)
|
||||
|
||||
Returns:
|
||||
CreatorsResponse: Paginated list of creators matching the filters
|
||||
"""
|
||||
if page < 1:
|
||||
raise HTTPException(status_code=422, detail="Page must be greater than 0")
|
||||
|
||||
if page_size < 1:
|
||||
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
|
||||
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return creators
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/store/creators/{username}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_store_creator(
|
||||
username: str,
|
||||
) -> store_model.CreatorDetails:
|
||||
"""
|
||||
Get details of a specific store creator by username.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
|
||||
Returns:
|
||||
CreatorDetails: Detailed information about the creator
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Security utilities for the integration connect popup flow.
|
||||
|
||||
Handles state management, nonce validation, and origin verification
|
||||
for the OAuth-style popup flow when connecting integrations.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from prisma.models import OAuthClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# State expiration time
|
||||
STATE_EXPIRATION_SECONDS = 600 # 10 minutes
|
||||
NONCE_EXPIRATION_SECONDS = 3600 # 1 hour (nonces valid for longer to prevent races)
|
||||
LOGIN_STATE_EXPIRATION_SECONDS = 600 # 10 minutes for login redirect flow
|
||||
|
||||
|
||||
class ConnectState(BaseModel):
|
||||
"""Pydantic model for connect state stored in Redis."""
|
||||
|
||||
user_id: str
|
||||
client_id: str
|
||||
provider: str
|
||||
requested_scopes: list[str]
|
||||
redirect_origin: str
|
||||
nonce: str
|
||||
credential_id: Optional[str] = None
|
||||
created_at: str
|
||||
expires_at: str
|
||||
|
||||
|
||||
class ConnectContinuationState(BaseModel):
|
||||
"""
|
||||
State for continuing the connect flow after OAuth completes.
|
||||
|
||||
When a user chooses to "connect new" during the connect flow,
|
||||
we store this state so we can complete the grant creation after
|
||||
the OAuth callback.
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
client_id: str # Public client ID
|
||||
client_db_id: str # Database UUID of the OAuth client
|
||||
provider: str
|
||||
requested_scopes: list[str] # Integration scopes (e.g., "google:gmail.readonly")
|
||||
redirect_origin: str
|
||||
nonce: str
|
||||
created_at: str
|
||||
|
||||
|
||||
class ConnectLoginState(BaseModel):
|
||||
"""
|
||||
State for connect flow when user needs to log in first.
|
||||
|
||||
When an unauthenticated user tries to access /connect/{provider},
|
||||
we store the connect parameters and redirect to login. After login,
|
||||
the user is redirected back to complete the connect flow.
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
provider: str
|
||||
requested_scopes: list[str]
|
||||
redirect_origin: str
|
||||
nonce: str
|
||||
created_at: str
|
||||
expires_at: str
|
||||
|
||||
|
||||
# Continuation state expiration (same as regular state)
|
||||
CONTINUATION_EXPIRATION_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
async def store_connect_continuation(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
client_db_id: str,
|
||||
provider: str,
|
||||
requested_scopes: list[str],
|
||||
redirect_origin: str,
|
||||
nonce: str,
|
||||
) -> str:
|
||||
"""
|
||||
Store continuation state for completing connect flow after OAuth.
|
||||
|
||||
Args:
|
||||
user_id: User initiating the connection
|
||||
client_id: Public OAuth client ID
|
||||
client_db_id: Database UUID of the OAuth client
|
||||
provider: Integration provider name
|
||||
requested_scopes: Requested integration scopes
|
||||
redirect_origin: Origin to send postMessage to
|
||||
nonce: Client-provided nonce for replay protection
|
||||
|
||||
Returns:
|
||||
Continuation token to be stored in OAuth state metadata
|
||||
"""
|
||||
token = generate_connect_token()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
state = ConnectContinuationState(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
client_db_id=client_db_id,
|
||||
provider=provider,
|
||||
requested_scopes=requested_scopes,
|
||||
redirect_origin=redirect_origin,
|
||||
nonce=nonce,
|
||||
created_at=now.isoformat(),
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_continuation:{token}"
|
||||
await redis.setex(key, CONTINUATION_EXPIRATION_SECONDS, state.model_dump_json())
|
||||
|
||||
logger.debug(f"Stored connect continuation state for token {token[:8]}...")
|
||||
return token
|
||||
|
||||
|
||||
async def get_connect_continuation(token: str) -> Optional[ConnectContinuationState]:
|
||||
"""
|
||||
Get continuation state without consuming it.
|
||||
|
||||
Args:
|
||||
token: Continuation token
|
||||
|
||||
Returns:
|
||||
ConnectContinuationState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_continuation:{token}"
|
||||
data = await redis.get(key)
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
return ConnectContinuationState.model_validate_json(data)
|
||||
|
||||
|
||||
async def consume_connect_continuation(
|
||||
token: str,
|
||||
) -> Optional[ConnectContinuationState]:
|
||||
"""
|
||||
Get and consume (delete) continuation state.
|
||||
|
||||
This ensures the token can only be used once.
|
||||
|
||||
Args:
|
||||
token: Continuation token
|
||||
|
||||
Returns:
|
||||
ConnectContinuationState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_continuation:{token}"
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions
|
||||
data = await redis.getdel(key)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
state = ConnectContinuationState.model_validate_json(data)
|
||||
logger.debug(f"Consumed connect continuation state for token {token[:8]}...")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def generate_connect_token() -> str:
|
||||
"""Generate a secure random token for connect state."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
async def store_connect_state(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
provider: str,
|
||||
requested_scopes: list[str],
|
||||
redirect_origin: str,
|
||||
nonce: str,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Store connect state in Redis and return a state token.
|
||||
|
||||
Args:
|
||||
user_id: User initiating the connection
|
||||
client_id: OAuth client ID (public identifier)
|
||||
provider: Integration provider name
|
||||
requested_scopes: Requested integration scopes
|
||||
redirect_origin: Origin to send postMessage to
|
||||
nonce: Client-provided nonce for replay protection
|
||||
credential_id: Optional existing credential to grant access to
|
||||
|
||||
Returns:
|
||||
State token to be used in the connect flow
|
||||
"""
|
||||
token = generate_connect_token()
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now.timestamp() + STATE_EXPIRATION_SECONDS
|
||||
|
||||
state = ConnectState(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
provider=provider,
|
||||
requested_scopes=requested_scopes,
|
||||
redirect_origin=redirect_origin,
|
||||
nonce=nonce,
|
||||
credential_id=credential_id,
|
||||
created_at=now.isoformat(),
|
||||
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_state:{token}"
|
||||
await redis.setex(key, STATE_EXPIRATION_SECONDS, state.model_dump_json())
|
||||
|
||||
logger.debug(f"Stored connect state for token {token[:8]}...")
|
||||
return token
|
||||
|
||||
|
||||
async def get_connect_state(token: str) -> Optional[ConnectState]:
|
||||
"""
|
||||
Get connect state without consuming it.
|
||||
|
||||
Args:
|
||||
token: State token
|
||||
|
||||
Returns:
|
||||
ConnectState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_state:{token}"
|
||||
data = await redis.get(key)
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
return ConnectState.model_validate_json(data)
|
||||
|
||||
|
||||
async def consume_connect_state(token: str) -> Optional[ConnectState]:
|
||||
"""
|
||||
Get and consume (delete) connect state.
|
||||
|
||||
This ensures the token can only be used once.
|
||||
|
||||
Args:
|
||||
token: State token
|
||||
|
||||
Returns:
|
||||
ConnectState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_state:{token}"
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions
|
||||
data = await redis.getdel(key)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
state = ConnectState.model_validate_json(data)
|
||||
logger.debug(f"Consumed connect state for token {token[:8]}...")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
async def store_connect_login_state(
|
||||
client_id: str,
|
||||
provider: str,
|
||||
requested_scopes: list[str],
|
||||
redirect_origin: str,
|
||||
nonce: str,
|
||||
) -> str:
|
||||
"""
|
||||
Store connect parameters for unauthenticated users.
|
||||
|
||||
When a user isn't logged in, we store the connect params and redirect
|
||||
to login. After login, the frontend calls /connect/resume with the token.
|
||||
|
||||
Args:
|
||||
client_id: OAuth client ID
|
||||
provider: Integration provider name
|
||||
requested_scopes: Requested integration scopes
|
||||
redirect_origin: Origin to send postMessage to
|
||||
nonce: Client-provided nonce for replay protection
|
||||
|
||||
Returns:
|
||||
Login state token to be used after login completes
|
||||
"""
|
||||
token = generate_connect_token()
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now.timestamp() + LOGIN_STATE_EXPIRATION_SECONDS
|
||||
|
||||
state = ConnectLoginState(
|
||||
client_id=client_id,
|
||||
provider=provider,
|
||||
requested_scopes=requested_scopes,
|
||||
redirect_origin=redirect_origin,
|
||||
nonce=nonce,
|
||||
created_at=now.isoformat(),
|
||||
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_login_state:{token}"
|
||||
await redis.setex(key, LOGIN_STATE_EXPIRATION_SECONDS, state.model_dump_json())
|
||||
|
||||
logger.debug(f"Stored connect login state for token {token[:8]}...")
|
||||
return token
|
||||
|
||||
|
||||
async def get_connect_login_state(token: str) -> Optional[ConnectLoginState]:
|
||||
"""
|
||||
Get connect login state without consuming it.
|
||||
|
||||
Args:
|
||||
token: Login state token
|
||||
|
||||
Returns:
|
||||
ConnectLoginState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_login_state:{token}"
|
||||
data = await redis.get(key)
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
return ConnectLoginState.model_validate_json(data)
|
||||
|
||||
|
||||
async def consume_connect_login_state(token: str) -> Optional[ConnectLoginState]:
|
||||
"""
|
||||
Get and consume (delete) connect login state.
|
||||
|
||||
This ensures the token can only be used once.
|
||||
|
||||
Args:
|
||||
token: Login state token
|
||||
|
||||
Returns:
|
||||
ConnectLoginState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_login_state:{token}"
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions
|
||||
data = await redis.getdel(key)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
state = ConnectLoginState.model_validate_json(data)
|
||||
logger.debug(f"Consumed connect login state for token {token[:8]}...")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
async def validate_nonce(client_id: str, nonce: str) -> bool:
|
||||
"""
|
||||
Validate that a nonce hasn't been used before (replay protection).
|
||||
|
||||
Uses atomic SET NX EX for check-and-set with automatic TTL expiry.
|
||||
|
||||
Args:
|
||||
client_id: OAuth client ID
|
||||
nonce: Client-provided nonce
|
||||
|
||||
Returns:
|
||||
True if nonce is valid (not replayed)
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Create a hash of the nonce for storage
|
||||
nonce_hash = hashlib.sha256(nonce.encode()).hexdigest()
|
||||
key = f"nonce:{client_id}:{nonce_hash}"
|
||||
|
||||
# Atomic set-if-not-exists with expiration (prevents race condition)
|
||||
was_set = await redis.set(key, "1", nx=True, ex=NONCE_EXPIRATION_SECONDS)
|
||||
if was_set:
|
||||
return True
|
||||
|
||||
logger.warning(f"Nonce replay detected for client {client_id}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_redirect_origin(origin: str, client: OAuthClient) -> bool:
|
||||
"""
|
||||
Validate that a redirect origin is allowed for the client.
|
||||
|
||||
The origin must match one of the client's registered redirect URIs
|
||||
or webhook domains.
|
||||
|
||||
Args:
|
||||
origin: Origin URL to validate
|
||||
client: OAuth client to check against
|
||||
|
||||
Returns:
|
||||
True if origin is allowed
|
||||
"""
|
||||
from backend.util.url import hostname_matches_any_domain
|
||||
|
||||
try:
|
||||
parsed_origin = urlparse(origin)
|
||||
origin_host = parsed_origin.netloc.lower()
|
||||
|
||||
# Check against redirect URIs
|
||||
for redirect_uri in client.redirectUris:
|
||||
parsed_redirect = urlparse(redirect_uri)
|
||||
if parsed_redirect.netloc.lower() == origin_host:
|
||||
return True
|
||||
|
||||
# Check against webhook domains
|
||||
if hostname_matches_any_domain(origin_host, client.webhookDomains):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def create_post_message_data(
|
||||
success: bool,
|
||||
grant_id: Optional[str] = None,
|
||||
credential_id: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
nonce: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create the postMessage data to send back to the opener.
|
||||
|
||||
Args:
|
||||
success: Whether the operation succeeded
|
||||
grant_id: ID of the created grant (if successful)
|
||||
credential_id: ID of the credential (if successful)
|
||||
provider: Provider name
|
||||
error: Error code (if failed)
|
||||
error_description: Human-readable error description
|
||||
nonce: Original nonce for correlation
|
||||
|
||||
Returns:
|
||||
Dictionary to be sent via postMessage
|
||||
"""
|
||||
data: dict[str, Any] = {
|
||||
"type": "autogpt_connect_result",
|
||||
"success": success,
|
||||
}
|
||||
|
||||
if nonce:
|
||||
data["nonce"] = nonce
|
||||
|
||||
if success:
|
||||
data["grant_id"] = grant_id
|
||||
data["credential_id"] = credential_id
|
||||
data["provider"] = provider
|
||||
else:
|
||||
data["error"] = error
|
||||
data["error_description"] = error_description
|
||||
|
||||
return data
|
||||
20
autogpt_platform/backend/backend/server/oauth/__init__.py
Normal file
20
autogpt_platform/backend/backend/server/oauth/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
OAuth 2.0 Provider module for AutoGPT Platform.
|
||||
|
||||
This module implements AutoGPT as an OAuth 2.0 Authorization Server,
|
||||
allowing external applications to authenticate users and access
|
||||
platform resources with user consent.
|
||||
|
||||
Key components:
|
||||
- router.py: OAuth authorization and token endpoints
|
||||
- discovery_router.py: OIDC discovery endpoints
|
||||
- client_router.py: OAuth client management
|
||||
- token_service.py: JWT generation and validation
|
||||
- service.py: Core OAuth business logic
|
||||
"""
|
||||
|
||||
from backend.server.oauth.client_router import client_router
|
||||
from backend.server.oauth.discovery_router import discovery_router
|
||||
from backend.server.oauth.router import oauth_router
|
||||
|
||||
__all__ = ["oauth_router", "discovery_router", "client_router"]
|
||||
367
autogpt_platform/backend/backend/server/oauth/client_router.py
Normal file
367
autogpt_platform/backend/backend/server/oauth/client_router.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
OAuth Client Management endpoints.
|
||||
|
||||
Implements self-service client registration and management:
|
||||
- POST /oauth/clients - Register a new client
|
||||
- GET /oauth/clients - List owned clients
|
||||
- GET /oauth/clients/{client_id} - Get client details
|
||||
- PATCH /oauth/clients/{client_id} - Update client
|
||||
- DELETE /oauth/clients/{client_id} - Delete client
|
||||
- POST /oauth/clients/{client_id}/rotate-secret - Rotate client secret
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.enums import OAuthClientStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.server.oauth.models import (
|
||||
ClientResponse,
|
||||
ClientSecretResponse,
|
||||
OAuthScope,
|
||||
RegisterClientRequest,
|
||||
UpdateClientRequest,
|
||||
)
|
||||
|
||||
client_router = APIRouter(prefix="/oauth/clients", tags=["oauth-clients"])
|
||||
|
||||
|
||||
def _generate_client_id() -> str:
|
||||
"""Generate a unique client ID."""
|
||||
return f"app_{secrets.token_urlsafe(16)}"
|
||||
|
||||
|
||||
def _generate_client_secret() -> str:
|
||||
"""Generate a secure client secret."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _generate_webhook_secret() -> str:
|
||||
"""Generate a secure webhook secret for HMAC signing."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _hash_secret(secret: str, salt: str) -> str:
|
||||
"""Hash a client secret with salt."""
|
||||
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
|
||||
|
||||
|
||||
def _client_to_response(client) -> ClientResponse:
|
||||
"""Convert Prisma client to response model."""
|
||||
return ClientResponse(
|
||||
id=client.id,
|
||||
client_id=client.clientId,
|
||||
client_type=client.clientType,
|
||||
name=client.name,
|
||||
description=client.description,
|
||||
logo_url=client.logoUrl,
|
||||
homepage_url=client.homepageUrl,
|
||||
privacy_policy_url=client.privacyPolicyUrl,
|
||||
terms_of_service_url=client.termsOfServiceUrl,
|
||||
redirect_uris=client.redirectUris,
|
||||
allowed_scopes=client.allowedScopes,
|
||||
webhook_domains=client.webhookDomains,
|
||||
status=client.status,
|
||||
created_at=client.createdAt,
|
||||
updated_at=client.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
# Default allowed scopes for new clients
|
||||
DEFAULT_ALLOWED_SCOPES = [
|
||||
OAuthScope.OPENID.value,
|
||||
OAuthScope.PROFILE.value,
|
||||
OAuthScope.EMAIL.value,
|
||||
OAuthScope.INTEGRATIONS_LIST.value,
|
||||
OAuthScope.INTEGRATIONS_CONNECT.value,
|
||||
OAuthScope.INTEGRATIONS_DELETE.value,
|
||||
OAuthScope.AGENTS_EXECUTE.value,
|
||||
]
|
||||
|
||||
|
||||
@client_router.post("/", response_model=ClientSecretResponse)
|
||||
async def register_client(
|
||||
request: RegisterClientRequest,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientSecretResponse:
|
||||
"""
|
||||
Register a new OAuth client.
|
||||
|
||||
The client is immediately active (no admin approval required).
|
||||
For confidential clients, the client_secret is returned only once.
|
||||
The webhook_secret is always generated and returned only once.
|
||||
"""
|
||||
# Generate client credentials
|
||||
client_id = _generate_client_id()
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
client_secret_salt = None
|
||||
|
||||
if request.client_type == "confidential":
|
||||
client_secret = _generate_client_secret()
|
||||
client_secret_salt = secrets.token_urlsafe(16)
|
||||
client_secret_hash = _hash_secret(client_secret, client_secret_salt)
|
||||
|
||||
# Generate webhook secret for HMAC signing
|
||||
webhook_secret = _generate_webhook_secret()
|
||||
|
||||
# Create client
|
||||
await prisma.oauthclient.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"clientId": client_id,
|
||||
"clientSecretHash": client_secret_hash,
|
||||
"clientSecretSalt": client_secret_salt,
|
||||
"clientType": request.client_type,
|
||||
"name": request.name,
|
||||
"description": request.description,
|
||||
"logoUrl": str(request.logo_url) if request.logo_url else None,
|
||||
"homepageUrl": str(request.homepage_url) if request.homepage_url else None,
|
||||
"privacyPolicyUrl": (
|
||||
str(request.privacy_policy_url) if request.privacy_policy_url else None
|
||||
),
|
||||
"termsOfServiceUrl": (
|
||||
str(request.terms_of_service_url)
|
||||
if request.terms_of_service_url
|
||||
else None
|
||||
),
|
||||
"redirectUris": request.redirect_uris,
|
||||
"allowedScopes": DEFAULT_ALLOWED_SCOPES,
|
||||
"webhookDomains": request.webhook_domains,
|
||||
"webhookSecret": webhook_secret,
|
||||
"status": OAuthClientStatus.ACTIVE,
|
||||
"ownerId": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return ClientSecretResponse(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret or "",
|
||||
webhook_secret=webhook_secret,
|
||||
)
|
||||
|
||||
|
||||
@client_router.get("/", response_model=list[ClientResponse])
|
||||
async def list_clients(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[ClientResponse]:
|
||||
"""List all OAuth clients owned by the current user."""
|
||||
clients = await prisma.oauthclient.find_many(
|
||||
where={"ownerId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [_client_to_response(c) for c in clients]
|
||||
|
||||
|
||||
@client_router.get("/{client_id}", response_model=ClientResponse)
|
||||
async def get_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Get details of a specific OAuth client."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
return _client_to_response(client)
|
||||
|
||||
|
||||
@client_router.patch("/{client_id}", response_model=ClientResponse)
|
||||
async def update_client(
|
||||
client_id: str,
|
||||
request: UpdateClientRequest,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Update an OAuth client."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
# Build update data
|
||||
update_data: dict[str, str | list[str] | None] = {}
|
||||
if request.name is not None:
|
||||
update_data["name"] = request.name
|
||||
if request.description is not None:
|
||||
update_data["description"] = request.description
|
||||
if request.logo_url is not None:
|
||||
update_data["logoUrl"] = str(request.logo_url)
|
||||
if request.homepage_url is not None:
|
||||
update_data["homepageUrl"] = str(request.homepage_url)
|
||||
if request.privacy_policy_url is not None:
|
||||
update_data["privacyPolicyUrl"] = str(request.privacy_policy_url)
|
||||
if request.terms_of_service_url is not None:
|
||||
update_data["termsOfServiceUrl"] = str(request.terms_of_service_url)
|
||||
if request.redirect_uris is not None:
|
||||
update_data["redirectUris"] = request.redirect_uris
|
||||
if request.webhook_domains is not None:
|
||||
update_data["webhookDomains"] = request.webhook_domains
|
||||
|
||||
if not update_data:
|
||||
return _client_to_response(client)
|
||||
|
||||
updated = await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data=update_data, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return _client_to_response(updated)
|
||||
|
||||
|
||||
@client_router.delete("/{client_id}")
|
||||
async def delete_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete an OAuth client.
|
||||
|
||||
This will also revoke all tokens and authorizations for this client.
|
||||
"""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
# Delete cascades will handle tokens, codes, and authorizations
|
||||
await prisma.oauthclient.delete(where={"id": client.id})
|
||||
|
||||
return {"status": "deleted", "client_id": client_id}
|
||||
|
||||
|
||||
@client_router.post("/{client_id}/rotate-secret", response_model=ClientSecretResponse)
|
||||
async def rotate_client_secret(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientSecretResponse:
|
||||
"""
|
||||
Rotate the client secret for a confidential client.
|
||||
|
||||
The new secret is returned only once. All existing tokens remain valid.
|
||||
Also rotates the webhook secret for security.
|
||||
"""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
if client.clientType != "confidential":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot rotate secret for public clients",
|
||||
)
|
||||
|
||||
# Generate new secrets
|
||||
new_secret = _generate_client_secret()
|
||||
new_salt = secrets.token_urlsafe(16)
|
||||
new_hash = _hash_secret(new_secret, new_salt)
|
||||
new_webhook_secret = _generate_webhook_secret()
|
||||
|
||||
await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data={
|
||||
"clientSecretHash": new_hash,
|
||||
"clientSecretSalt": new_salt,
|
||||
"webhookSecret": new_webhook_secret,
|
||||
},
|
||||
)
|
||||
|
||||
return ClientSecretResponse(
|
||||
client_id=client_id,
|
||||
client_secret=new_secret,
|
||||
webhook_secret=new_webhook_secret,
|
||||
)
|
||||
|
||||
|
||||
class WebhookSecretResponse(BaseModel):
|
||||
"""Response containing newly generated webhook secret."""
|
||||
|
||||
client_id: str
|
||||
webhook_secret: str
|
||||
|
||||
|
||||
@client_router.post(
|
||||
"/{client_id}/rotate-webhook-secret", response_model=WebhookSecretResponse
|
||||
)
|
||||
async def rotate_webhook_secret(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> WebhookSecretResponse:
|
||||
"""
|
||||
Rotate only the webhook secret for a client.
|
||||
|
||||
The new webhook secret is returned only once.
|
||||
"""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
# Generate new webhook secret
|
||||
new_webhook_secret = _generate_webhook_secret()
|
||||
|
||||
await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data={"webhookSecret": new_webhook_secret},
|
||||
)
|
||||
|
||||
return WebhookSecretResponse(
|
||||
client_id=client_id,
|
||||
webhook_secret=new_webhook_secret,
|
||||
)
|
||||
|
||||
|
||||
@client_router.post("/{client_id}/suspend")
|
||||
async def suspend_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Suspend an OAuth client (prevents new authorizations)."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
updated = await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data={"status": OAuthClientStatus.SUSPENDED},
|
||||
)
|
||||
|
||||
return _client_to_response(updated)
|
||||
|
||||
|
||||
@client_router.post("/{client_id}/activate")
|
||||
async def activate_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Reactivate a suspended OAuth client."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
updated = await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data={"status": OAuthClientStatus.ACTIVE},
|
||||
)
|
||||
|
||||
return _client_to_response(updated)
|
||||
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
Server-rendered HTML templates for OAuth consent UI.
|
||||
|
||||
These templates are used for the OAuth authorization flow
|
||||
when the user needs to approve access for an external application.
|
||||
"""
|
||||
|
||||
import html
|
||||
from typing import Optional
|
||||
|
||||
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
|
||||
|
||||
|
||||
def _base_styles() -> str:
|
||||
"""Common CSS styles for all OAuth pages."""
|
||||
return """
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
color: #e4e4e7;
|
||||
}
|
||||
.container {
|
||||
background: #27272a;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
|
||||
max-width: 420px;
|
||||
width: 100%;
|
||||
padding: 32px;
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.logo {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
border-radius: 12px;
|
||||
margin-bottom: 16px;
|
||||
background: #3f3f46;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
.logo img {
|
||||
max-width: 48px;
|
||||
max-height: 48px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.logo-placeholder {
|
||||
font-size: 28px;
|
||||
color: #a1a1aa;
|
||||
}
|
||||
h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.subtitle {
|
||||
color: #a1a1aa;
|
||||
font-size: 14px;
|
||||
}
|
||||
.app-name {
|
||||
color: #22d3ee;
|
||||
font-weight: 600;
|
||||
}
|
||||
.divider {
|
||||
height: 1px;
|
||||
background: #3f3f46;
|
||||
margin: 24px 0;
|
||||
}
|
||||
.scopes-section h2 {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: #a1a1aa;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.scope-item {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
padding: 12px 0;
|
||||
border-bottom: 1px solid #3f3f46;
|
||||
}
|
||||
.scope-item:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
.scope-icon {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
color: #22d3ee;
|
||||
flex-shrink: 0;
|
||||
margin-top: 2px;
|
||||
}
|
||||
.scope-text {
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.buttons {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
margin-top: 24px;
|
||||
}
|
||||
.btn {
|
||||
flex: 1;
|
||||
padding: 12px 24px;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.btn-cancel {
|
||||
background: #3f3f46;
|
||||
color: #e4e4e7;
|
||||
}
|
||||
.btn-cancel:hover {
|
||||
background: #52525b;
|
||||
}
|
||||
.btn-allow {
|
||||
background: #22d3ee;
|
||||
color: #0f172a;
|
||||
}
|
||||
.btn-allow:hover {
|
||||
background: #06b6d4;
|
||||
}
|
||||
.footer {
|
||||
margin-top: 24px;
|
||||
text-align: center;
|
||||
font-size: 12px;
|
||||
color: #71717a;
|
||||
}
|
||||
.footer a {
|
||||
color: #a1a1aa;
|
||||
text-decoration: none;
|
||||
}
|
||||
.footer a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.error-container {
|
||||
text-align: center;
|
||||
}
|
||||
.error-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 16px;
|
||||
color: #ef4444;
|
||||
}
|
||||
.error-title {
|
||||
color: #ef4444;
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.error-message {
|
||||
color: #a1a1aa;
|
||||
font-size: 14px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.success-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 16px;
|
||||
color: #22c55e;
|
||||
}
|
||||
.success-title {
|
||||
color: #22c55e;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _check_icon() -> str:
|
||||
"""SVG checkmark icon."""
|
||||
return """
|
||||
<svg class="scope-icon" viewBox="0 0 20 20" fill="currentColor">
|
||||
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/>
|
||||
</svg>
|
||||
"""
|
||||
|
||||
|
||||
def _error_icon() -> str:
|
||||
"""SVG error icon."""
|
||||
return """
|
||||
<svg class="error-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="12" r="10"/>
|
||||
<line x1="15" y1="9" x2="9" y2="15"/>
|
||||
<line x1="9" y1="9" x2="15" y2="15"/>
|
||||
</svg>
|
||||
"""
|
||||
|
||||
|
||||
def _success_icon() -> str:
|
||||
"""SVG success icon."""
|
||||
return """
|
||||
<svg class="success-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="12" r="10"/>
|
||||
<path d="M9 12l2 2 4-4"/>
|
||||
</svg>
|
||||
"""
|
||||
|
||||
|
||||
def render_consent_page(
|
||||
client_name: str,
|
||||
client_logo: Optional[str],
|
||||
scopes: list[str],
|
||||
consent_token: str,
|
||||
action_url: str,
|
||||
privacy_policy_url: Optional[str] = None,
|
||||
terms_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render the OAuth consent page.
|
||||
|
||||
Args:
|
||||
client_name: Name of the requesting application
|
||||
client_logo: URL to the client's logo (optional)
|
||||
scopes: List of requested scopes
|
||||
consent_token: CSRF token for the consent form
|
||||
action_url: URL to submit the consent form
|
||||
privacy_policy_url: Client's privacy policy URL (optional)
|
||||
terms_url: Client's terms of service URL (optional)
|
||||
|
||||
Returns:
|
||||
HTML string for the consent page
|
||||
"""
|
||||
# Escape user-provided values to prevent XSS
|
||||
safe_client_name = html.escape(client_name)
|
||||
safe_client_logo = html.escape(client_logo) if client_logo else None
|
||||
|
||||
# Build logo HTML
|
||||
if safe_client_logo:
|
||||
logo_html = f'<img src="{safe_client_logo}" alt="{safe_client_name}">'
|
||||
else:
|
||||
logo_html = f'<span class="logo-placeholder">{html.escape(client_name[0].upper())}</span>'
|
||||
|
||||
# Build scopes HTML
|
||||
scopes_html = ""
|
||||
for scope in scopes:
|
||||
description = SCOPE_DESCRIPTIONS.get(scope, scope)
|
||||
scopes_html += f"""
|
||||
<div class="scope-item">
|
||||
{_check_icon()}
|
||||
<span class="scope-text">{html.escape(description)}</span>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Build footer links (escape URLs)
|
||||
footer_links = []
|
||||
if privacy_policy_url:
|
||||
footer_links.append(
|
||||
f'<a href="{html.escape(privacy_policy_url)}" target="_blank">Privacy Policy</a>'
|
||||
)
|
||||
if terms_url:
|
||||
footer_links.append(
|
||||
f'<a href="{html.escape(terms_url)}" target="_blank">Terms of Service</a>'
|
||||
)
|
||||
footer_html = " • ".join(footer_links) if footer_links else ""
|
||||
|
||||
# Escape action_url and consent_token
|
||||
safe_action_url = html.escape(action_url)
|
||||
safe_consent_token = html.escape(consent_token)
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorize {safe_client_name} - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<div class="logo">{logo_html}</div>
|
||||
<h1>Authorize <span class="app-name">{safe_client_name}</span></h1>
|
||||
<p class="subtitle">wants to access your AutoGPT account</p>
|
||||
</div>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<div class="scopes-section">
|
||||
<h2>This will allow {safe_client_name} to:</h2>
|
||||
{scopes_html}
|
||||
</div>
|
||||
|
||||
<form method="POST" action="{safe_action_url}">
|
||||
<input type="hidden" name="consent_token" value="{safe_consent_token}">
|
||||
<div class="buttons">
|
||||
<button type="submit" name="authorize" value="false" class="btn btn-cancel">
|
||||
Cancel
|
||||
</button>
|
||||
<button type="submit" name="authorize" value="true" class="btn btn-allow">
|
||||
Allow
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
{f'<div class="footer">{footer_html}</div>' if footer_html else ''}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def render_error_page(
|
||||
error: str,
|
||||
error_description: str,
|
||||
redirect_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render an OAuth error page.
|
||||
|
||||
Args:
|
||||
error: Error code
|
||||
error_description: Human-readable error description
|
||||
redirect_url: Optional URL to redirect back (if safe)
|
||||
|
||||
Returns:
|
||||
HTML string for the error page
|
||||
"""
|
||||
# Escape user-provided values to prevent XSS
|
||||
safe_error = html.escape(error)
|
||||
safe_error_description = html.escape(error_description)
|
||||
|
||||
redirect_html = ""
|
||||
if redirect_url:
|
||||
safe_redirect_url = html.escape(redirect_url)
|
||||
redirect_html = f"""
|
||||
<a href="{safe_redirect_url}" class="btn btn-cancel" style="display: inline-block; text-decoration: none;">
|
||||
Go Back
|
||||
</a>
|
||||
"""
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorization Error - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="error-container">
|
||||
{_error_icon()}
|
||||
<h1 class="error-title">Authorization Failed</h1>
|
||||
<p class="error-message">{safe_error_description}</p>
|
||||
<p class="error-message" style="font-size: 12px; color: #52525b;">
|
||||
Error code: {safe_error}
|
||||
</p>
|
||||
{redirect_html}
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def render_success_page(
|
||||
message: str,
|
||||
redirect_origin: Optional[str] = None,
|
||||
post_message_data: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render a success page, optionally with postMessage for popup flows.
|
||||
|
||||
Args:
|
||||
message: Success message to display
|
||||
redirect_origin: Origin for postMessage (popup flows)
|
||||
post_message_data: Data to send via postMessage (popup flows)
|
||||
|
||||
Returns:
|
||||
HTML string for the success page
|
||||
"""
|
||||
# Escape user-provided values to prevent XSS
|
||||
safe_message = html.escape(message)
|
||||
|
||||
# PostMessage script for popup flows
|
||||
post_message_script = ""
|
||||
if redirect_origin and post_message_data:
|
||||
import json
|
||||
|
||||
# json.dumps escapes for JS context, but we also escape < > for HTML context
|
||||
safe_json_origin = (
|
||||
json.dumps(redirect_origin).replace("<", "\\u003c").replace(">", "\\u003e")
|
||||
)
|
||||
safe_json_data = (
|
||||
json.dumps(post_message_data)
|
||||
.replace("<", "\\u003c")
|
||||
.replace(">", "\\u003e")
|
||||
)
|
||||
|
||||
post_message_script = f"""
|
||||
<script>
|
||||
(function() {{
|
||||
var targetOrigin = {safe_json_origin};
|
||||
var message = {safe_json_data};
|
||||
if (window.opener) {{
|
||||
window.opener.postMessage(message, targetOrigin);
|
||||
setTimeout(function() {{ window.close(); }}, 1000);
|
||||
}}
|
||||
}})();
|
||||
</script>
|
||||
"""
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorization Successful - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="error-container">
|
||||
{_success_icon()}
|
||||
<h1 class="success-title">Success!</h1>
|
||||
<p class="error-message">{safe_message}</p>
|
||||
<p class="error-message" style="font-size: 12px;">
|
||||
This window will close automatically...
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{post_message_script}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def render_login_redirect_page(login_url: str) -> str:
|
||||
"""
|
||||
Render a page that redirects to login.
|
||||
|
||||
Args:
|
||||
login_url: URL to redirect to for login
|
||||
|
||||
Returns:
|
||||
HTML string with auto-redirect
|
||||
"""
|
||||
# Escape URL to prevent XSS
|
||||
safe_login_url = html.escape(login_url)
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta http-equiv="refresh" content="0;url={safe_login_url}">
|
||||
<title>Login Required - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="error-container">
|
||||
<p class="error-message">Redirecting to login...</p>
|
||||
<a href="{safe_login_url}" class="btn btn-allow" style="display: inline-block; text-decoration: none;">
|
||||
Click here if not redirected
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def _login_form_styles() -> str:
|
||||
"""Additional CSS styles for login form."""
|
||||
return """
|
||||
.form-group {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: #a1a1aa;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 12px 16px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #3f3f46;
|
||||
background: #18181b;
|
||||
color: #e4e4e7;
|
||||
font-size: 14px;
|
||||
outline: none;
|
||||
transition: border-color 0.2s;
|
||||
}
|
||||
.form-group input:focus {
|
||||
border-color: #22d3ee;
|
||||
}
|
||||
.form-group input::placeholder {
|
||||
color: #52525b;
|
||||
}
|
||||
.error-alert {
|
||||
background: rgba(239, 68, 68, 0.1);
|
||||
border: 1px solid #ef4444;
|
||||
border-radius: 8px;
|
||||
padding: 12px 16px;
|
||||
margin-bottom: 16px;
|
||||
color: #fca5a5;
|
||||
font-size: 14px;
|
||||
}
|
||||
.btn-login {
|
||||
width: 100%;
|
||||
padding: 12px 24px;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
background: #22d3ee;
|
||||
color: #0f172a;
|
||||
transition: all 0.2s;
|
||||
margin-top: 8px;
|
||||
}
|
||||
.btn-login:hover {
|
||||
background: #06b6d4;
|
||||
}
|
||||
.btn-login:disabled {
|
||||
background: #3f3f46;
|
||||
color: #71717a;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
.signup-link {
|
||||
text-align: center;
|
||||
margin-top: 16px;
|
||||
font-size: 14px;
|
||||
color: #a1a1aa;
|
||||
}
|
||||
.signup-link a {
|
||||
color: #22d3ee;
|
||||
text-decoration: none;
|
||||
}
|
||||
.signup-link a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def render_login_page(
|
||||
action_url: str,
|
||||
login_state: str,
|
||||
client_name: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
signup_url: Optional[str] = None,
|
||||
browser_login_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render an embedded login page for OAuth flow.
|
||||
|
||||
Args:
|
||||
action_url: URL to submit the login form
|
||||
login_state: State token to preserve OAuth parameters
|
||||
client_name: Name of the application requesting access (optional)
|
||||
error_message: Error message to display (optional)
|
||||
signup_url: URL to signup page (optional)
|
||||
browser_login_url: URL to redirect to frontend login (optional)
|
||||
|
||||
Returns:
|
||||
HTML string for the login page
|
||||
"""
|
||||
# Escape all user-provided values to prevent XSS
|
||||
safe_action_url = html.escape(action_url)
|
||||
safe_login_state = html.escape(login_state)
|
||||
safe_client_name = html.escape(client_name) if client_name else None
|
||||
|
||||
error_html = ""
|
||||
if error_message:
|
||||
safe_error_message = html.escape(error_message)
|
||||
error_html = f'<div class="error-alert">{safe_error_message}</div>'
|
||||
|
||||
subtitle = "wants to access your AutoGPT account" if safe_client_name else ""
|
||||
title_html = (
|
||||
'<h1>Sign in to <span class="app-name">AutoGPT</span></h1>'
|
||||
if not safe_client_name
|
||||
else f'<h1><span class="app-name">{safe_client_name}</span></h1>'
|
||||
)
|
||||
|
||||
signup_html = ""
|
||||
if signup_url:
|
||||
safe_signup_url = html.escape(signup_url)
|
||||
signup_html = f"""
|
||||
<div class="signup-link">
|
||||
Don't have an account? <a href="{safe_signup_url}">Sign up</a>
|
||||
</div>
|
||||
"""
|
||||
|
||||
browser_login_html = ""
|
||||
if browser_login_url:
|
||||
safe_browser_login_url = html.escape(browser_login_url)
|
||||
browser_login_html = f"""
|
||||
<div class="divider"></div>
|
||||
<div class="signup-link">
|
||||
<a href="{safe_browser_login_url}">Sign in with Google or other providers</a>
|
||||
</div>
|
||||
"""
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Sign In - AutoGPT</title>
|
||||
<style>
|
||||
{_base_styles()}
|
||||
{_login_form_styles()}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<div class="logo">
|
||||
<span class="logo-placeholder">A</span>
|
||||
</div>
|
||||
{title_html}
|
||||
<p class="subtitle">{subtitle}</p>
|
||||
</div>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
{error_html}
|
||||
|
||||
<form method="POST" action="{safe_action_url}">
|
||||
<input type="hidden" name="login_state" value="{safe_login_state}">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="email">Email</label>
|
||||
<input
|
||||
type="email"
|
||||
id="email"
|
||||
name="email"
|
||||
placeholder="you@example.com"
|
||||
required
|
||||
autocomplete="email"
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input
|
||||
type="password"
|
||||
id="password"
|
||||
name="password"
|
||||
placeholder="Enter your password"
|
||||
required
|
||||
autocomplete="current-password"
|
||||
>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn-login">Sign In</button>
|
||||
</form>
|
||||
|
||||
{signup_html}
|
||||
{browser_login_html}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
OIDC Discovery endpoints.
|
||||
|
||||
Implements:
|
||||
- GET /.well-known/openid-configuration - OIDC Discovery Document
|
||||
- GET /.well-known/jwks.json - JSON Web Key Set
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from backend.server.oauth.models import JWKS, OpenIDConfiguration
|
||||
from backend.server.oauth.token_service import get_token_service
|
||||
from backend.util.settings import Settings
|
||||
|
||||
discovery_router = APIRouter(tags=["oidc-discovery"])
|
||||
|
||||
|
||||
@discovery_router.get(
|
||||
"/.well-known/openid-configuration",
|
||||
response_model=OpenIDConfiguration,
|
||||
)
|
||||
async def openid_configuration() -> OpenIDConfiguration:
|
||||
"""
|
||||
OIDC Discovery Document.
|
||||
|
||||
Returns metadata about the OAuth 2.0 authorization server including
|
||||
endpoints, supported features, and algorithms.
|
||||
"""
|
||||
settings = Settings()
|
||||
base_url = settings.config.platform_base_url or "https://platform.agpt.co"
|
||||
|
||||
return OpenIDConfiguration(
|
||||
issuer=base_url,
|
||||
authorization_endpoint=f"{base_url}/oauth/authorize",
|
||||
token_endpoint=f"{base_url}/oauth/token",
|
||||
userinfo_endpoint=f"{base_url}/oauth/userinfo",
|
||||
revocation_endpoint=f"{base_url}/oauth/revoke",
|
||||
jwks_uri=f"{base_url}/.well-known/jwks.json",
|
||||
scopes_supported=[
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
"integrations:list",
|
||||
"integrations:connect",
|
||||
"integrations:delete",
|
||||
"agents:execute",
|
||||
],
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
token_endpoint_auth_methods_supported=[
|
||||
"client_secret_post",
|
||||
"client_secret_basic",
|
||||
"none", # For public clients with PKCE
|
||||
],
|
||||
code_challenge_methods_supported=["S256"],
|
||||
subject_types_supported=["public"],
|
||||
id_token_signing_alg_values_supported=["RS256"],
|
||||
)
|
||||
|
||||
|
||||
@discovery_router.get("/.well-known/jwks.json", response_model=JWKS)
|
||||
async def jwks() -> dict:
|
||||
"""
|
||||
JSON Web Key Set (JWKS).
|
||||
|
||||
Returns the public key(s) used to verify JWT signatures.
|
||||
External applications can use these keys to verify access tokens
|
||||
and ID tokens issued by this authorization server.
|
||||
"""
|
||||
token_service = get_token_service()
|
||||
return token_service.get_jwks()
|
||||
162
autogpt_platform/backend/backend/server/oauth/errors.py
Normal file
162
autogpt_platform/backend/backend/server/oauth/errors.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
OAuth 2.0 Error Responses (RFC 6749 Section 5.2).
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OAuthErrorCode(str, Enum):
|
||||
"""Standard OAuth 2.0 error codes."""
|
||||
|
||||
# Authorization endpoint errors (RFC 6749 Section 4.1.2.1)
|
||||
INVALID_REQUEST = "invalid_request"
|
||||
UNAUTHORIZED_CLIENT = "unauthorized_client"
|
||||
ACCESS_DENIED = "access_denied"
|
||||
UNSUPPORTED_RESPONSE_TYPE = "unsupported_response_type"
|
||||
INVALID_SCOPE = "invalid_scope"
|
||||
SERVER_ERROR = "server_error"
|
||||
TEMPORARILY_UNAVAILABLE = "temporarily_unavailable"
|
||||
|
||||
# Token endpoint errors (RFC 6749 Section 5.2)
|
||||
INVALID_CLIENT = "invalid_client"
|
||||
INVALID_GRANT = "invalid_grant"
|
||||
UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type"
|
||||
|
||||
# Extension errors
|
||||
LOGIN_REQUIRED = "login_required"
|
||||
CONSENT_REQUIRED = "consent_required"
|
||||
|
||||
|
||||
class OAuthErrorResponse(BaseModel):
|
||||
"""OAuth error response model."""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
error_uri: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthError(Exception):
|
||||
"""Base OAuth error exception."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: OAuthErrorCode,
|
||||
description: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
):
|
||||
self.error = error
|
||||
self.description = description
|
||||
self.uri = uri
|
||||
self.state = state
|
||||
super().__init__(description or error.value)
|
||||
|
||||
def to_response(self) -> OAuthErrorResponse:
|
||||
"""Convert to response model."""
|
||||
return OAuthErrorResponse(
|
||||
error=self.error.value,
|
||||
error_description=self.description,
|
||||
error_uri=self.uri,
|
||||
)
|
||||
|
||||
def to_redirect(self, redirect_uri: str) -> RedirectResponse:
|
||||
"""Convert to redirect response with error in query params."""
|
||||
params = {"error": self.error.value}
|
||||
if self.description:
|
||||
params["error_description"] = self.description
|
||||
if self.uri:
|
||||
params["error_uri"] = self.uri
|
||||
if self.state:
|
||||
params["state"] = self.state
|
||||
|
||||
separator = "&" if "?" in redirect_uri else "?"
|
||||
url = f"{redirect_uri}{separator}{urlencode(params)}"
|
||||
return RedirectResponse(url=url, status_code=302)
|
||||
|
||||
def to_http_exception(self, status_code: int = 400) -> HTTPException:
|
||||
"""Convert to FastAPI HTTPException."""
|
||||
return HTTPException(
|
||||
status_code=status_code,
|
||||
detail=self.to_response().model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
# Convenience error classes
|
||||
class InvalidRequestError(OAuthError):
|
||||
"""The request is missing a required parameter or is otherwise malformed."""
|
||||
|
||||
def __init__(self, description: str, state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.INVALID_REQUEST, description, state=state)
|
||||
|
||||
|
||||
class UnauthorizedClientError(OAuthError):
|
||||
"""The client is not authorized to request an authorization code."""
|
||||
|
||||
def __init__(self, description: str, state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.UNAUTHORIZED_CLIENT, description, state=state)
|
||||
|
||||
|
||||
class AccessDeniedError(OAuthError):
|
||||
"""The resource owner denied the request."""
|
||||
|
||||
def __init__(self, description: str = "Access denied", state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.ACCESS_DENIED, description, state=state)
|
||||
|
||||
|
||||
class InvalidScopeError(OAuthError):
|
||||
"""The requested scope is invalid, unknown, or malformed."""
|
||||
|
||||
def __init__(self, description: str, state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.INVALID_SCOPE, description, state=state)
|
||||
|
||||
|
||||
class InvalidClientError(OAuthError):
|
||||
"""Client authentication failed."""
|
||||
|
||||
def __init__(self, description: str = "Invalid client"):
|
||||
super().__init__(OAuthErrorCode.INVALID_CLIENT, description)
|
||||
|
||||
|
||||
class InvalidGrantError(OAuthError):
|
||||
"""The provided authorization code or refresh token is invalid."""
|
||||
|
||||
def __init__(self, description: str = "Invalid grant"):
|
||||
super().__init__(OAuthErrorCode.INVALID_GRANT, description)
|
||||
|
||||
|
||||
class UnsupportedGrantTypeError(OAuthError):
|
||||
"""The authorization grant type is not supported."""
|
||||
|
||||
def __init__(self, grant_type: str):
|
||||
super().__init__(
|
||||
OAuthErrorCode.UNSUPPORTED_GRANT_TYPE,
|
||||
f"Grant type '{grant_type}' is not supported",
|
||||
)
|
||||
|
||||
|
||||
class LoginRequiredError(OAuthError):
|
||||
"""User must be logged in to complete the request."""
|
||||
|
||||
def __init__(self, state: Optional[str] = None):
|
||||
super().__init__(
|
||||
OAuthErrorCode.LOGIN_REQUIRED,
|
||||
"User authentication required",
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
class ConsentRequiredError(OAuthError):
|
||||
"""User consent is required for the requested scopes."""
|
||||
|
||||
def __init__(self, state: Optional[str] = None):
|
||||
super().__init__(
|
||||
OAuthErrorCode.CONSENT_REQUIRED,
|
||||
"User consent required",
|
||||
state=state,
|
||||
)
|
||||
288
autogpt_platform/backend/backend/server/oauth/models.py
Normal file
288
autogpt_platform/backend/backend/server/oauth/models.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Pydantic models for OAuth 2.0 requests and responses.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
# ============================================================
|
||||
# Enums and Constants
|
||||
# ============================================================
|
||||
|
||||
|
||||
class OAuthScope(str, Enum):
|
||||
"""Supported OAuth scopes."""
|
||||
|
||||
# OpenID Connect standard scopes
|
||||
OPENID = "openid"
|
||||
PROFILE = "profile"
|
||||
EMAIL = "email"
|
||||
|
||||
# AutoGPT-specific scopes
|
||||
INTEGRATIONS_LIST = "integrations:list"
|
||||
INTEGRATIONS_CONNECT = "integrations:connect"
|
||||
INTEGRATIONS_DELETE = "integrations:delete"
|
||||
AGENTS_EXECUTE = "agents:execute"
|
||||
|
||||
|
||||
SCOPE_DESCRIPTIONS: dict[str, str] = {
|
||||
OAuthScope.OPENID.value: "Access your user ID",
|
||||
OAuthScope.PROFILE.value: "Access your profile information (name)",
|
||||
OAuthScope.EMAIL.value: "Access your email address",
|
||||
OAuthScope.INTEGRATIONS_LIST.value: "View your connected integrations",
|
||||
OAuthScope.INTEGRATIONS_CONNECT.value: "Connect new integrations on your behalf",
|
||||
OAuthScope.INTEGRATIONS_DELETE.value: "Delete integrations on your behalf",
|
||||
OAuthScope.AGENTS_EXECUTE.value: "Run agents on your behalf",
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Authorization Request/Response Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class AuthorizationRequest(BaseModel):
|
||||
"""OAuth 2.0 Authorization Request (RFC 6749 Section 4.1.1)."""
|
||||
|
||||
response_type: Literal["code"] = Field(
|
||||
..., description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
client_id: str = Field(..., description="Client identifier")
|
||||
redirect_uri: str = Field(..., description="Redirect URI after authorization")
|
||||
scope: str = Field(default="", description="Space-separated list of scopes")
|
||||
state: str = Field(..., description="CSRF protection token (required)")
|
||||
code_challenge: str = Field(..., description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256"] = Field(
|
||||
default="S256", description="PKCE method (only S256 supported)"
|
||||
)
|
||||
nonce: Optional[str] = Field(None, description="OIDC nonce for replay protection")
|
||||
prompt: Optional[Literal["consent", "login", "none"]] = Field(
|
||||
None, description="Prompt behavior"
|
||||
)
|
||||
|
||||
|
||||
class ConsentFormData(BaseModel):
|
||||
"""Consent form submission data."""
|
||||
|
||||
consent_token: str = Field(..., description="CSRF token for consent")
|
||||
authorize: bool = Field(..., description="Whether user authorized")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Token Request/Response Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TokenRequest(BaseModel):
|
||||
"""OAuth 2.0 Token Request (RFC 6749 Section 4.1.3)."""
|
||||
|
||||
grant_type: Literal["authorization_code", "refresh_token"] = Field(
|
||||
..., description="Grant type"
|
||||
)
|
||||
code: Optional[str] = Field(
|
||||
None, description="Authorization code (for authorization_code grant)"
|
||||
)
|
||||
redirect_uri: Optional[str] = Field(
|
||||
None, description="Must match authorization request"
|
||||
)
|
||||
client_id: str = Field(..., description="Client identifier")
|
||||
client_secret: Optional[str] = Field(
|
||||
None, description="Client secret (for confidential clients)"
|
||||
)
|
||||
code_verifier: Optional[str] = Field(
|
||||
None, description="PKCE code verifier (for authorization_code grant)"
|
||||
)
|
||||
refresh_token: Optional[str] = Field(
|
||||
None, description="Refresh token (for refresh_token grant)"
|
||||
)
|
||||
scope: Optional[str] = Field(
|
||||
None, description="Requested scopes (for refresh_token grant)"
|
||||
)
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
|
||||
|
||||
access_token: str = Field(..., description="Access token")
|
||||
token_type: Literal["Bearer"] = Field(default="Bearer", description="Token type")
|
||||
expires_in: int = Field(..., description="Token lifetime in seconds")
|
||||
refresh_token: Optional[str] = Field(None, description="Refresh token")
|
||||
scope: Optional[str] = Field(None, description="Granted scopes")
|
||||
id_token: Optional[str] = Field(None, description="OIDC ID token")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# UserInfo Response Model
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UserInfoResponse(BaseModel):
|
||||
"""OIDC UserInfo Response."""
|
||||
|
||||
sub: str = Field(..., description="User ID (subject)")
|
||||
email: Optional[str] = Field(None, description="User email")
|
||||
email_verified: Optional[bool] = Field(
|
||||
None, description="Whether email is verified"
|
||||
)
|
||||
name: Optional[str] = Field(None, description="User display name")
|
||||
updated_at: Optional[int] = Field(None, description="Last profile update timestamp")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# OIDC Discovery Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class OpenIDConfiguration(BaseModel):
|
||||
"""OIDC Discovery Document."""
|
||||
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: str
|
||||
revocation_endpoint: str
|
||||
jwks_uri: str
|
||||
scopes_supported: list[str]
|
||||
response_types_supported: list[str]
|
||||
grant_types_supported: list[str]
|
||||
token_endpoint_auth_methods_supported: list[str]
|
||||
code_challenge_methods_supported: list[str]
|
||||
subject_types_supported: list[str]
|
||||
id_token_signing_alg_values_supported: list[str]
|
||||
|
||||
|
||||
class JWK(BaseModel):
|
||||
"""JSON Web Key."""
|
||||
|
||||
kty: str = Field(..., description="Key type (RSA)")
|
||||
use: str = Field(default="sig", description="Key use (signature)")
|
||||
kid: str = Field(..., description="Key ID")
|
||||
alg: str = Field(default="RS256", description="Algorithm")
|
||||
n: str = Field(..., description="RSA modulus")
|
||||
e: str = Field(..., description="RSA exponent")
|
||||
|
||||
|
||||
class JWKS(BaseModel):
|
||||
"""JSON Web Key Set."""
|
||||
|
||||
keys: list[JWK]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Client Management Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class RegisterClientRequest(BaseModel):
|
||||
"""Request to register a new OAuth client."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Client name")
|
||||
description: Optional[str] = Field(
|
||||
None, max_length=500, description="Client description"
|
||||
)
|
||||
logo_url: Optional[HttpUrl] = Field(None, description="Logo URL")
|
||||
homepage_url: Optional[HttpUrl] = Field(None, description="Homepage URL")
|
||||
privacy_policy_url: Optional[HttpUrl] = Field(
|
||||
None, description="Privacy policy URL"
|
||||
)
|
||||
terms_of_service_url: Optional[HttpUrl] = Field(
|
||||
None, description="Terms of service URL"
|
||||
)
|
||||
redirect_uris: list[str] = Field(
|
||||
..., min_length=1, description="Allowed redirect URIs"
|
||||
)
|
||||
client_type: Literal["public", "confidential"] = Field(
|
||||
default="public", description="Client type"
|
||||
)
|
||||
webhook_domains: list[str] = Field(
|
||||
default_factory=list, description="Allowed webhook domains"
|
||||
)
|
||||
|
||||
|
||||
class UpdateClientRequest(BaseModel):
|
||||
"""Request to update an OAuth client."""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
logo_url: Optional[HttpUrl] = None
|
||||
homepage_url: Optional[HttpUrl] = None
|
||||
privacy_policy_url: Optional[HttpUrl] = None
|
||||
terms_of_service_url: Optional[HttpUrl] = None
|
||||
redirect_uris: Optional[list[str]] = None
|
||||
webhook_domains: Optional[list[str]] = None
|
||||
|
||||
|
||||
class ClientResponse(BaseModel):
|
||||
"""OAuth client response."""
|
||||
|
||||
id: str
|
||||
client_id: str
|
||||
client_type: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
logo_url: Optional[str]
|
||||
homepage_url: Optional[str]
|
||||
privacy_policy_url: Optional[str]
|
||||
terms_of_service_url: Optional[str]
|
||||
redirect_uris: list[str]
|
||||
allowed_scopes: list[str]
|
||||
webhook_domains: list[str]
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ClientSecretResponse(BaseModel):
|
||||
"""Response containing newly generated client credentials."""
|
||||
|
||||
client_id: str
|
||||
client_secret: str = Field(
|
||||
..., description="Client secret (only shown once, store securely)"
|
||||
)
|
||||
webhook_secret: str = Field(
|
||||
...,
|
||||
description="Webhook secret for HMAC signing (only shown once, store securely)",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Token Introspection/Revocation Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TokenRevocationRequest(BaseModel):
|
||||
"""Token revocation request (RFC 7009)."""
|
||||
|
||||
token: str = Field(..., description="Token to revoke")
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
|
||||
None, description="Hint about token type"
|
||||
)
|
||||
|
||||
|
||||
class TokenIntrospectionRequest(BaseModel):
|
||||
"""Token introspection request (RFC 7662)."""
|
||||
|
||||
token: str = Field(..., description="Token to introspect")
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
|
||||
None, description="Hint about token type"
|
||||
)
|
||||
|
||||
|
||||
class TokenIntrospectionResponse(BaseModel):
|
||||
"""Token introspection response."""
|
||||
|
||||
active: bool = Field(..., description="Whether the token is active")
|
||||
scope: Optional[str] = Field(None, description="Token scopes")
|
||||
client_id: Optional[str] = Field(
|
||||
None, description="Client that token was issued to"
|
||||
)
|
||||
username: Optional[str] = Field(None, description="User identifier")
|
||||
token_type: Optional[str] = Field(None, description="Token type")
|
||||
exp: Optional[int] = Field(None, description="Expiration timestamp")
|
||||
iat: Optional[int] = Field(None, description="Issued at timestamp")
|
||||
sub: Optional[str] = Field(None, description="Subject (user ID)")
|
||||
aud: Optional[str] = Field(None, description="Audience")
|
||||
iss: Optional[str] = Field(None, description="Issuer")
|
||||
66
autogpt_platform/backend/backend/server/oauth/pkce.py
Normal file
66
autogpt_platform/backend/backend/server/oauth/pkce.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
PKCE (Proof Key for Code Exchange) implementation for OAuth 2.0.
|
||||
|
||||
RFC 7636: https://tools.ietf.org/html/rfc7636
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
|
||||
def generate_code_verifier(length: int = 64) -> str:
|
||||
"""
|
||||
Generate a cryptographically random code verifier.
|
||||
|
||||
Args:
|
||||
length: Length of the verifier (43-128 characters, default 64)
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded random string
|
||||
"""
|
||||
if not 43 <= length <= 128:
|
||||
raise ValueError("Code verifier length must be between 43 and 128")
|
||||
return secrets.token_urlsafe(length)[:length]
|
||||
|
||||
|
||||
def generate_code_challenge(verifier: str, method: str = "S256") -> str:
|
||||
"""
|
||||
Generate a code challenge from the verifier.
|
||||
|
||||
Args:
|
||||
verifier: The code verifier string
|
||||
method: Challenge method ("S256" or "plain")
|
||||
|
||||
Returns:
|
||||
The code challenge string
|
||||
"""
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
||||
# URL-safe base64 encoding without padding
|
||||
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
elif method == "plain":
|
||||
return verifier
|
||||
else:
|
||||
raise ValueError(f"Unsupported code challenge method: {method}")
|
||||
|
||||
|
||||
def verify_code_challenge(
|
||||
verifier: str,
|
||||
challenge: str,
|
||||
method: str = "S256",
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a code verifier matches the stored challenge.
|
||||
|
||||
Args:
|
||||
verifier: The code verifier from the token request
|
||||
challenge: The code challenge stored during authorization
|
||||
method: The challenge method used
|
||||
|
||||
Returns:
|
||||
True if the verifier matches the challenge
|
||||
"""
|
||||
expected = generate_code_challenge(verifier, method)
|
||||
# Use constant-time comparison to prevent timing attacks
|
||||
return secrets.compare_digest(expected, challenge)
|
||||
860
autogpt_platform/backend/backend/server/oauth/router.py
Normal file
860
autogpt_platform/backend/backend/server/oauth/router.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""
|
||||
OAuth 2.0 Authorization Server endpoints.
|
||||
|
||||
Implements:
|
||||
- GET /oauth/authorize - Authorization endpoint
|
||||
- POST /oauth/authorize/consent - Consent form submission
|
||||
- POST /oauth/token - Token endpoint
|
||||
- GET /oauth/userinfo - OIDC UserInfo endpoint
|
||||
- POST /oauth/revoke - Token revocation endpoint
|
||||
|
||||
Authentication:
|
||||
- X-API-Key header - API key for external apps (preferred)
|
||||
- Authorization: Bearer <jwt> - JWT token authentication
|
||||
- access_token cookie - Browser-based auth
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Query, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.server.oauth.consent_templates import (
|
||||
render_consent_page,
|
||||
render_error_page,
|
||||
render_login_redirect_page,
|
||||
)
|
||||
from backend.server.oauth.errors import (
|
||||
InvalidClientError,
|
||||
InvalidRequestError,
|
||||
OAuthError,
|
||||
UnsupportedGrantTypeError,
|
||||
)
|
||||
from backend.server.oauth.models import TokenResponse, UserInfoResponse
|
||||
from backend.server.oauth.service import get_oauth_service
|
||||
from backend.server.oauth.token_service import get_token_service
|
||||
from backend.util.rate_limiter import check_rate_limit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
oauth_router = APIRouter(prefix="/oauth", tags=["oauth"])
|
||||
|
||||
# Redis key prefix and TTL for consent state storage
|
||||
CONSENT_STATE_PREFIX = "oauth:consent:"
|
||||
CONSENT_STATE_TTL = 600 # 10 minutes
|
||||
|
||||
# Redis key prefix and TTL for login redirect state storage
|
||||
LOGIN_STATE_PREFIX = "oauth:login:"
|
||||
LOGIN_STATE_TTL = 900 # 15 minutes (longer to allow time for login)
|
||||
|
||||
|
||||
async def _store_login_state(token: str, state: dict) -> None:
|
||||
"""Store OAuth login state in Redis with TTL."""
|
||||
redis = await get_redis_async()
|
||||
await redis.setex(
|
||||
f"{LOGIN_STATE_PREFIX}{token}",
|
||||
LOGIN_STATE_TTL,
|
||||
json.dumps(state, default=str),
|
||||
)
|
||||
|
||||
|
||||
async def _get_and_delete_login_state(token: str) -> Optional[dict]:
|
||||
"""Retrieve and delete login state from Redis (one-time use, atomic)."""
|
||||
redis = await get_redis_async()
|
||||
key = f"{LOGIN_STATE_PREFIX}{token}"
|
||||
# Use GETDEL for atomic get+delete to prevent race conditions
|
||||
state_json = await redis.getdel(key)
|
||||
if state_json:
|
||||
return json.loads(state_json)
|
||||
return None
|
||||
|
||||
|
||||
async def _store_consent_state(token: str, state: dict) -> None:
|
||||
"""Store consent state in Redis with TTL."""
|
||||
redis = await get_redis_async()
|
||||
await redis.setex(
|
||||
f"{CONSENT_STATE_PREFIX}{token}",
|
||||
CONSENT_STATE_TTL,
|
||||
json.dumps(state, default=str),
|
||||
)
|
||||
|
||||
|
||||
async def _get_and_delete_consent_state(token: str) -> Optional[dict]:
|
||||
"""Retrieve and delete consent state from Redis (atomic get+delete)."""
|
||||
redis = await get_redis_async()
|
||||
key = f"{CONSENT_STATE_PREFIX}{token}"
|
||||
# Use GETDEL for atomic get+delete to prevent race conditions
|
||||
state_json = await redis.getdel(key)
|
||||
if state_json:
|
||||
return json.loads(state_json)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_user_id_from_request(
|
||||
request: Request, strict_bearer: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract user ID from request, checking API key, Authorization header, and cookie.
|
||||
|
||||
Supports:
|
||||
1. X-API-Key header - API key authentication (preferred for external apps)
|
||||
2. Authorization: Bearer <jwt> - JWT token authentication
|
||||
3. access_token cookie - Cookie-based auth (for browser flows)
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
strict_bearer: If True and Bearer token is provided but invalid,
|
||||
do NOT fallthrough to cookie auth (prevents auth downgrade attacks)
|
||||
"""
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
|
||||
from backend.data.api_key import validate_api_key
|
||||
|
||||
# First try X-API-Key header (for external apps)
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if api_key:
|
||||
try:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info.user_id
|
||||
except Exception:
|
||||
logger.debug("API key validation failed")
|
||||
|
||||
# Then try Authorization header (JWT)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
try:
|
||||
token = auth_header[7:]
|
||||
payload = parse_jwt_token(token)
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug("JWT token validation failed: %s", type(e).__name__)
|
||||
# Security fix: If Bearer token was provided but invalid,
|
||||
# don't fallthrough to weaker auth methods when strict_bearer is True
|
||||
if strict_bearer:
|
||||
return None
|
||||
|
||||
# Finally try cookie (browser-based auth)
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
try:
|
||||
payload = parse_jwt_token(token)
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug("Cookie token validation failed: %s", type(e).__name__)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _parse_scopes(scope_str: str) -> list[str]:
|
||||
"""Parse space-separated scope string into list."""
|
||||
if not scope_str:
|
||||
return []
|
||||
return [s.strip() for s in scope_str.split() if s.strip()]
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Get client IP address from request."""
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Authorization Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.get("/authorize", response_model=None)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response_type: str = Query(..., description="Must be 'code'"),
|
||||
client_id: str = Query(..., description="Client identifier"),
|
||||
redirect_uri: str = Query(..., description="Redirect URI"),
|
||||
state: str = Query(..., description="CSRF state parameter"),
|
||||
code_challenge: str = Query(..., description="PKCE code challenge"),
|
||||
code_challenge_method: str = Query("S256", description="PKCE method"),
|
||||
scope: str = Query("", description="Space-separated scopes"),
|
||||
nonce: Optional[str] = Query(None, description="OIDC nonce"),
|
||||
prompt: Optional[str] = Query(None, description="Prompt behavior"),
|
||||
) -> HTMLResponse | RedirectResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint.
|
||||
|
||||
Validates the request, checks user authentication, and either:
|
||||
- Returns error if user is not authenticated (API key or JWT required)
|
||||
- Shows consent page if user hasn't authorized these scopes
|
||||
- Redirects with authorization code if already authorized
|
||||
|
||||
Authentication methods (in order of preference):
|
||||
1. X-API-Key header - API key for external apps
|
||||
2. Authorization: Bearer <jwt> - JWT token
|
||||
3. access_token cookie - Browser-based auth
|
||||
"""
|
||||
# Get user ID from API key, Authorization header, or cookie
|
||||
user_id = await _get_user_id_from_request(request)
|
||||
|
||||
# Rate limiting - use client IP as identifier for authorize endpoint
|
||||
client_ip = _get_client_ip(request)
|
||||
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
|
||||
if not rate_result.allowed:
|
||||
return HTMLResponse(
|
||||
render_error_page(
|
||||
"rate_limit_exceeded",
|
||||
"Too many authorization requests. Please try again later.",
|
||||
),
|
||||
status_code=429,
|
||||
)
|
||||
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
try:
|
||||
# Validate response_type
|
||||
if response_type != "code":
|
||||
raise InvalidRequestError(
|
||||
"Only 'code' response_type is supported", state=state
|
||||
)
|
||||
|
||||
# Validate PKCE method
|
||||
if code_challenge_method != "S256":
|
||||
raise InvalidRequestError(
|
||||
"Only 'S256' code_challenge_method is supported", state=state
|
||||
)
|
||||
|
||||
# Parse scopes
|
||||
scopes = _parse_scopes(scope)
|
||||
|
||||
# Validate client and redirect URI
|
||||
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
|
||||
|
||||
# Check if user is authenticated
|
||||
if not user_id:
|
||||
# User needs to log in - store OAuth params and redirect to frontend login
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
login_token = secrets.token_urlsafe(32)
|
||||
logger.info(f"Storing login state with token: {login_token}")
|
||||
await _store_login_state(
|
||||
login_token,
|
||||
{
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scopes": scopes,
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"nonce": nonce,
|
||||
"prompt": prompt,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"expires_at": (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=LOGIN_STATE_TTL)
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
logger.info(f"Login state stored successfully for token: {login_token}")
|
||||
|
||||
# Build redirect URL to frontend login
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
if not frontend_base_url:
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(
|
||||
"server_error", "Frontend URL not configured"
|
||||
),
|
||||
status_code=500,
|
||||
)
|
||||
)
|
||||
|
||||
# Redirect to frontend login with oauth_session parameter
|
||||
login_url = f"{frontend_base_url}/login?oauth_session={login_token}"
|
||||
return _add_security_headers(
|
||||
HTMLResponse(render_login_redirect_page(login_url))
|
||||
)
|
||||
|
||||
# Check if user has already authorized these scopes
|
||||
if prompt != "consent":
|
||||
has_auth = await oauth_service.has_valid_authorization(
|
||||
user_id, client_id, scopes
|
||||
)
|
||||
if has_auth:
|
||||
# Skip consent, issue code directly
|
||||
code = await oauth_service.create_authorization_code(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scopes=scopes,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
nonce=nonce,
|
||||
)
|
||||
redirect_url = (
|
||||
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
|
||||
)
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
# Generate consent token and store state in Redis
|
||||
consent_token = secrets.token_urlsafe(32)
|
||||
await _store_consent_state(
|
||||
consent_token,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scopes": scopes,
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"nonce": nonce,
|
||||
"expires_at": (
|
||||
datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Render consent page
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_consent_page(
|
||||
client_name=client.name,
|
||||
client_logo=client.logoUrl,
|
||||
scopes=scopes,
|
||||
consent_token=consent_token,
|
||||
action_url="/oauth/authorize/consent",
|
||||
privacy_policy_url=client.privacyPolicyUrl,
|
||||
terms_url=client.termsOfServiceUrl,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
# If we have a valid redirect_uri, redirect with error
|
||||
# Otherwise show error page
|
||||
try:
|
||||
client = await oauth_service.get_client(client_id)
|
||||
if client and redirect_uri in client.redirectUris:
|
||||
return e.to_redirect(redirect_uri)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(e.error.value, e.description or "An error occurred"),
|
||||
status_code=400,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@oauth_router.post("/authorize/consent", response_model=None)
|
||||
async def submit_consent(
|
||||
request: Request,
|
||||
consent_token: str = Form(...),
|
||||
authorize: str = Form(...),
|
||||
) -> HTMLResponse | RedirectResponse:
|
||||
"""
|
||||
Process consent form submission.
|
||||
|
||||
Creates authorization code and redirects to client's redirect_uri.
|
||||
"""
|
||||
# Rate limiting on consent submission to prevent brute force attacks
|
||||
client_ip = _get_client_ip(request)
|
||||
rate_result = await check_rate_limit(client_ip, "oauth_consent")
|
||||
if not rate_result.allowed:
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(
|
||||
"rate_limit_exceeded",
|
||||
"Too many consent requests. Please try again later.",
|
||||
),
|
||||
status_code=429,
|
||||
)
|
||||
)
|
||||
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
# Validate consent token (retrieves and deletes from Redis atomically)
|
||||
consent_state = await _get_and_delete_consent_state(consent_token)
|
||||
if not consent_state:
|
||||
return HTMLResponse(
|
||||
render_error_page("invalid_request", "Invalid or expired consent token"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Check expiration (expires_at is stored as ISO string in Redis)
|
||||
expires_at = datetime.fromisoformat(consent_state["expires_at"])
|
||||
if expires_at < datetime.now(timezone.utc):
|
||||
return HTMLResponse(
|
||||
render_error_page("invalid_request", "Consent session expired"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
redirect_uri = consent_state["redirect_uri"]
|
||||
state = consent_state["state"]
|
||||
|
||||
# Check if user denied
|
||||
if authorize.lower() != "true":
|
||||
error_params = urlencode(
|
||||
{
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied the authorization request",
|
||||
"state": state,
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{error_params}",
|
||||
status_code=302,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create authorization code
|
||||
code = await oauth_service.create_authorization_code(
|
||||
user_id=consent_state["user_id"],
|
||||
client_id=consent_state["client_id"],
|
||||
redirect_uri=redirect_uri,
|
||||
scopes=consent_state["scopes"],
|
||||
code_challenge=consent_state["code_challenge"],
|
||||
code_challenge_method=consent_state["code_challenge_method"],
|
||||
nonce=consent_state["nonce"],
|
||||
)
|
||||
|
||||
# Redirect with code
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode({'code': code, 'state': state})}",
|
||||
status_code=302,
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
return e.to_redirect(redirect_uri)
|
||||
|
||||
|
||||
def _wants_json(request: Request) -> bool:
|
||||
"""Check if client prefers JSON response (for frontend fetch calls)."""
|
||||
accept = request.headers.get("Accept", "")
|
||||
return "application/json" in accept
|
||||
|
||||
|
||||
def _add_security_headers(response: HTMLResponse) -> HTMLResponse:
|
||||
"""Add security headers to OAuth HTML responses."""
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["Content-Security-Policy"] = "frame-ancestors 'none'"
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
return response
|
||||
|
||||
|
||||
@oauth_router.get("/authorize/resume", response_model=None)
|
||||
async def resume_authorization(
|
||||
request: Request,
|
||||
session_id: str = Query(..., description="OAuth login session ID"),
|
||||
) -> HTMLResponse | RedirectResponse | JSONResponse:
|
||||
"""
|
||||
Resume OAuth authorization after user login.
|
||||
|
||||
This endpoint is called after the user completes login on the frontend.
|
||||
It retrieves the stored OAuth parameters and continues the authorization flow.
|
||||
|
||||
Supports Accept: application/json header to return JSON for frontend fetch calls,
|
||||
solving CORS issues with redirect responses.
|
||||
"""
|
||||
wants_json = _wants_json(request)
|
||||
|
||||
# Rate limiting - use client IP
|
||||
client_ip = _get_client_ip(request)
|
||||
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
|
||||
if not rate_result.allowed:
|
||||
if wants_json:
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": "rate_limit_exceeded",
|
||||
"error_description": "Too many requests",
|
||||
},
|
||||
status_code=429,
|
||||
)
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(
|
||||
"rate_limit_exceeded",
|
||||
"Too many authorization requests. Please try again later.",
|
||||
),
|
||||
status_code=429,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify user is now authenticated (use strict_bearer to prevent auth downgrade)
|
||||
user_id = await _get_user_id_from_request(request, strict_bearer=True)
|
||||
if not user_id:
|
||||
from backend.util.settings import Settings
|
||||
|
||||
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
|
||||
if wants_json:
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": "login_required",
|
||||
"error_description": "Authentication required",
|
||||
"redirect_url": f"{frontend_url}/login",
|
||||
},
|
||||
status_code=401,
|
||||
)
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(
|
||||
"login_required",
|
||||
"Authentication required. Please log in and try again.",
|
||||
redirect_url=f"{frontend_url}/login",
|
||||
),
|
||||
status_code=401,
|
||||
)
|
||||
)
|
||||
|
||||
# Retrieve and delete login state (one-time use)
|
||||
logger.info(f"Attempting to retrieve login state for session_id: {session_id}")
|
||||
login_state = await _get_and_delete_login_state(session_id)
|
||||
if not login_state:
|
||||
logger.warning(f"Login state not found for session_id: {session_id}")
|
||||
if wants_json:
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Invalid or expired authorization session",
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(
|
||||
"invalid_request",
|
||||
"Invalid or expired authorization session. Please start over.",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
expires_at = datetime.fromisoformat(login_state["expires_at"])
|
||||
if expires_at < datetime.now(timezone.utc):
|
||||
if wants_json:
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Authorization session has expired",
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(
|
||||
"invalid_request",
|
||||
"Authorization session has expired. Please start over.",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract stored OAuth parameters
|
||||
client_id = login_state["client_id"]
|
||||
redirect_uri = login_state["redirect_uri"]
|
||||
scopes = login_state["scopes"]
|
||||
state = login_state["state"]
|
||||
code_challenge = login_state["code_challenge"]
|
||||
code_challenge_method = login_state["code_challenge_method"]
|
||||
nonce = login_state.get("nonce")
|
||||
prompt = login_state.get("prompt")
|
||||
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
try:
|
||||
# Re-validate client (in case it was deactivated during login)
|
||||
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
|
||||
|
||||
# Check if user has already authorized these scopes (skip consent if yes)
|
||||
if prompt != "consent":
|
||||
has_auth = await oauth_service.has_valid_authorization(
|
||||
user_id, client_id, scopes
|
||||
)
|
||||
if has_auth:
|
||||
# Skip consent, issue code directly
|
||||
code = await oauth_service.create_authorization_code(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scopes=scopes,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
nonce=nonce,
|
||||
)
|
||||
redirect_url = (
|
||||
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
|
||||
)
|
||||
# Return JSON with redirect URL for frontend to handle
|
||||
if wants_json:
|
||||
return JSONResponse(
|
||||
{"redirect_url": redirect_url, "needs_consent": False}
|
||||
)
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
# Generate consent token and store state in Redis
|
||||
consent_token = secrets.token_urlsafe(32)
|
||||
await _store_consent_state(
|
||||
consent_token,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scopes": scopes,
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"nonce": nonce,
|
||||
"expires_at": (
|
||||
datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# For JSON requests, return consent data instead of HTML
|
||||
if wants_json:
|
||||
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
|
||||
|
||||
scope_details = [
|
||||
{"scope": s, "description": SCOPE_DESCRIPTIONS.get(s, s)}
|
||||
for s in scopes
|
||||
]
|
||||
return JSONResponse(
|
||||
{
|
||||
"needs_consent": True,
|
||||
"consent_token": consent_token,
|
||||
"client": {
|
||||
"name": client.name,
|
||||
"logo_url": client.logoUrl,
|
||||
"privacy_policy_url": client.privacyPolicyUrl,
|
||||
"terms_url": client.termsOfServiceUrl,
|
||||
},
|
||||
"scopes": scope_details,
|
||||
"action_url": "/oauth/authorize/consent",
|
||||
}
|
||||
)
|
||||
|
||||
# Render consent page (HTML response)
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_consent_page(
|
||||
client_name=client.name,
|
||||
client_logo=client.logoUrl,
|
||||
scopes=scopes,
|
||||
consent_token=consent_token,
|
||||
action_url="/oauth/authorize/consent",
|
||||
privacy_policy_url=client.privacyPolicyUrl,
|
||||
terms_url=client.termsOfServiceUrl,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
if wants_json:
|
||||
return JSONResponse(
|
||||
{"error": e.error.value, "error_description": e.description},
|
||||
status_code=400,
|
||||
)
|
||||
# If we have a valid redirect_uri, redirect with error
|
||||
try:
|
||||
client = await oauth_service.get_client(client_id)
|
||||
if client and redirect_uri in client.redirectUris:
|
||||
return e.to_redirect(redirect_uri)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _add_security_headers(
|
||||
HTMLResponse(
|
||||
render_error_page(e.error.value, e.description or "An error occurred"),
|
||||
status_code=400,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Token Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.post("/token", response_model=TokenResponse)
|
||||
async def token(
|
||||
request: Request,
|
||||
grant_type: str = Form(...),
|
||||
code: Optional[str] = Form(None),
|
||||
redirect_uri: Optional[str] = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: Optional[str] = Form(None),
|
||||
code_verifier: Optional[str] = Form(None),
|
||||
refresh_token: Optional[str] = Form(None),
|
||||
scope: Optional[str] = Form(None),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint.
|
||||
|
||||
Supports:
|
||||
- authorization_code grant (with PKCE)
|
||||
- refresh_token grant
|
||||
"""
|
||||
# Rate limiting - use client_id as identifier
|
||||
rate_result = await check_rate_limit(client_id, "oauth_token")
|
||||
if not rate_result.allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Rate limit exceeded",
|
||||
headers={
|
||||
"Retry-After": str(int(rate_result.retry_after or 60)),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
"X-RateLimit-Reset": str(int(rate_result.reset_at)),
|
||||
},
|
||||
)
|
||||
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
try:
|
||||
# Validate client authentication
|
||||
await oauth_service.validate_client_secret(client_id, client_secret)
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
# Validate required parameters
|
||||
if not code:
|
||||
raise InvalidRequestError("'code' is required")
|
||||
if not redirect_uri:
|
||||
raise InvalidRequestError("'redirect_uri' is required")
|
||||
if not code_verifier:
|
||||
raise InvalidRequestError("'code_verifier' is required for PKCE")
|
||||
|
||||
return await oauth_service.exchange_authorization_code(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
elif grant_type == "refresh_token":
|
||||
if not refresh_token:
|
||||
raise InvalidRequestError("'refresh_token' is required")
|
||||
|
||||
requested_scopes = _parse_scopes(scope) if scope else None
|
||||
return await oauth_service.refresh_access_token(
|
||||
refresh_token=refresh_token,
|
||||
client_id=client_id,
|
||||
requested_scopes=requested_scopes,
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnsupportedGrantTypeError(grant_type)
|
||||
|
||||
except OAuthError as e:
|
||||
# 401 for client auth failure, 400 for other validation errors (per RFC 6749)
|
||||
raise e.to_http_exception(401 if isinstance(e, InvalidClientError) else 400)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# UserInfo Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.get("/userinfo", response_model=UserInfoResponse)
|
||||
async def userinfo(request: Request) -> UserInfoResponse:
|
||||
"""
|
||||
OIDC UserInfo Endpoint.
|
||||
|
||||
Returns user profile information based on the granted scopes.
|
||||
"""
|
||||
token_service = get_token_service()
|
||||
|
||||
# Extract bearer token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Bearer token required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = auth_header[7:]
|
||||
|
||||
try:
|
||||
# Verify token
|
||||
claims = token_service.verify_access_token(token)
|
||||
|
||||
# Check token is not revoked
|
||||
token_hash = token_service.hash_token(token)
|
||||
stored_token = await prisma.oauthaccesstoken.find_unique(
|
||||
where={"tokenHash": token_hash}
|
||||
)
|
||||
|
||||
if not stored_token or stored_token.revokedAt:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has been revoked",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Update last used
|
||||
await prisma.oauthaccesstoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Get user info based on scopes
|
||||
user = await prisma.user.find_unique(where={"id": claims.sub})
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
scopes = claims.scope.split()
|
||||
|
||||
# Build response based on scopes
|
||||
email = user.email if "email" in scopes else None
|
||||
email_verified = user.emailVerified if "email" in scopes else None
|
||||
name = user.name if "profile" in scopes else None
|
||||
updated_at = int(user.updatedAt.timestamp()) if "profile" in scopes else None
|
||||
|
||||
return UserInfoResponse(
|
||||
sub=claims.sub,
|
||||
email=email,
|
||||
email_verified=email_verified,
|
||||
name=name,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid token: {str(e)}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.post("/revoke")
|
||||
async def revoke(
|
||||
request: Request,
|
||||
token: str = Form(...),
|
||||
token_type_hint: Optional[str] = Form(None),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
||||
|
||||
Revokes an access token or refresh token.
|
||||
"""
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
# Note: Per RFC 7009, always return 200 even if token not found
|
||||
await oauth_service.revoke_token(token, token_type_hint)
|
||||
|
||||
return JSONResponse(content={}, status_code=200)
|
||||
625
autogpt_platform/backend/backend/server/oauth/service.py
Normal file
625
autogpt_platform/backend/backend/server/oauth/service.py
Normal file
@@ -0,0 +1,625 @@
|
||||
"""
|
||||
Core OAuth 2.0 service logic.
|
||||
|
||||
Handles:
|
||||
- Client validation and lookup
|
||||
- Authorization code generation and exchange
|
||||
- Token issuance and refresh
|
||||
- User consent management
|
||||
- Audit logging
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from prisma.enums import OAuthClientStatus
|
||||
from prisma.models import OAuthAuthorization, OAuthClient, User
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.server.oauth.errors import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
InvalidRequestError,
|
||||
InvalidScopeError,
|
||||
)
|
||||
from backend.server.oauth.models import TokenResponse
|
||||
from backend.server.oauth.pkce import verify_code_challenge
|
||||
from backend.server.oauth.token_service import OAuthTokenService, get_token_service
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""Core OAuth 2.0 service."""
|
||||
|
||||
def __init__(self, token_service: Optional[OAuthTokenService] = None):
|
||||
self.token_service = token_service or get_token_service()
|
||||
|
||||
# ================================================================
|
||||
# Client Operations
|
||||
# ================================================================
|
||||
|
||||
async def get_client(self, client_id: str) -> Optional[OAuthClient]:
|
||||
"""Get an OAuth client by client_id."""
|
||||
return await prisma.oauthclient.find_unique(where={"clientId": client_id})
|
||||
|
||||
async def validate_client(
|
||||
self,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthClient:
|
||||
"""
|
||||
Validate a client for authorization.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
redirect_uri: Requested redirect URI
|
||||
scopes: Requested scopes
|
||||
|
||||
Returns:
|
||||
Validated OAuthClient
|
||||
|
||||
Raises:
|
||||
InvalidClientError: Client not found or inactive
|
||||
InvalidRequestError: Invalid redirect URI
|
||||
InvalidScopeError: Invalid scopes requested
|
||||
"""
|
||||
client = await self.get_client(client_id)
|
||||
|
||||
if not client:
|
||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
||||
|
||||
if client.status != OAuthClientStatus.ACTIVE:
|
||||
raise InvalidClientError(f"Client '{client_id}' is not active")
|
||||
|
||||
# Validate redirect URI (exact match required)
|
||||
if redirect_uri not in client.redirectUris:
|
||||
raise InvalidRequestError(
|
||||
f"Redirect URI '{redirect_uri}' is not registered for this client"
|
||||
)
|
||||
|
||||
# Validate scopes
|
||||
invalid_scopes = set(scopes) - set(client.allowedScopes)
|
||||
if invalid_scopes:
|
||||
raise InvalidScopeError(
|
||||
f"Scopes not allowed for this client: {', '.join(invalid_scopes)}"
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
async def validate_client_secret(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: Optional[str],
|
||||
) -> OAuthClient:
|
||||
"""
|
||||
Validate client authentication for token endpoint.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
client_secret: Client secret (for confidential clients)
|
||||
|
||||
Returns:
|
||||
Validated OAuthClient
|
||||
|
||||
Raises:
|
||||
InvalidClientError: Invalid client or credentials
|
||||
"""
|
||||
client = await self.get_client(client_id)
|
||||
|
||||
if not client:
|
||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
||||
|
||||
if client.status != OAuthClientStatus.ACTIVE:
|
||||
raise InvalidClientError(f"Client '{client_id}' is not active")
|
||||
|
||||
# Confidential clients must provide secret
|
||||
if client.clientType == "confidential":
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required")
|
||||
|
||||
# Hash and compare
|
||||
secret_hash = self._hash_secret(
|
||||
client_secret, client.clientSecretSalt or ""
|
||||
)
|
||||
if not secrets.compare_digest(secret_hash, client.clientSecretHash or ""):
|
||||
raise InvalidClientError("Invalid client credentials")
|
||||
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def _hash_secret(secret: str, salt: str) -> str:
|
||||
"""Hash a client secret with salt."""
|
||||
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
|
||||
|
||||
# ================================================================
|
||||
# Authorization Code Operations
|
||||
# ================================================================
|
||||
|
||||
async def create_authorization_code(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scopes: list[str],
|
||||
code_challenge: str,
|
||||
code_challenge_method: str = "S256",
|
||||
nonce: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new authorization code.
|
||||
|
||||
Args:
|
||||
user_id: User who authorized
|
||||
client_id: Client being authorized
|
||||
redirect_uri: Redirect URI for callback
|
||||
scopes: Granted scopes
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method (S256)
|
||||
nonce: OIDC nonce (optional)
|
||||
|
||||
Returns:
|
||||
Authorization code string
|
||||
"""
|
||||
code = secrets.token_urlsafe(32)
|
||||
code_hash = self.token_service.hash_token(code)
|
||||
|
||||
# Get the OAuthClient to link
|
||||
client = await self.get_client(client_id)
|
||||
if not client:
|
||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
||||
|
||||
await prisma.oauthauthorizationcode.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"codeHash": code_hash,
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
"redirectUri": redirect_uri,
|
||||
"scopes": scopes,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
"nonce": nonce,
|
||||
"expiresAt": datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
}
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Exchange an authorization code for tokens.
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
client_id: Client identifier
|
||||
redirect_uri: Must match original redirect URI
|
||||
code_verifier: PKCE code verifier
|
||||
|
||||
Returns:
|
||||
TokenResponse with access token, refresh token, etc.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: Invalid or expired code
|
||||
InvalidRequestError: PKCE verification failed
|
||||
"""
|
||||
code_hash = self.token_service.hash_token(code)
|
||||
|
||||
# Find the authorization code
|
||||
auth_code = await prisma.oauthauthorizationcode.find_unique(
|
||||
where={"codeHash": code_hash},
|
||||
include={"Client": True, "User": True},
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
raise InvalidGrantError("Authorization code not found")
|
||||
|
||||
# Ensure Client relation is loaded
|
||||
if not auth_code.Client:
|
||||
raise InvalidGrantError("Authorization code client not found")
|
||||
|
||||
# Check if already used
|
||||
if auth_code.usedAt:
|
||||
# Code reuse is a security incident - revoke all tokens for this authorization
|
||||
await self._revoke_tokens_for_client_user(
|
||||
auth_code.Client.clientId, auth_code.userId
|
||||
)
|
||||
raise InvalidGrantError("Authorization code has already been used")
|
||||
|
||||
# Check expiration
|
||||
if auth_code.expiresAt < datetime.now(timezone.utc):
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
|
||||
# Validate client
|
||||
if auth_code.Client.clientId != client_id:
|
||||
raise InvalidGrantError("Client ID mismatch")
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirectUri != redirect_uri:
|
||||
raise InvalidGrantError("Redirect URI mismatch")
|
||||
|
||||
# Verify PKCE
|
||||
if not verify_code_challenge(
|
||||
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
|
||||
):
|
||||
raise InvalidRequestError("PKCE verification failed")
|
||||
|
||||
# Mark code as used
|
||||
await prisma.oauthauthorizationcode.update(
|
||||
where={"id": auth_code.id},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Create or update authorization record
|
||||
await self._upsert_authorization(
|
||||
auth_code.userId, auth_code.Client.id, auth_code.scopes
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
return await self._create_tokens(
|
||||
user_id=auth_code.userId,
|
||||
client=auth_code.Client,
|
||||
scopes=auth_code.scopes,
|
||||
nonce=auth_code.nonce,
|
||||
user=auth_code.User,
|
||||
)
|
||||
|
||||
async def refresh_access_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
requested_scopes: Optional[list[str]] = None,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Refresh an access token using a refresh token.
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token string
|
||||
client_id: Client identifier
|
||||
requested_scopes: Optionally request fewer scopes
|
||||
|
||||
Returns:
|
||||
New TokenResponse
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: Invalid or expired refresh token
|
||||
"""
|
||||
token_hash = self.token_service.hash_token(refresh_token)
|
||||
|
||||
# Find the refresh token
|
||||
stored_token = await prisma.oauthrefreshtoken.find_unique(
|
||||
where={"tokenHash": token_hash},
|
||||
include={"Client": True, "User": True},
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
raise InvalidGrantError("Refresh token not found")
|
||||
|
||||
# Ensure Client relation is loaded
|
||||
if not stored_token.Client:
|
||||
raise InvalidGrantError("Refresh token client not found")
|
||||
|
||||
# Check if revoked
|
||||
if stored_token.revokedAt:
|
||||
raise InvalidGrantError("Refresh token has been revoked")
|
||||
|
||||
# Check expiration
|
||||
if stored_token.expiresAt < datetime.now(timezone.utc):
|
||||
raise InvalidGrantError("Refresh token has expired")
|
||||
|
||||
# Validate client
|
||||
if stored_token.Client.clientId != client_id:
|
||||
raise InvalidGrantError("Client ID mismatch")
|
||||
|
||||
# Determine scopes
|
||||
scopes = stored_token.scopes
|
||||
if requested_scopes:
|
||||
# Can only request a subset of original scopes
|
||||
invalid = set(requested_scopes) - set(stored_token.scopes)
|
||||
if invalid:
|
||||
raise InvalidScopeError(
|
||||
f"Cannot request scopes not in original grant: {', '.join(invalid)}"
|
||||
)
|
||||
scopes = requested_scopes
|
||||
|
||||
# Generate new tokens (rotates refresh token)
|
||||
return await self._create_tokens(
|
||||
user_id=stored_token.userId,
|
||||
client=stored_token.Client,
|
||||
scopes=scopes,
|
||||
user=stored_token.User,
|
||||
old_refresh_token_id=stored_token.id,
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# Token Operations
|
||||
# ================================================================
|
||||
|
||||
async def _create_tokens(
|
||||
self,
|
||||
user_id: str,
|
||||
client: OAuthClient,
|
||||
scopes: list[str],
|
||||
user: Optional[User] = None,
|
||||
nonce: Optional[str] = None,
|
||||
old_refresh_token_id: Optional[str] = None,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Create access and refresh tokens.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client: OAuth client
|
||||
scopes: Granted scopes
|
||||
user: User object (for ID token claims)
|
||||
nonce: OIDC nonce
|
||||
old_refresh_token_id: ID of refresh token being rotated
|
||||
|
||||
Returns:
|
||||
TokenResponse
|
||||
"""
|
||||
# Generate access token
|
||||
access_token, access_expires_at = self.token_service.generate_access_token(
|
||||
user_id=user_id,
|
||||
client_id=client.clientId,
|
||||
scopes=scopes,
|
||||
expires_in=client.tokenLifetimeSecs,
|
||||
)
|
||||
|
||||
# Store access token hash
|
||||
await prisma.oauthaccesstoken.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"tokenHash": self.token_service.hash_token(access_token),
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
"scopes": scopes,
|
||||
"expiresAt": access_expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
# Generate refresh token
|
||||
refresh_token = self.token_service.generate_refresh_token()
|
||||
refresh_expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=client.refreshTokenLifetimeSecs
|
||||
)
|
||||
|
||||
await prisma.oauthrefreshtoken.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"tokenHash": self.token_service.hash_token(refresh_token),
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
"scopes": scopes,
|
||||
"expiresAt": refresh_expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
# Revoke old refresh token if rotating
|
||||
if old_refresh_token_id:
|
||||
await prisma.oauthrefreshtoken.update(
|
||||
where={"id": old_refresh_token_id},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Generate ID token if openid scope requested
|
||||
id_token = None
|
||||
if "openid" in scopes and user:
|
||||
email = user.email if "email" in scopes else None
|
||||
name = user.name if "profile" in scopes else None
|
||||
id_token = self.token_service.generate_id_token(
|
||||
user_id=user_id,
|
||||
client_id=client.clientId,
|
||||
email=email,
|
||||
name=name,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
# Audit log
|
||||
await self._audit_log(
|
||||
event_type="token.issued",
|
||||
user_id=user_id,
|
||||
client_id=client.clientId,
|
||||
details={"scopes": scopes},
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=client.tokenLifetimeSecs,
|
||||
refresh_token=refresh_token,
|
||||
scope=" ".join(scopes),
|
||||
id_token=id_token,
|
||||
)
|
||||
|
||||
async def revoke_token(
|
||||
self,
|
||||
token: str,
|
||||
token_type_hint: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke an access or refresh token.
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
token_type_hint: Hint about token type
|
||||
|
||||
Returns:
|
||||
True if token was found and revoked
|
||||
"""
|
||||
token_hash = self.token_service.hash_token(token)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Try refresh token first if hinted or no hint
|
||||
if token_type_hint in (None, "refresh_token"):
|
||||
result = await prisma.oauthrefreshtoken.update_many(
|
||||
where={"tokenHash": token_hash, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
if result > 0:
|
||||
return True
|
||||
|
||||
# Try access token
|
||||
if token_type_hint in (None, "access_token"):
|
||||
result = await prisma.oauthaccesstoken.update_many(
|
||||
where={"tokenHash": token_hash, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
if result > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _revoke_tokens_for_client_user(
|
||||
self,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Revoke all tokens for a client-user pair (security incident response)."""
|
||||
client = await self.get_client(client_id)
|
||||
if not client:
|
||||
return
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
await prisma.oauthaccesstoken.update_many(
|
||||
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
|
||||
await prisma.oauthrefreshtoken.update_many(
|
||||
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
|
||||
await self._audit_log(
|
||||
event_type="tokens.revoked.security",
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
details={"reason": "authorization_code_reuse"},
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# Authorization (Consent) Operations
|
||||
# ================================================================
|
||||
|
||||
async def get_authorization(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
) -> Optional[OAuthAuthorization]:
|
||||
"""Get existing authorization for user-client pair."""
|
||||
client = await self.get_client(client_id)
|
||||
if not client:
|
||||
return None
|
||||
|
||||
return await prisma.oauthauthorization.find_unique(
|
||||
where={
|
||||
"userId_clientId": {
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def has_valid_authorization(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has already authorized these scopes for this client.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client_id: Client identifier
|
||||
scopes: Requested scopes
|
||||
|
||||
Returns:
|
||||
True if user has already authorized all requested scopes
|
||||
"""
|
||||
auth = await self.get_authorization(user_id, client_id)
|
||||
if not auth or auth.revokedAt:
|
||||
return False
|
||||
|
||||
# Check if all requested scopes are already authorized
|
||||
return set(scopes).issubset(set(auth.scopes))
|
||||
|
||||
async def _upsert_authorization(
|
||||
self,
|
||||
user_id: str,
|
||||
client_db_id: str,
|
||||
scopes: list[str],
|
||||
) -> None:
|
||||
"""Create or update an authorization record."""
|
||||
existing = await prisma.oauthauthorization.find_unique(
|
||||
where={
|
||||
"userId_clientId": {
|
||||
"userId": user_id,
|
||||
"clientId": client_db_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if existing:
|
||||
# Merge scopes
|
||||
merged_scopes = list(set(existing.scopes) | set(scopes))
|
||||
await prisma.oauthauthorization.update(
|
||||
where={"id": existing.id},
|
||||
data={"scopes": merged_scopes, "revokedAt": None},
|
||||
)
|
||||
else:
|
||||
await prisma.oauthauthorization.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"userId": user_id,
|
||||
"clientId": client_db_id,
|
||||
"scopes": scopes,
|
||||
}
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# Audit Logging
|
||||
# ================================================================
|
||||
|
||||
async def _audit_log(
|
||||
self,
|
||||
event_type: str,
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
grant_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
details: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Create an audit log entry."""
|
||||
# Convert details to JSON for Prisma's Json field
|
||||
details_json = json.dumps(details or {})
|
||||
await prisma.oauthauditlog.create(
|
||||
data={
|
||||
"eventType": event_type,
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"grantId": grant_id,
|
||||
"ipAddress": ip_address,
|
||||
"userAgent": user_agent,
|
||||
"details": json.loads(details_json), # type: ignore[arg-type]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_oauth_service: Optional[OAuthService] = None
|
||||
|
||||
|
||||
def get_oauth_service() -> OAuthService:
|
||||
"""Get the singleton OAuth service instance."""
|
||||
global _oauth_service
|
||||
if _oauth_service is None:
|
||||
_oauth_service = OAuthService()
|
||||
return _oauth_service
|
||||
298
autogpt_platform/backend/backend/server/oauth/token_service.py
Normal file
298
autogpt_platform/backend/backend/server/oauth/token_service.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
JWT Token Service for OAuth 2.0 Provider.
|
||||
|
||||
Handles generation and validation of:
|
||||
- Access tokens (JWT)
|
||||
- Refresh tokens (opaque)
|
||||
- ID tokens (JWT, OIDC)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
||||
RSAPrivateKey,
|
||||
RSAPublicKey,
|
||||
generate_private_key,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
class TokenClaims(BaseModel):
|
||||
"""Decoded token claims."""
|
||||
|
||||
iss: str # Issuer
|
||||
sub: str # Subject (user ID)
|
||||
aud: str # Audience (client ID)
|
||||
exp: int # Expiration timestamp
|
||||
iat: int # Issued at timestamp
|
||||
jti: str # JWT ID
|
||||
scope: str # Space-separated scopes
|
||||
client_id: str # Client ID
|
||||
|
||||
|
||||
class OAuthTokenService:
|
||||
"""
|
||||
Service for generating and validating OAuth tokens.
|
||||
|
||||
Uses RS256 (RSA with SHA-256) for JWT signing.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self._settings = settings or Settings()
|
||||
self._private_key: Optional[RSAPrivateKey] = None
|
||||
self._public_key: Optional[RSAPublicKey] = None
|
||||
self._algorithm = "RS256"
|
||||
|
||||
@property
|
||||
def issuer(self) -> str:
|
||||
"""Get the token issuer URL."""
|
||||
return self._settings.config.platform_base_url or "https://platform.agpt.co"
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
"""Get the key ID for JWKS."""
|
||||
return self._settings.secrets.oauth_jwt_key_id or "default-key-id"
|
||||
|
||||
def _get_private_key(self) -> RSAPrivateKey:
|
||||
"""Load or generate the private key."""
|
||||
if self._private_key is not None:
|
||||
return self._private_key
|
||||
|
||||
key_pem = self._settings.secrets.oauth_jwt_private_key
|
||||
if key_pem:
|
||||
loaded_key = serialization.load_pem_private_key(
|
||||
key_pem.encode(), password=None
|
||||
)
|
||||
if not isinstance(loaded_key, RSAPrivateKey):
|
||||
raise ValueError("OAuth JWT private key must be RSA")
|
||||
self._private_key = loaded_key
|
||||
else:
|
||||
# Generate a key for development (should not be used in production)
|
||||
self._private_key = generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
return self._private_key
|
||||
|
||||
def _get_public_key(self) -> RSAPublicKey:
|
||||
"""Get the public key from the private key."""
|
||||
if self._public_key is not None:
|
||||
return self._public_key
|
||||
|
||||
key_pem = self._settings.secrets.oauth_jwt_public_key
|
||||
if key_pem:
|
||||
loaded_key = serialization.load_pem_public_key(key_pem.encode())
|
||||
if not isinstance(loaded_key, RSAPublicKey):
|
||||
raise ValueError("OAuth JWT public key must be RSA")
|
||||
self._public_key = loaded_key
|
||||
else:
|
||||
self._public_key = self._get_private_key().public_key()
|
||||
return self._public_key
|
||||
|
||||
def generate_access_token(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
expires_in: int = 3600,
|
||||
) -> tuple[str, datetime]:
|
||||
"""
|
||||
Generate a JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: User ID (subject)
|
||||
client_id: Client ID (audience)
|
||||
scopes: List of granted scopes
|
||||
expires_in: Token lifetime in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (token string, expiration datetime)
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=expires_in)
|
||||
|
||||
payload = {
|
||||
"iss": self.issuer,
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int(expires_at.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": secrets.token_urlsafe(16),
|
||||
"scope": " ".join(scopes),
|
||||
"client_id": client_id,
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._get_private_key(),
|
||||
algorithm=self._algorithm,
|
||||
headers={"kid": self.key_id},
|
||||
)
|
||||
return token, expires_at
|
||||
|
||||
def generate_refresh_token(self) -> str:
|
||||
"""
|
||||
Generate an opaque refresh token.
|
||||
|
||||
Returns:
|
||||
URL-safe random token string
|
||||
"""
|
||||
return secrets.token_urlsafe(48)
|
||||
|
||||
def generate_id_token(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
email: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
nonce: Optional[str] = None,
|
||||
expires_in: int = 3600,
|
||||
) -> str:
|
||||
"""
|
||||
Generate an OIDC ID token.
|
||||
|
||||
Args:
|
||||
user_id: User ID (subject)
|
||||
client_id: Client ID (audience)
|
||||
email: User's email (optional)
|
||||
name: User's name (optional)
|
||||
nonce: OIDC nonce for replay protection (optional)
|
||||
expires_in: Token lifetime in seconds
|
||||
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=expires_in)
|
||||
|
||||
payload = {
|
||||
"iss": self.issuer,
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int(expires_at.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"auth_time": int(now.timestamp()),
|
||||
}
|
||||
|
||||
if email:
|
||||
payload["email"] = email
|
||||
payload["email_verified"] = True
|
||||
if name:
|
||||
payload["name"] = name
|
||||
if nonce:
|
||||
payload["nonce"] = nonce
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
self._get_private_key(),
|
||||
algorithm=self._algorithm,
|
||||
headers={"kid": self.key_id},
|
||||
)
|
||||
|
||||
def verify_access_token(
|
||||
self,
|
||||
token: str,
|
||||
expected_client_id: Optional[str] = None,
|
||||
) -> TokenClaims:
|
||||
"""
|
||||
Verify and decode a JWT access token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
expected_client_id: Expected client ID (audience)
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
|
||||
Raises:
|
||||
jwt.ExpiredSignatureError: Token has expired
|
||||
jwt.InvalidTokenError: Token is invalid
|
||||
"""
|
||||
options = {}
|
||||
if expected_client_id:
|
||||
options["audience"] = expected_client_id
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self._get_public_key(),
|
||||
algorithms=[self._algorithm],
|
||||
issuer=self.issuer,
|
||||
options={"verify_aud": bool(expected_client_id)},
|
||||
**options,
|
||||
)
|
||||
|
||||
return TokenClaims(
|
||||
iss=payload["iss"],
|
||||
sub=payload["sub"],
|
||||
aud=payload.get("aud", payload.get("client_id", "")),
|
||||
exp=payload["exp"],
|
||||
iat=payload["iat"],
|
||||
jti=payload["jti"],
|
||||
scope=payload.get("scope", ""),
|
||||
client_id=payload.get("client_id", payload.get("aud", "")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hash_token(token: str) -> str:
|
||||
"""
|
||||
Hash a token for secure storage.
|
||||
|
||||
Args:
|
||||
token: Token string to hash
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of the token
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
def get_jwks(self) -> dict:
|
||||
"""
|
||||
Get the JSON Web Key Set (JWKS) for public key distribution.
|
||||
|
||||
Returns:
|
||||
JWKS dictionary with public key(s)
|
||||
"""
|
||||
public_key = self._get_public_key()
|
||||
public_numbers = public_key.public_numbers()
|
||||
|
||||
# Convert to base64url encoding without padding
|
||||
def int_to_base64url(n: int, length: int) -> str:
|
||||
data = n.to_bytes(length, byteorder="big")
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
# RSA modulus and exponent
|
||||
n = int_to_base64url(public_numbers.n, (public_numbers.n.bit_length() + 7) // 8)
|
||||
e = int_to_base64url(public_numbers.e, 3)
|
||||
|
||||
return {
|
||||
"keys": [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"kid": self.key_id,
|
||||
"alg": self._algorithm,
|
||||
"n": n,
|
||||
"e": e,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_token_service: Optional[OAuthTokenService] = None
|
||||
|
||||
|
||||
def get_token_service() -> OAuthTokenService:
|
||||
"""Get the singleton token service instance."""
|
||||
global _token_service
|
||||
if _token_service is None:
|
||||
_token_service = OAuthTokenService()
|
||||
return _token_service
|
||||
@@ -21,6 +21,7 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.integrations.connect_router
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
@@ -44,6 +45,7 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.server.oauth import client_router, discovery_router, oauth_router
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
@@ -300,6 +302,18 @@ app.include_router(
|
||||
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
# OAuth Provider routes
|
||||
app.include_router(oauth_router, tags=["oauth"], prefix="")
|
||||
app.include_router(discovery_router, tags=["oidc-discovery"], prefix="")
|
||||
app.include_router(client_router, tags=["oauth-clients"], prefix="")
|
||||
|
||||
# Integration Connect popup routes (for Credential Broker)
|
||||
app.include_router(
|
||||
backend.server.integrations.connect_router.connect_router,
|
||||
tags=["integration-connect"],
|
||||
prefix="",
|
||||
)
|
||||
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
async def health():
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
@@ -49,13 +50,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for LLM tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for Firecrawl tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
|
||||
@@ -802,18 +802,16 @@ async def add_store_agent_to_library(
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
},
|
||||
isCreatedByUser=False,
|
||||
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
|
||||
@@ -248,7 +248,9 @@ async def log_search_term(search_query: str):
|
||||
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
try:
|
||||
await prisma.models.SearchTerms.prisma().create(
|
||||
data={"searchTerm": search_query, "createdDate": date}
|
||||
data=prisma.types.SearchTermsCreateInput(
|
||||
searchTerm=search_query, createdDate=date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Fail silently here so that logging search terms doesn't break the app
|
||||
@@ -1430,13 +1432,10 @@ async def _approve_sub_agent(
|
||||
|
||||
# Create new version if no matching version found
|
||||
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
||||
await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data={
|
||||
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
||||
"version": next_version,
|
||||
"storeListingId": listing.id,
|
||||
}
|
||||
)
|
||||
sub_agent_data = _create_sub_agent_version_data(sub_graph, heading, main_agent_name)
|
||||
sub_agent_data["version"] = next_version
|
||||
sub_agent_data["storeListingId"] = listing.id
|
||||
await prisma.models.StoreListingVersion.prisma(tx).create(data=sub_agent_data)
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||
)
|
||||
|
||||
228
autogpt_platform/backend/backend/util/rate_limiter.py
Normal file
228
autogpt_platform/backend/backend/util/rate_limiter.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Rate Limiting for External API.
|
||||
|
||||
Implements sliding window rate limiting using Redis for distributed systems.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
remaining: int
|
||||
reset_at: float
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Redis-based sliding window rate limiter.
|
||||
|
||||
Supports multiple limit tiers (per-minute, per-hour, per-day).
|
||||
"""
|
||||
|
||||
def __init__(self, prefix: str = "ratelimit"):
|
||||
self.prefix = prefix
|
||||
|
||||
def _make_key(self, identifier: str, window: str) -> str:
|
||||
"""Create a Redis key for the rate limit counter."""
|
||||
return f"{self.prefix}:{identifier}:{window}"
|
||||
|
||||
async def check_and_increment(
|
||||
self,
|
||||
identifier: str,
|
||||
limits: dict[str, tuple[int, int]], # window_name -> (limit, window_seconds)
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check rate limits and increment counters if allowed.
|
||||
|
||||
Uses atomic increment-first approach to prevent race conditions:
|
||||
1. Increment all counters atomically
|
||||
2. Check if any limit exceeded
|
||||
3. If exceeded, decrement and return rate limit error
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier (e.g., client_id, client_id:user_id)
|
||||
limits: Dictionary of limit configurations
|
||||
e.g., {"minute": (60, 60), "hour": (1000, 3600)}
|
||||
|
||||
Returns:
|
||||
RateLimitResult with allowed status and remaining quota
|
||||
"""
|
||||
if not limits:
|
||||
# No limits configured, allow request
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining=999999,
|
||||
reset_at=time.time() + 60,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
current_time = time.time()
|
||||
|
||||
# Increment all counters atomically first
|
||||
incremented_keys: list[tuple[str, int, int, int]] = (
|
||||
[]
|
||||
) # (key, new_count, limit, window_seconds)
|
||||
|
||||
for window_name, (limit, window_seconds) in limits.items():
|
||||
key = self._make_key(identifier, window_name)
|
||||
|
||||
# Atomic increment
|
||||
new_count = await redis.incr(key)
|
||||
|
||||
# Set expiry if this is a new key
|
||||
if new_count == 1:
|
||||
await redis.expire(key, window_seconds)
|
||||
|
||||
incremented_keys.append((key, new_count, limit, window_seconds))
|
||||
|
||||
# Check if any limit exceeded
|
||||
for key, new_count, limit, window_seconds in incremented_keys:
|
||||
if new_count > limit:
|
||||
# Rate limit exceeded - decrement all counters we just incremented
|
||||
for decr_key, _, _, _ in incremented_keys:
|
||||
await redis.decr(decr_key)
|
||||
|
||||
ttl = await redis.ttl(key)
|
||||
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_at=reset_at,
|
||||
retry_after=ttl if ttl > 0 else window_seconds,
|
||||
)
|
||||
|
||||
# All limits passed
|
||||
min_remaining = float("inf")
|
||||
earliest_reset = current_time
|
||||
|
||||
for key, new_count, limit, window_seconds in incremented_keys:
|
||||
remaining = max(0, limit - new_count)
|
||||
min_remaining = min(min_remaining, remaining)
|
||||
|
||||
ttl = await redis.ttl(key)
|
||||
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
|
||||
earliest_reset = max(earliest_reset, reset_at)
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining=int(min_remaining),
|
||||
reset_at=earliest_reset,
|
||||
)
|
||||
|
||||
async def get_remaining(
|
||||
self,
|
||||
identifier: str,
|
||||
limits: dict[str, tuple[int, int]],
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Get remaining quota for all windows without incrementing.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
limits: Dictionary of limit configurations
|
||||
|
||||
Returns:
|
||||
Dictionary of remaining quota per window
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
remaining = {}
|
||||
|
||||
for window_name, (limit, _) in limits.items():
|
||||
key = self._make_key(identifier, window_name)
|
||||
count = await redis.get(key)
|
||||
current_count = int(count) if count else 0
|
||||
remaining[window_name] = max(0, limit - current_count)
|
||||
|
||||
return remaining
|
||||
|
||||
async def reset(self, identifier: str, window: Optional[str] = None) -> None:
|
||||
"""
|
||||
Reset rate limit counters.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
window: Optional specific window to reset (resets all if None)
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
|
||||
if window:
|
||||
key = self._make_key(identifier, window)
|
||||
await redis.delete(key)
|
||||
else:
|
||||
# Delete known window keys instead of scanning
|
||||
# This avoids potentially slow scan operations with many keys
|
||||
known_windows = ["minute", "hour", "day"]
|
||||
keys_to_delete = [self._make_key(identifier, w) for w in known_windows]
|
||||
# Delete all in one call (Redis handles non-existent keys gracefully)
|
||||
if keys_to_delete:
|
||||
await redis.delete(*keys_to_delete)
|
||||
|
||||
|
||||
# Default rate limits for different endpoints
|
||||
DEFAULT_RATE_LIMITS = {
|
||||
# OAuth endpoints
|
||||
"oauth_authorize": {"minute": (30, 60)}, # 30/min per IP
|
||||
"oauth_token": {"minute": (20, 60)}, # 20/min per client
|
||||
"oauth_consent": {"minute": (20, 60)}, # 20/min per IP for consent submission
|
||||
# External API endpoints
|
||||
"api_execute": {
|
||||
"minute": (10, 60),
|
||||
"hour": (100, 3600),
|
||||
}, # 10/min, 100/hour per client+user
|
||||
"api_read": {
|
||||
"minute": (60, 60),
|
||||
"hour": (1000, 3600),
|
||||
}, # 60/min, 1000/hour per client+user
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get the singleton rate limiter instance."""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
identifier: str,
|
||||
limit_type: str,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Convenience function to check rate limits.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier for the rate limit
|
||||
limit_type: Type of limit from DEFAULT_RATE_LIMITS
|
||||
|
||||
Returns:
|
||||
RateLimitResult
|
||||
"""
|
||||
limits = DEFAULT_RATE_LIMITS.get(limit_type)
|
||||
if not limits:
|
||||
# No rate limit configured, allow
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining=999999,
|
||||
reset_at=time.time() + 60,
|
||||
)
|
||||
|
||||
rate_limiter = get_rate_limiter()
|
||||
return await rate_limiter.check_and_increment(identifier, limits)
|
||||
@@ -651,6 +651,23 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
|
||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||
|
||||
# OAuth Provider JWT keys
|
||||
oauth_jwt_private_key: str = Field(
|
||||
default="",
|
||||
description="RSA private key for signing OAuth tokens (PEM format). "
|
||||
"If not set, a development key will be auto-generated.",
|
||||
)
|
||||
oauth_jwt_public_key: str = Field(
|
||||
default="",
|
||||
description="RSA public key for verifying OAuth tokens (PEM format). "
|
||||
"If not set, derived from private key.",
|
||||
)
|
||||
oauth_jwt_key_id: str = Field(
|
||||
default="autogpt-oauth-key-1",
|
||||
description="Key ID (kid) for JWKS. Used to identify the signing key.",
|
||||
)
|
||||
|
||||
# Add more secret fields as needed
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
43
autogpt_platform/backend/backend/util/time.py
Normal file
43
autogpt_platform/backend/backend/util/time.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Time utilities for the backend.
|
||||
|
||||
Common datetime operations used across the codebase.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
|
||||
def expiration_datetime(seconds: int) -> datetime:
|
||||
"""
|
||||
Calculate an expiration datetime from now.
|
||||
|
||||
Args:
|
||||
seconds: Number of seconds until expiration
|
||||
|
||||
Returns:
|
||||
Datetime when the item will expire (UTC)
|
||||
"""
|
||||
return datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
|
||||
|
||||
def is_expired(dt: datetime) -> bool:
|
||||
"""
|
||||
Check if a datetime has passed.
|
||||
|
||||
Args:
|
||||
dt: The datetime to check (should be timezone-aware)
|
||||
|
||||
Returns:
|
||||
True if the datetime is in the past
|
||||
"""
|
||||
return dt < datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""
|
||||
Get the current UTC time.
|
||||
|
||||
Returns:
|
||||
Current datetime in UTC
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
46
autogpt_platform/backend/backend/util/url.py
Normal file
46
autogpt_platform/backend/backend/util/url.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
URL and domain validation utilities.
|
||||
|
||||
Common URL validation operations used across the codebase.
|
||||
"""
|
||||
|
||||
|
||||
def matches_domain_pattern(hostname: str, domain_pattern: str) -> bool:
|
||||
"""
|
||||
Check if a hostname matches a domain pattern.
|
||||
|
||||
Supports wildcard patterns (*.example.com) which match:
|
||||
- The base domain (example.com)
|
||||
- Any subdomain (sub.example.com, deep.sub.example.com)
|
||||
|
||||
Args:
|
||||
hostname: The hostname to check (e.g., "api.example.com")
|
||||
domain_pattern: The pattern to match against (e.g., "*.example.com" or "example.com")
|
||||
|
||||
Returns:
|
||||
True if the hostname matches the pattern
|
||||
"""
|
||||
hostname = hostname.lower()
|
||||
domain_pattern = domain_pattern.lower()
|
||||
|
||||
if domain_pattern.startswith("*."):
|
||||
# Wildcard domain - matches base and any subdomains
|
||||
base_domain = domain_pattern[2:]
|
||||
return hostname == base_domain or hostname.endswith("." + base_domain)
|
||||
|
||||
# Exact match
|
||||
return hostname == domain_pattern
|
||||
|
||||
|
||||
def hostname_matches_any_domain(hostname: str, allowed_domains: list[str]) -> bool:
|
||||
"""
|
||||
Check if a hostname matches any of the allowed domain patterns.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to check
|
||||
allowed_domains: List of allowed domain patterns (supports wildcards)
|
||||
|
||||
Returns:
|
||||
True if the hostname matches any pattern
|
||||
"""
|
||||
return any(matches_domain_pattern(hostname, domain) for domain in allowed_domains)
|
||||
@@ -0,0 +1,249 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "OAuthClientStatus" AS ENUM ('ACTIVE', 'SUSPENDED');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "CredentialGrantPermission" AS ENUM ('USE', 'DELETE');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthClient" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"clientSecretHash" TEXT,
|
||||
"clientSecretSalt" TEXT,
|
||||
"clientType" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"homepageUrl" TEXT,
|
||||
"privacyPolicyUrl" TEXT,
|
||||
"termsOfServiceUrl" TEXT,
|
||||
"redirectUris" TEXT[],
|
||||
"allowedScopes" TEXT[],
|
||||
"webhookDomains" TEXT[],
|
||||
"requirePkce" BOOLEAN NOT NULL DEFAULT true,
|
||||
"tokenLifetimeSecs" INTEGER NOT NULL DEFAULT 3600,
|
||||
"refreshTokenLifetimeSecs" INTEGER NOT NULL DEFAULT 2592000,
|
||||
"status" "OAuthClientStatus" NOT NULL DEFAULT 'ACTIVE',
|
||||
"ownerId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "OAuthClient_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthAuthorization" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"scopes" TEXT[],
|
||||
"revokedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "OAuthAuthorization_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthAuthorizationCode" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"codeHash" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"redirectUri" TEXT NOT NULL,
|
||||
"scopes" TEXT[],
|
||||
"nonce" TEXT,
|
||||
"codeChallenge" TEXT NOT NULL,
|
||||
"codeChallengeMethod" TEXT NOT NULL DEFAULT 'S256',
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"usedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "OAuthAuthorizationCode_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthAccessToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"tokenHash" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"scopes" TEXT[],
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"revokedAt" TIMESTAMP(3),
|
||||
"lastUsedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "OAuthAccessToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthRefreshToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"tokenHash" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"scopes" TEXT[],
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"revokedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "OAuthRefreshToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "CredentialGrant" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"credentialId" TEXT NOT NULL,
|
||||
"provider" TEXT NOT NULL,
|
||||
"grantedScopes" TEXT[],
|
||||
"permissions" "CredentialGrantPermission"[],
|
||||
"expiresAt" TIMESTAMP(3),
|
||||
"revokedAt" TIMESTAMP(3),
|
||||
"lastUsedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "CredentialGrant_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OAuthAuditLog" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"eventType" TEXT NOT NULL,
|
||||
"userId" TEXT,
|
||||
"clientId" TEXT,
|
||||
"grantId" TEXT,
|
||||
"ipAddress" TEXT,
|
||||
"userAgent" TEXT,
|
||||
"details" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "OAuthAuditLog_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ExecutionWebhook" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"executionId" TEXT NOT NULL,
|
||||
"webhookUrl" TEXT NOT NULL,
|
||||
"clientId" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"secret" TEXT,
|
||||
|
||||
CONSTRAINT "ExecutionWebhook_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthClient_clientId_key" ON "OAuthClient"("clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthClient_clientId_idx" ON "OAuthClient"("clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthClient_ownerId_idx" ON "OAuthClient"("ownerId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthClient_status_idx" ON "OAuthClient"("status");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuthorization_userId_idx" ON "OAuthAuthorization"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuthorization_clientId_idx" ON "OAuthAuthorization"("clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthAuthorization_userId_clientId_key" ON "OAuthAuthorization"("userId", "clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthAuthorizationCode_codeHash_key" ON "OAuthAuthorizationCode"("codeHash");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuthorizationCode_codeHash_idx" ON "OAuthAuthorizationCode"("codeHash");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuthorizationCode_expiresAt_idx" ON "OAuthAuthorizationCode"("expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthAccessToken_tokenHash_key" ON "OAuthAccessToken"("tokenHash");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAccessToken_tokenHash_idx" ON "OAuthAccessToken"("tokenHash");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAccessToken_userId_clientId_idx" ON "OAuthAccessToken"("userId", "clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAccessToken_expiresAt_idx" ON "OAuthAccessToken"("expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OAuthRefreshToken_tokenHash_key" ON "OAuthRefreshToken"("tokenHash");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthRefreshToken_tokenHash_idx" ON "OAuthRefreshToken"("tokenHash");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthRefreshToken_expiresAt_idx" ON "OAuthRefreshToken"("expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "CredentialGrant_userId_clientId_idx" ON "CredentialGrant"("userId", "clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "CredentialGrant_clientId_idx" ON "CredentialGrant"("clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "CredentialGrant_userId_clientId_credentialId_key" ON "CredentialGrant"("userId", "clientId", "credentialId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuditLog_createdAt_idx" ON "OAuthAuditLog"("createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuditLog_eventType_idx" ON "OAuthAuditLog"("eventType");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuditLog_userId_idx" ON "OAuthAuditLog"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "OAuthAuditLog_clientId_idx" ON "OAuthAuditLog"("clientId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ExecutionWebhook_executionId_idx" ON "ExecutionWebhook"("executionId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ExecutionWebhook_clientId_idx" ON "ExecutionWebhook"("clientId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthClient" ADD CONSTRAINT "OAuthClient_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "platform"."OAuthClient" ADD COLUMN "webhookSecret" TEXT;
|
||||
@@ -60,6 +60,14 @@ model User {
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthClientsOwned OAuthClient[] @relation("OAuthClientOwner")
|
||||
OAuthAuthorizations OAuthAuthorization[]
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
OAuthAccessTokens OAuthAccessToken[]
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
CredentialGrants CredentialGrant[]
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
@@ -701,11 +709,11 @@ view StoreAgent {
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
@@ -834,14 +842,14 @@ model StoreListingVersion {
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
// Content fields
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
@@ -961,3 +969,226 @@ enum APIKeyStatus {
|
||||
REVOKED
|
||||
SUSPENDED
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OAuth Provider & Credential Broker Models
|
||||
// ============================================================
|
||||
|
||||
enum OAuthClientStatus {
|
||||
ACTIVE
|
||||
SUSPENDED
|
||||
}
|
||||
|
||||
enum CredentialGrantPermission {
|
||||
USE // Can use credential for agent execution
|
||||
DELETE // Can delete the credential
|
||||
}
|
||||
|
||||
// OAuth Client - Registered external applications
|
||||
model OAuthClient {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Client identification
|
||||
clientId String @unique // Public identifier (e.g., "app_abc123")
|
||||
clientSecretHash String? // Hashed (null for public clients)
|
||||
clientSecretSalt String?
|
||||
clientType String // "public" or "confidential"
|
||||
|
||||
// Metadata (shown on consent screen)
|
||||
name String
|
||||
description String?
|
||||
logoUrl String?
|
||||
homepageUrl String?
|
||||
privacyPolicyUrl String?
|
||||
termsOfServiceUrl String?
|
||||
|
||||
// Configuration
|
||||
redirectUris String[]
|
||||
allowedScopes String[]
|
||||
webhookDomains String[] // For webhook URL validation
|
||||
webhookSecret String? // Secret for HMAC signing webhooks
|
||||
|
||||
// Security
|
||||
requirePkce Boolean @default(true)
|
||||
tokenLifetimeSecs Int @default(3600)
|
||||
refreshTokenLifetimeSecs Int @default(2592000) // 30 days
|
||||
|
||||
// Status
|
||||
status OAuthClientStatus @default(ACTIVE)
|
||||
|
||||
// Owner
|
||||
ownerId String
|
||||
Owner User @relation("OAuthClientOwner", fields: [ownerId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Relations
|
||||
Authorizations OAuthAuthorization[]
|
||||
AuthorizationCodes OAuthAuthorizationCode[]
|
||||
AccessTokens OAuthAccessToken[]
|
||||
RefreshTokens OAuthRefreshToken[]
|
||||
CredentialGrants CredentialGrant[]
|
||||
|
||||
@@index([clientId])
|
||||
@@index([ownerId])
|
||||
@@index([status])
|
||||
}
|
||||
|
||||
// OAuth Authorization - User consent record
|
||||
model OAuthAuthorization {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes String[]
|
||||
revokedAt DateTime?
|
||||
|
||||
@@unique([userId, clientId])
|
||||
@@index([userId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
// OAuth Authorization Code - Short-lived, single-use
|
||||
model OAuthAuthorizationCode {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
codeHash String @unique
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
redirectUri String
|
||||
scopes String[]
|
||||
nonce String? // OIDC nonce
|
||||
|
||||
// PKCE
|
||||
codeChallenge String
|
||||
codeChallengeMethod String @default("S256")
|
||||
|
||||
expiresAt DateTime // 10 minutes
|
||||
usedAt DateTime?
|
||||
|
||||
@@index([codeHash])
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
// OAuth Access Token
|
||||
model OAuthAccessToken {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
tokenHash String @unique // SHA256 of token
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes String[]
|
||||
expiresAt DateTime
|
||||
revokedAt DateTime?
|
||||
lastUsedAt DateTime?
|
||||
|
||||
@@index([tokenHash])
|
||||
@@index([userId, clientId])
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
// OAuth Refresh Token
|
||||
model OAuthRefreshToken {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
tokenHash String @unique
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes String[]
|
||||
expiresAt DateTime
|
||||
revokedAt DateTime?
|
||||
|
||||
@@index([tokenHash])
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
// Credential Grant - Links external app to user's credential with scoped access
|
||||
model CredentialGrant {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
credentialId String // Reference to credential in User.integrations
|
||||
provider String
|
||||
|
||||
// Fine-grained integration scopes (e.g., "google:gmail.readonly")
|
||||
grantedScopes String[]
|
||||
|
||||
// Permissions for the credential itself
|
||||
permissions CredentialGrantPermission[]
|
||||
|
||||
expiresAt DateTime?
|
||||
revokedAt DateTime?
|
||||
lastUsedAt DateTime?
|
||||
|
||||
@@unique([userId, clientId, credentialId])
|
||||
@@index([userId, clientId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
// OAuth Audit Log
|
||||
model OAuthAuditLog {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
eventType String // e.g., "token.issued", "grant.created"
|
||||
|
||||
userId String?
|
||||
clientId String?
|
||||
grantId String?
|
||||
|
||||
ipAddress String?
|
||||
userAgent String?
|
||||
|
||||
details Json @default("{}")
|
||||
|
||||
@@index([createdAt])
|
||||
@@index([eventType])
|
||||
@@index([userId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
// Execution Webhook - Webhook registration for external API executions
|
||||
model ExecutionWebhook {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
executionId String // The graph execution ID
|
||||
webhookUrl String // URL to send notifications to
|
||||
clientId String // The OAuth client database ID
|
||||
userId String // The user who started the execution
|
||||
secret String? // Optional webhook secret for HMAC signing
|
||||
|
||||
@@index([executionId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import random
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from faker import Faker
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
|
||||
from backend.data.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
@@ -177,12 +178,12 @@ class TestDataCreator:
|
||||
for block in blocks_to_create:
|
||||
try:
|
||||
await prisma.agentblock.create(
|
||||
data={
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": "{}",
|
||||
"outputSchema": "{}",
|
||||
}
|
||||
data=AgentBlockCreateInput(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
inputSchema="{}",
|
||||
outputSchema="{}",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating block {block.name}: {e}")
|
||||
|
||||
@@ -30,13 +30,19 @@ from prisma.types import (
|
||||
AgentGraphCreateInput,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
AgentPresetCreateInput,
|
||||
AnalyticsDetailsCreateInput,
|
||||
AnalyticsMetricsCreateInput,
|
||||
APIKeyCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
IntegrationWebhookCreateInput,
|
||||
LibraryAgentCreateInput,
|
||||
ProfileCreateInput,
|
||||
StoreListingCreateInput,
|
||||
StoreListingReviewCreateInput,
|
||||
StoreListingVersionCreateInput,
|
||||
UserCreateInput,
|
||||
UserOnboardingCreateInput,
|
||||
)
|
||||
|
||||
faker = Faker()
|
||||
@@ -172,14 +178,14 @@ async def main():
|
||||
for _ in range(num_presets): # Create 1 AgentPreset per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = await db.agentpreset.create(
|
||||
data={
|
||||
"name": faker.sentence(nb_words=3),
|
||||
"description": faker.text(max_nb_chars=200),
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"isActive": True,
|
||||
}
|
||||
data=AgentPresetCreateInput(
|
||||
name=faker.sentence(nb_words=3),
|
||||
description=faker.text(max_nb_chars=200),
|
||||
userId=user.id,
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
isActive=True,
|
||||
)
|
||||
)
|
||||
agent_presets.append(preset)
|
||||
|
||||
@@ -220,18 +226,18 @@ async def main():
|
||||
)
|
||||
|
||||
library_agent = await db.libraryagent.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"creatorId": creator_profile.id if creator_profile else None,
|
||||
"imageUrl": get_image() if random.random() < 0.5 else None,
|
||||
"useGraphIsActiveVersion": random.choice([True, False]),
|
||||
"isFavorite": random.choice([True, False]),
|
||||
"isCreatedByUser": random.choice([True, False]),
|
||||
"isArchived": random.choice([True, False]),
|
||||
"isDeleted": random.choice([True, False]),
|
||||
}
|
||||
data=LibraryAgentCreateInput(
|
||||
userId=user.id,
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
creatorId=creator_profile.id if creator_profile else None,
|
||||
imageUrl=get_image() if random.random() < 0.5 else None,
|
||||
useGraphIsActiveVersion=random.choice([True, False]),
|
||||
isFavorite=random.choice([True, False]),
|
||||
isCreatedByUser=random.choice([True, False]),
|
||||
isArchived=random.choice([True, False]),
|
||||
isDeleted=random.choice([True, False]),
|
||||
)
|
||||
)
|
||||
library_agents.append(library_agent)
|
||||
|
||||
@@ -392,13 +398,13 @@ async def main():
|
||||
user = random.choice(users)
|
||||
slug = faker.slug()
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"hasApprovedVersion": random.choice([True, False]),
|
||||
"slug": slug,
|
||||
}
|
||||
data=StoreListingCreateInput(
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
owningUserId=user.id,
|
||||
hasApprovedVersion=random.choice([True, False]),
|
||||
slug=slug,
|
||||
)
|
||||
)
|
||||
store_listings.append(listing)
|
||||
|
||||
@@ -408,26 +414,26 @@ async def main():
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": get_video_url() if random.random() < 0.3 else None,
|
||||
"imageUrls": [get_image() for _ in range(3)],
|
||||
"description": faker.text(),
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
"isFeatured": random.choice([True, False]),
|
||||
"isAvailable": True,
|
||||
"storeListingId": listing.id,
|
||||
"submissionStatus": random.choice(
|
||||
data=StoreListingVersionCreateInput(
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
name=graph.name or faker.sentence(nb_words=3),
|
||||
subHeading=faker.sentence(),
|
||||
videoUrl=get_video_url() if random.random() < 0.3 else None,
|
||||
imageUrls=[get_image() for _ in range(3)],
|
||||
description=faker.text(),
|
||||
categories=[faker.word() for _ in range(3)],
|
||||
isFeatured=random.choice([True, False]),
|
||||
isAvailable=True,
|
||||
storeListingId=listing.id,
|
||||
submissionStatus=random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
store_listing_versions.append(version)
|
||||
|
||||
@@ -469,51 +475,47 @@ async def main():
|
||||
|
||||
try:
|
||||
await db.useronboarding.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"completedSteps": completed_steps,
|
||||
"walletShown": random.choice([True, False]),
|
||||
"notified": (
|
||||
data=UserOnboardingCreateInput(
|
||||
userId=user.id,
|
||||
completedSteps=completed_steps,
|
||||
walletShown=random.choice([True, False]),
|
||||
notified=(
|
||||
random.sample(completed_steps, k=min(3, len(completed_steps)))
|
||||
if completed_steps
|
||||
else []
|
||||
),
|
||||
"rewardedFor": (
|
||||
rewardedFor=(
|
||||
random.sample(completed_steps, k=min(2, len(completed_steps)))
|
||||
if completed_steps
|
||||
else []
|
||||
),
|
||||
"usageReason": (
|
||||
usageReason=(
|
||||
random.choice(["personal", "business", "research", "learning"])
|
||||
if random.random() < 0.7
|
||||
else None
|
||||
),
|
||||
"integrations": random.sample(
|
||||
integrations=random.sample(
|
||||
["github", "google", "discord", "slack"], k=random.randint(0, 2)
|
||||
),
|
||||
"otherIntegrations": (
|
||||
faker.word() if random.random() < 0.2 else None
|
||||
),
|
||||
"selectedStoreListingVersionId": (
|
||||
otherIntegrations=(faker.word() if random.random() < 0.2 else None),
|
||||
selectedStoreListingVersionId=(
|
||||
random.choice(store_listing_versions).id
|
||||
if store_listing_versions and random.random() < 0.5
|
||||
else None
|
||||
),
|
||||
"onboardingAgentExecutionId": (
|
||||
onboardingAgentExecutionId=(
|
||||
random.choice(agent_graph_executions).id
|
||||
if agent_graph_executions and random.random() < 0.3
|
||||
else None
|
||||
),
|
||||
"agentRuns": random.randint(0, 10),
|
||||
}
|
||||
agentRuns=random.randint(0, 10),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating onboarding for user {user.id}: {e}")
|
||||
# Try simpler version
|
||||
await db.useronboarding.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
}
|
||||
data=UserOnboardingCreateInput(userId=user.id)
|
||||
)
|
||||
|
||||
# Insert IntegrationWebhooks for some users
|
||||
@@ -544,20 +546,20 @@ async def main():
|
||||
for user in users:
|
||||
api_key = APIKeySmith().generate_key()
|
||||
await db.apikey.create(
|
||||
data={
|
||||
"name": faker.word(),
|
||||
"head": api_key.head,
|
||||
"tail": api_key.tail,
|
||||
"hash": api_key.hash,
|
||||
"salt": api_key.salt,
|
||||
"status": prisma.enums.APIKeyStatus.ACTIVE,
|
||||
"permissions": [
|
||||
data=APIKeyCreateInput(
|
||||
name=faker.word(),
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
hash=api_key.hash,
|
||||
salt=api_key.salt,
|
||||
status=prisma.enums.APIKeyStatus.ACTIVE,
|
||||
permissions=[
|
||||
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
|
||||
prisma.enums.APIKeyPermission.READ_GRAPH,
|
||||
],
|
||||
"description": faker.text(),
|
||||
"userId": user.id,
|
||||
}
|
||||
description=faker.text(),
|
||||
userId=user.id,
|
||||
)
|
||||
)
|
||||
|
||||
# Refresh materialized views
|
||||
|
||||
@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
|
||||
import prisma.enums
|
||||
from faker import Faker
|
||||
from prisma import Json, Prisma
|
||||
from prisma.types import CreditTransactionCreateInput, StoreListingReviewCreateInput
|
||||
|
||||
faker = Faker()
|
||||
|
||||
@@ -166,16 +167,16 @@ async def main():
|
||||
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
|
||||
|
||||
await db.storelistingreview.create(
|
||||
data={
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewByUserId": reviewer.id,
|
||||
"score": score,
|
||||
"comments": (
|
||||
data=StoreListingReviewCreateInput(
|
||||
storeListingVersionId=version.id,
|
||||
reviewByUserId=reviewer.id,
|
||||
score=score,
|
||||
comments=(
|
||||
faker.text(max_nb_chars=200)
|
||||
if random.random() < 0.7
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
new_reviews_count += 1
|
||||
|
||||
@@ -244,17 +245,17 @@ async def main():
|
||||
)
|
||||
|
||||
await db.credittransaction.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"amount": amount,
|
||||
"type": transaction_type,
|
||||
"metadata": Json(
|
||||
data=CreditTransactionCreateInput(
|
||||
userId=user.id,
|
||||
amount=amount,
|
||||
type=transaction_type,
|
||||
metadata=Json(
|
||||
{
|
||||
"source": "test_updater",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
transaction_count += 1
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import { shouldShowOnboarding } from "@/app/api/helpers";
|
||||
export async function GET(request: Request) {
|
||||
const { searchParams, origin } = new URL(request.url);
|
||||
const code = searchParams.get("code");
|
||||
const oauthSession = searchParams.get("oauth_session");
|
||||
const connectSession = searchParams.get("connect_session");
|
||||
|
||||
let next = "/marketplace";
|
||||
|
||||
@@ -25,6 +27,22 @@ export async function GET(request: Request) {
|
||||
const api = new BackendAPI();
|
||||
await api.createUser();
|
||||
|
||||
// Handle oauth_session redirect - resume OAuth flow after login
|
||||
// Redirect to a frontend page that will handle the OAuth resume with proper auth
|
||||
if (oauthSession) {
|
||||
return NextResponse.redirect(
|
||||
`${origin}/auth/oauth-resume?session_id=${encodeURIComponent(oauthSession)}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Handle connect_session redirect - resume connect flow after login
|
||||
// Redirect to a frontend page that will handle the connect resume with proper auth
|
||||
if (connectSession) {
|
||||
return NextResponse.redirect(
|
||||
`${origin}/auth/connect-resume?session_id=${encodeURIComponent(connectSession)}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (await shouldShowOnboarding()) {
|
||||
next = "/onboarding";
|
||||
revalidatePath("/onboarding", "layout");
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState, useRef, useCallback } from "react";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
|
||||
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
|
||||
const attemptedSessions = new Set<string>();
|
||||
|
||||
interface ScopeInfo {
|
||||
scope: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
interface CredentialInfo {
|
||||
id: string;
|
||||
title: string;
|
||||
username: string;
|
||||
}
|
||||
|
||||
interface ClientInfo {
|
||||
name: string;
|
||||
logo_url: string | null;
|
||||
}
|
||||
|
||||
interface ConnectData {
|
||||
connect_token: string;
|
||||
client: ClientInfo;
|
||||
provider: string;
|
||||
scopes: ScopeInfo[];
|
||||
credentials: CredentialInfo[];
|
||||
action_url: string;
|
||||
}
|
||||
|
||||
interface ErrorData {
|
||||
error: string;
|
||||
error_description: string;
|
||||
}
|
||||
|
||||
type ResumeResponse = ConnectData | ErrorData;
|
||||
|
||||
function isConnectData(data: ResumeResponse): data is ConnectData {
|
||||
return "connect_token" in data;
|
||||
}
|
||||
|
||||
function isErrorData(data: ResumeResponse): data is ErrorData {
|
||||
return "error" in data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect Consent Form Component
|
||||
*
|
||||
* Renders a proper React component for the integration connect consent form
|
||||
*/
|
||||
function ConnectForm({
|
||||
client,
|
||||
provider,
|
||||
scopes,
|
||||
credentials,
|
||||
connectToken,
|
||||
actionUrl,
|
||||
}: {
|
||||
client: ClientInfo;
|
||||
provider: string;
|
||||
scopes: ScopeInfo[];
|
||||
credentials: CredentialInfo[];
|
||||
connectToken: string;
|
||||
actionUrl: string;
|
||||
}) {
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [selectedCredential, setSelectedCredential] = useState<string>(
|
||||
credentials.length > 0 ? credentials[0].id : "",
|
||||
);
|
||||
|
||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
||||
const backendOrigin = backendUrl
|
||||
? new URL(backendUrl).origin
|
||||
: "http://localhost:8006";
|
||||
|
||||
const fullActionUrl = `${backendOrigin}${actionUrl}`;
|
||||
|
||||
function handleSubmit() {
|
||||
setIsSubmitting(true);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
|
||||
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
|
||||
{/* Header */}
|
||||
<div className="mb-6 text-center">
|
||||
<h1 className="text-xl font-semibold text-zinc-100">
|
||||
Connect{" "}
|
||||
<span className="rounded bg-zinc-700 px-2 py-1 text-sm capitalize">
|
||||
{provider}
|
||||
</span>
|
||||
</h1>
|
||||
<p className="mt-2 text-sm text-zinc-400">
|
||||
<span className="font-semibold text-cyan-400">{client.name}</span>{" "}
|
||||
wants to use your {provider} integration
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Divider */}
|
||||
<div className="my-6 h-px bg-zinc-700" />
|
||||
|
||||
{/* Scopes Section */}
|
||||
<div className="mb-6">
|
||||
<h2 className="mb-4 text-sm font-medium text-zinc-400">
|
||||
This will allow {client.name} to:
|
||||
</h2>
|
||||
<div className="space-y-2">
|
||||
{scopes.map((scope) => (
|
||||
<div key={scope.scope} className="flex items-start gap-2 py-2">
|
||||
<span className="flex-shrink-0 text-cyan-400">✓</span>
|
||||
<span className="text-sm text-zinc-300">
|
||||
{scope.description}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Divider */}
|
||||
<div className="my-6 h-px bg-zinc-700" />
|
||||
|
||||
{/* Form */}
|
||||
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
|
||||
<input type="hidden" name="connect_token" value={connectToken} />
|
||||
|
||||
{/* Existing credentials selection */}
|
||||
{credentials.length > 0 && (
|
||||
<>
|
||||
<h3 className="mb-3 text-sm font-medium text-zinc-400">
|
||||
Select an existing credential:
|
||||
</h3>
|
||||
<div className="mb-4 space-y-2">
|
||||
{credentials.map((cred) => (
|
||||
<label
|
||||
key={cred.id}
|
||||
className={`flex cursor-pointer items-center gap-3 rounded-lg border p-3 transition-colors ${
|
||||
selectedCredential === cred.id
|
||||
? "border-cyan-400 bg-cyan-400/10"
|
||||
: "border-zinc-700 hover:border-cyan-400/50"
|
||||
}`}
|
||||
>
|
||||
<input
|
||||
type="radio"
|
||||
name="credential_id"
|
||||
value={cred.id}
|
||||
checked={selectedCredential === cred.id}
|
||||
onChange={() => setSelectedCredential(cred.id)}
|
||||
className="hidden"
|
||||
/>
|
||||
<div>
|
||||
<div className="text-sm font-medium text-zinc-200">
|
||||
{cred.title}
|
||||
</div>
|
||||
{cred.username && (
|
||||
<div className="text-xs text-zinc-500">
|
||||
{cred.username}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
<div className="my-4 h-px bg-zinc-700" />
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Connect new account */}
|
||||
<div className="mb-4">
|
||||
{credentials.length > 0 ? (
|
||||
<h3 className="mb-3 text-sm font-medium text-zinc-400">
|
||||
Or connect a new account:
|
||||
</h3>
|
||||
) : (
|
||||
<p className="mb-3 text-sm text-zinc-400">
|
||||
You don't have any {provider} credentials yet.
|
||||
</p>
|
||||
)}
|
||||
<button
|
||||
type="submit"
|
||||
name="action"
|
||||
value="connect_new"
|
||||
disabled={isSubmitting}
|
||||
className="w-full rounded-lg bg-blue-500 px-6 py-3 text-sm font-medium text-white transition-colors hover:bg-blue-400 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
Connect {provider.charAt(0).toUpperCase() + provider.slice(1)}{" "}
|
||||
Account
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Action buttons */}
|
||||
<div className="flex gap-3">
|
||||
<button
|
||||
type="submit"
|
||||
name="action"
|
||||
value="deny"
|
||||
disabled={isSubmitting}
|
||||
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
{credentials.length > 0 && (
|
||||
<button
|
||||
type="submit"
|
||||
name="action"
|
||||
value="approve"
|
||||
disabled={isSubmitting}
|
||||
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
{isSubmitting ? "Approving..." : "Approve"}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect Resume Page
|
||||
*
|
||||
* This page handles resuming the integration connect flow after a user logs in.
|
||||
* It fetches the connect data from the backend via JSON API and renders the consent form.
|
||||
*/
|
||||
export default function ConnectResumePage() {
|
||||
const searchParams = useSearchParams();
|
||||
const sessionId = searchParams.get("session_id");
|
||||
const { isUserLoading, refreshSession } = useSupabase();
|
||||
|
||||
const [connectData, setConnectData] = useState<ConnectData | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const retryCountRef = useRef(0);
|
||||
const maxRetries = 5;
|
||||
|
||||
const resumeConnectFlow = useCallback(async () => {
|
||||
if (!sessionId) {
|
||||
setError(
|
||||
"Missing session ID. Please start the connection process again.",
|
||||
);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (attemptedSessions.has(sessionId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isUserLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
attemptedSessions.add(sessionId);
|
||||
|
||||
try {
|
||||
let tokenResult = await getWebSocketToken();
|
||||
let accessToken = tokenResult.token;
|
||||
|
||||
while (!accessToken && retryCountRef.current < maxRetries) {
|
||||
retryCountRef.current += 1;
|
||||
console.log(
|
||||
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
|
||||
);
|
||||
await refreshSession();
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
tokenResult = await getWebSocketToken();
|
||||
accessToken = tokenResult.token;
|
||||
}
|
||||
|
||||
if (!accessToken) {
|
||||
setError(
|
||||
"Unable to retrieve authentication token. Please log in again.",
|
||||
);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
||||
if (!backendUrl) {
|
||||
setError("Backend URL not configured.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
let backendOrigin: string;
|
||||
try {
|
||||
const url = new URL(backendUrl);
|
||||
backendOrigin = url.origin;
|
||||
} catch {
|
||||
setError("Invalid backend URL configuration.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`${backendOrigin}/connect/resume?session_id=${encodeURIComponent(sessionId)}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: "application/json",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const data: ResumeResponse = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
if (isErrorData(data)) {
|
||||
setError(data.error_description || data.error);
|
||||
} else {
|
||||
setError(`Connection failed (${response.status}). Please try again.`);
|
||||
}
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isConnectData(data)) {
|
||||
setConnectData(data);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setError("Unexpected response from server. Please try again.");
|
||||
setIsLoading(false);
|
||||
} catch (err) {
|
||||
console.error("Connect resume error:", err);
|
||||
setError(
|
||||
"An error occurred while resuming connection. Please try again.",
|
||||
);
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [sessionId, isUserLoading, refreshSession]);
|
||||
|
||||
useEffect(() => {
|
||||
resumeConnectFlow();
|
||||
}, [resumeConnectFlow]);
|
||||
|
||||
if (isLoading || isUserLoading) {
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
||||
<div className="text-center">
|
||||
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
|
||||
<p className="text-zinc-400">Resuming connection...</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
||||
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
|
||||
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" />
|
||||
<line x1="15" y1="9" x2="9" y2="15" />
|
||||
<line x1="9" y1="9" x2="15" y2="15" />
|
||||
</svg>
|
||||
</div>
|
||||
<h1 className="mb-2 text-xl font-semibold text-red-400">
|
||||
Connection Error
|
||||
</h1>
|
||||
<p className="mb-6 text-zinc-400">{error}</p>
|
||||
<button
|
||||
onClick={() => window.close()}
|
||||
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (connectData) {
|
||||
return (
|
||||
<ConnectForm
|
||||
client={connectData.client}
|
||||
provider={connectData.provider}
|
||||
scopes={connectData.scopes}
|
||||
credentials={connectData.credentials}
|
||||
connectToken={connectData.connect_token}
|
||||
actionUrl={connectData.action_url}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
@@ -22,20 +22,28 @@ export async function GET(request: Request) {
|
||||
|
||||
console.debug("Sending message to opener:", message);
|
||||
|
||||
// Escape JSON to prevent XSS attacks via </script> injection
|
||||
const safeJson = JSON.stringify(message)
|
||||
.replace(/</g, "\\u003c")
|
||||
.replace(/>/g, "\\u003e");
|
||||
|
||||
// Return a response with the message as JSON and a script to close the window
|
||||
return new NextResponse(
|
||||
`
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
window.opener.postMessage(${JSON.stringify(message)});
|
||||
window.close();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`,
|
||||
`<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
window.opener.postMessage(${safeJson}, '*');
|
||||
window.close();
|
||||
</script>
|
||||
</body>
|
||||
</html>`,
|
||||
{
|
||||
headers: { "Content-Type": "text/html" },
|
||||
headers: {
|
||||
"Content-Type": "text/html",
|
||||
"Content-Security-Policy":
|
||||
"default-src 'none'; script-src 'unsafe-inline'",
|
||||
},
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,399 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState, useRef, useCallback } from "react";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
|
||||
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
|
||||
// This is keyed by session_id to allow different sessions
|
||||
const attemptedSessions = new Set<string>();
|
||||
|
||||
interface ScopeInfo {
|
||||
scope: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
interface ClientInfo {
|
||||
name: string;
|
||||
logo_url: string | null;
|
||||
privacy_policy_url: string | null;
|
||||
terms_url: string | null;
|
||||
}
|
||||
|
||||
interface ConsentData {
|
||||
needs_consent: true;
|
||||
consent_token: string;
|
||||
client: ClientInfo;
|
||||
scopes: ScopeInfo[];
|
||||
action_url: string;
|
||||
}
|
||||
|
||||
interface RedirectData {
|
||||
redirect_url: string;
|
||||
needs_consent: false;
|
||||
}
|
||||
|
||||
interface ErrorData {
|
||||
error: string;
|
||||
error_description: string;
|
||||
redirect_url?: string;
|
||||
}
|
||||
|
||||
type ResumeResponse = ConsentData | RedirectData | ErrorData;
|
||||
|
||||
function isConsentData(data: ResumeResponse): data is ConsentData {
|
||||
return "needs_consent" in data && data.needs_consent === true;
|
||||
}
|
||||
|
||||
function isRedirectData(data: ResumeResponse): data is RedirectData {
|
||||
return "redirect_url" in data && !("error" in data);
|
||||
}
|
||||
|
||||
function isErrorData(data: ResumeResponse): data is ErrorData {
|
||||
return "error" in data;
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth Consent Form Component
|
||||
*
|
||||
* Renders a proper React component for the consent form instead of dangerouslySetInnerHTML
|
||||
*/
|
||||
function ConsentForm({
|
||||
client,
|
||||
scopes,
|
||||
consentToken,
|
||||
actionUrl,
|
||||
}: {
|
||||
client: ClientInfo;
|
||||
scopes: ScopeInfo[];
|
||||
consentToken: string;
|
||||
actionUrl: string;
|
||||
}) {
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
||||
const backendOrigin = backendUrl
|
||||
? new URL(backendUrl).origin
|
||||
: "http://localhost:8006";
|
||||
|
||||
// Full action URL for form submission
|
||||
const fullActionUrl = `${backendOrigin}${actionUrl}`;
|
||||
|
||||
function handleSubmit() {
|
||||
setIsSubmitting(true);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
|
||||
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
|
||||
{/* Header */}
|
||||
<div className="mb-6 text-center">
|
||||
<div className="mx-auto mb-4 flex h-16 w-16 items-center justify-center rounded-xl bg-zinc-700">
|
||||
{client.logo_url ? (
|
||||
<img
|
||||
src={client.logo_url}
|
||||
alt={client.name}
|
||||
className="h-12 w-12 rounded-lg"
|
||||
/>
|
||||
) : (
|
||||
<span className="text-3xl text-zinc-400">
|
||||
{client.name.charAt(0).toUpperCase()}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<h1 className="text-xl font-semibold text-zinc-100">
|
||||
Authorize <span className="text-cyan-400">{client.name}</span>
|
||||
</h1>
|
||||
<p className="mt-2 text-sm text-zinc-400">
|
||||
wants to access your AutoGPT account
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Divider */}
|
||||
<div className="my-6 h-px bg-zinc-700" />
|
||||
|
||||
{/* Scopes Section */}
|
||||
<div className="mb-6">
|
||||
<h2 className="mb-4 text-sm font-medium text-zinc-400">
|
||||
This will allow {client.name} to:
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
{scopes.map((scope) => (
|
||||
<div
|
||||
key={scope.scope}
|
||||
className="flex items-start gap-3 border-b border-zinc-700 pb-3 last:border-0"
|
||||
>
|
||||
<svg
|
||||
className="mt-0.5 h-5 w-5 flex-shrink-0 text-cyan-400"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
fillRule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clipRule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
<span className="text-sm leading-relaxed text-zinc-300">
|
||||
{scope.description}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Form */}
|
||||
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
|
||||
<input type="hidden" name="consent_token" value={consentToken} />
|
||||
<div className="flex gap-3">
|
||||
<button
|
||||
type="submit"
|
||||
name="authorize"
|
||||
value="false"
|
||||
disabled={isSubmitting}
|
||||
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
name="authorize"
|
||||
value="true"
|
||||
disabled={isSubmitting}
|
||||
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
{isSubmitting ? "Authorizing..." : "Allow"}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
{/* Footer Links */}
|
||||
{(client.privacy_policy_url || client.terms_url) && (
|
||||
<div className="mt-6 text-center text-xs text-zinc-500">
|
||||
{client.privacy_policy_url && (
|
||||
<a
|
||||
href={client.privacy_policy_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-zinc-400 hover:underline"
|
||||
>
|
||||
Privacy Policy
|
||||
</a>
|
||||
)}
|
||||
{client.privacy_policy_url && client.terms_url && (
|
||||
<span className="mx-2">•</span>
|
||||
)}
|
||||
{client.terms_url && (
|
||||
<a
|
||||
href={client.terms_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-zinc-400 hover:underline"
|
||||
>
|
||||
Terms of Service
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth Resume Page
|
||||
*
|
||||
* This page handles resuming the OAuth authorization flow after a user logs in.
|
||||
* It fetches the consent data from the backend via JSON API and renders the consent form.
|
||||
*/
|
||||
export default function OAuthResumePage() {
|
||||
const searchParams = useSearchParams();
|
||||
const sessionId = searchParams.get("session_id");
|
||||
const { isUserLoading, refreshSession } = useSupabase();
|
||||
|
||||
const [consentData, setConsentData] = useState<ConsentData | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const retryCountRef = useRef(0);
|
||||
const maxRetries = 5;
|
||||
|
||||
const resumeOAuthFlow = useCallback(async () => {
|
||||
// Prevent multiple attempts for the same session (handles React StrictMode)
|
||||
if (!sessionId) {
|
||||
setError(
|
||||
"Missing session ID. Please start the authorization process again.",
|
||||
);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (attemptedSessions.has(sessionId)) {
|
||||
// Already attempted this session, don't retry
|
||||
return;
|
||||
}
|
||||
|
||||
if (isUserLoading) {
|
||||
return; // Wait for auth state to load
|
||||
}
|
||||
|
||||
// Mark this session as attempted IMMEDIATELY to prevent duplicate requests
|
||||
attemptedSessions.add(sessionId);
|
||||
|
||||
try {
|
||||
// Get the access token from server action (which reads cookies properly)
|
||||
let tokenResult = await getWebSocketToken();
|
||||
let accessToken = tokenResult.token;
|
||||
|
||||
// If no token, retry a few times with delays
|
||||
while (!accessToken && retryCountRef.current < maxRetries) {
|
||||
retryCountRef.current += 1;
|
||||
console.log(
|
||||
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
|
||||
);
|
||||
|
||||
// Try refreshing the session
|
||||
await refreshSession();
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
|
||||
tokenResult = await getWebSocketToken();
|
||||
accessToken = tokenResult.token;
|
||||
}
|
||||
|
||||
if (!accessToken) {
|
||||
setError(
|
||||
"Unable to retrieve authentication token. Please log in again.",
|
||||
);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Call the backend resume endpoint with JSON accept header
|
||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
||||
if (!backendUrl) {
|
||||
setError("Backend URL not configured.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract the origin from the backend URL
|
||||
let backendOrigin: string;
|
||||
try {
|
||||
const url = new URL(backendUrl);
|
||||
backendOrigin = url.origin;
|
||||
} catch {
|
||||
setError("Invalid backend URL configuration.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Use Accept: application/json to get JSON response instead of HTML
|
||||
// This solves the CORS/redirect issue by letting us handle redirects client-side
|
||||
const response = await fetch(
|
||||
`${backendOrigin}/oauth/authorize/resume?session_id=${encodeURIComponent(sessionId)}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: "application/json",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const data: ResumeResponse = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
if (isErrorData(data)) {
|
||||
setError(data.error_description || data.error);
|
||||
} else {
|
||||
setError(
|
||||
`Authorization failed (${response.status}). Please try again.`,
|
||||
);
|
||||
}
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle redirect response (user already authorized these scopes)
|
||||
if (isRedirectData(data)) {
|
||||
window.location.href = data.redirect_url;
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle consent required
|
||||
if (isConsentData(data)) {
|
||||
setConsentData(data);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unexpected response
|
||||
setError("Unexpected response from server. Please try again.");
|
||||
setIsLoading(false);
|
||||
} catch (err) {
|
||||
console.error("OAuth resume error:", err);
|
||||
setError(
|
||||
"An error occurred while resuming authorization. Please try again.",
|
||||
);
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [sessionId, isUserLoading, refreshSession]);
|
||||
|
||||
useEffect(() => {
|
||||
resumeOAuthFlow();
|
||||
}, [resumeOAuthFlow]);
|
||||
|
||||
if (isLoading || isUserLoading) {
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
||||
<div className="text-center">
|
||||
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
|
||||
<p className="text-zinc-400">Resuming authorization...</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
||||
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
|
||||
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" />
|
||||
<line x1="15" y1="9" x2="9" y2="15" />
|
||||
<line x1="9" y1="9" x2="15" y2="15" />
|
||||
</svg>
|
||||
</div>
|
||||
<h1 className="mb-2 text-xl font-semibold text-red-400">
|
||||
Authorization Error
|
||||
</h1>
|
||||
<p className="mb-6 text-zinc-400">{error}</p>
|
||||
<button
|
||||
onClick={() => window.close()}
|
||||
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (consentData) {
|
||||
return (
|
||||
<ConsentForm
|
||||
client={consentData.client}
|
||||
scopes={consentData.scopes}
|
||||
consentToken={consentData.consent_token}
|
||||
actionUrl={consentData.action_url}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
@@ -3,7 +3,7 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import z from "zod";
|
||||
@@ -13,6 +13,7 @@ export function useLoginPage() {
|
||||
const { supabase, user, isUserLoading, isLoggedIn } = useSupabase();
|
||||
const [feedback, setFeedback] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const { toast } = useToast();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isLoggingIn, setIsLoggingIn] = useState(false);
|
||||
@@ -20,11 +21,59 @@ export function useLoginPage() {
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = environment.isCloud();
|
||||
|
||||
// Get returnUrl, oauth_session, and connect_session from query params
|
||||
const returnUrl = searchParams.get("returnUrl");
|
||||
const oauthSession = searchParams.get("oauth_session");
|
||||
const connectSession = searchParams.get("connect_session");
|
||||
|
||||
function getRedirectUrl(onboarding: boolean): string {
|
||||
// OAuth session takes priority - redirect to frontend oauth-resume page
|
||||
// which will handle the backend call with proper authentication
|
||||
if (oauthSession) {
|
||||
return `/auth/oauth-resume?session_id=${encodeURIComponent(oauthSession)}`;
|
||||
}
|
||||
|
||||
// Connect session - redirect to frontend connect-resume page
|
||||
// for integration credential connection flow
|
||||
if (connectSession) {
|
||||
return `/auth/connect-resume?session_id=${encodeURIComponent(connectSession)}`;
|
||||
}
|
||||
|
||||
// If returnUrl is set and is a valid URL, redirect there after login
|
||||
if (returnUrl) {
|
||||
try {
|
||||
const url = new URL(returnUrl, window.location.origin);
|
||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
||||
|
||||
// Same origin - return normalized path only (prevents open redirect)
|
||||
if (url.origin === window.location.origin) {
|
||||
return url.pathname + url.search;
|
||||
}
|
||||
|
||||
// Backend URL - strict origin match (not startsWith to prevent prefix attacks)
|
||||
if (backendUrl) {
|
||||
try {
|
||||
const backendOrigin = new URL(backendUrl).origin;
|
||||
if (url.origin === backendOrigin) {
|
||||
return url.href;
|
||||
}
|
||||
} catch {
|
||||
// Invalid backend URL config, fall through to default
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Invalid URL, fall through to default
|
||||
}
|
||||
}
|
||||
return onboarding ? "/onboarding" : "/marketplace";
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && !isLoggingIn) {
|
||||
router.push("/marketplace");
|
||||
const redirectTo = getRedirectUrl(false);
|
||||
router.push(redirectTo);
|
||||
}
|
||||
}, [isLoggedIn, isLoggingIn]);
|
||||
}, [isLoggedIn, isLoggingIn, returnUrl, oauthSession, connectSession]);
|
||||
|
||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||
resolver: zodResolver(loginFormSchema),
|
||||
@@ -39,10 +88,20 @@ export function useLoginPage() {
|
||||
setIsLoggingIn(true);
|
||||
|
||||
try {
|
||||
// Build redirect URL that preserves oauth_session or connect_session through the OAuth flow
|
||||
let callbackUrl: string | undefined;
|
||||
const origin = window.location.origin;
|
||||
|
||||
if (oauthSession) {
|
||||
callbackUrl = `${origin}/auth/callback?oauth_session=${encodeURIComponent(oauthSession)}`;
|
||||
} else if (connectSession) {
|
||||
callbackUrl = `${origin}/auth/callback?connect_session=${encodeURIComponent(connectSession)}`;
|
||||
}
|
||||
|
||||
const response = await fetch("/api/auth/provider", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ provider }),
|
||||
body: JSON.stringify({ provider, redirectTo: callbackUrl }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -83,11 +142,7 @@ export function useLoginPage() {
|
||||
throw new Error(result.error || "Login failed");
|
||||
}
|
||||
|
||||
if (result.onboarding) {
|
||||
router.replace("/onboarding");
|
||||
} else {
|
||||
router.replace("/marketplace");
|
||||
}
|
||||
router.replace(getRedirectUrl(result.onboarding ?? false));
|
||||
} catch (error) {
|
||||
toast({
|
||||
title:
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
"use client";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { Copy } from "@phosphor-icons/react";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { Input } from "@/components/__legacy__/ui/input";
|
||||
import { Textarea } from "@/components/__legacy__/ui/textarea";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
|
||||
import { useOAuthClientModals } from "./useOAuthClientModals";
|
||||
|
||||
export function OAuthClientModals() {
|
||||
const {
|
||||
isCreateOpen,
|
||||
setIsCreateOpen,
|
||||
isSecretDialogOpen,
|
||||
setIsSecretDialogOpen,
|
||||
formState,
|
||||
setFormState,
|
||||
newClientSecret,
|
||||
isCreating,
|
||||
handleCreateClient,
|
||||
handleCopyClientId,
|
||||
handleCopyClientSecret,
|
||||
handleCopyWebhookSecret,
|
||||
resetForm,
|
||||
} = useOAuthClientModals();
|
||||
|
||||
return (
|
||||
<div className="mb-4 flex justify-end">
|
||||
<Dialog
|
||||
open={isCreateOpen}
|
||||
onOpenChange={(open) => {
|
||||
setIsCreateOpen(open);
|
||||
if (!open) resetForm();
|
||||
}}
|
||||
>
|
||||
<DialogTrigger asChild>
|
||||
<Button>Register OAuth Client</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent className="max-h-[90vh] overflow-y-auto sm:max-w-[525px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Register New OAuth Client</DialogTitle>
|
||||
<DialogDescription>
|
||||
Register a new OAuth client to integrate with the AutoGPT
|
||||
Platform. For confidential clients, the client secret will only be
|
||||
shown once.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="name">
|
||||
Name <span className="text-destructive">*</span>
|
||||
</Label>
|
||||
<Input
|
||||
id="name"
|
||||
value={formState.name}
|
||||
onChange={(e) =>
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
name: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="My Application"
|
||||
maxLength={100}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="description">Description</Label>
|
||||
<Input
|
||||
id="description"
|
||||
value={formState.description}
|
||||
onChange={(e) =>
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
description: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="A brief description of your application"
|
||||
maxLength={500}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="redirectUris">
|
||||
Redirect URIs <span className="text-destructive">*</span>
|
||||
</Label>
|
||||
<Textarea
|
||||
id="redirectUris"
|
||||
value={formState.redirectUris}
|
||||
onChange={(e) =>
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
redirectUris: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com/callback https://localhost:3000/callback"
|
||||
rows={3}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Enter one URI per line or separate with commas
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="clientType">Client Type</Label>
|
||||
<Select
|
||||
value={formState.clientType}
|
||||
onValueChange={(value: "public" | "confidential") =>
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
clientType: value,
|
||||
}))
|
||||
}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select client type" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="public">
|
||||
Public (SPA, Mobile apps)
|
||||
</SelectItem>
|
||||
<SelectItem value="confidential">
|
||||
Confidential (Server-side apps)
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Public clients cannot securely store secrets. Confidential
|
||||
clients receive a client secret.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="homepageUrl">Homepage URL</Label>
|
||||
<Input
|
||||
id="homepageUrl"
|
||||
type="url"
|
||||
value={formState.homepageUrl}
|
||||
onChange={(e) =>
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
homepageUrl: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="privacyPolicyUrl">Privacy Policy URL</Label>
|
||||
<Input
|
||||
id="privacyPolicyUrl"
|
||||
type="url"
|
||||
value={formState.privacyPolicyUrl}
|
||||
onChange={(e) =>
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
privacyPolicyUrl: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com/privacy"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="termsOfServiceUrl">Terms of Service URL</Label>
|
||||
<Input
|
||||
id="termsOfServiceUrl"
|
||||
type="url"
|
||||
value={formState.termsOfServiceUrl}
|
||||
onChange={(e) =>
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
termsOfServiceUrl: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com/terms"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
setIsCreateOpen(false);
|
||||
resetForm();
|
||||
}}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleCreateClient} disabled={isCreating}>
|
||||
{isCreating ? "Creating..." : "Create Client"}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<Dialog open={isSecretDialogOpen} onOpenChange={setIsSecretDialogOpen}>
|
||||
<DialogContent className="sm:max-w-[525px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>OAuth Client Created</DialogTitle>
|
||||
<DialogDescription>
|
||||
Please copy your client credentials now. These secrets will not be
|
||||
shown again!
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label>Client ID</Label>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 rounded-md bg-secondary p-2 font-mono text-sm">
|
||||
{newClientSecret?.client_id}
|
||||
</code>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="outline"
|
||||
onClick={handleCopyClientId}
|
||||
>
|
||||
<Copy className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{newClientSecret?.client_secret && (
|
||||
<div className="space-y-2">
|
||||
<Label>Client Secret</Label>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 break-all rounded-md bg-secondary p-2 font-mono text-sm">
|
||||
{newClientSecret.client_secret}
|
||||
</code>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="outline"
|
||||
onClick={handleCopyClientSecret}
|
||||
>
|
||||
<Copy className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{newClientSecret?.webhook_secret && (
|
||||
<div className="space-y-2">
|
||||
<Label>Webhook Secret</Label>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 break-all rounded-md bg-secondary p-2 font-mono text-sm">
|
||||
{newClientSecret.webhook_secret}
|
||||
</code>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="outline"
|
||||
onClick={handleCopyWebhookSecret}
|
||||
>
|
||||
<Copy className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Use this secret to verify webhook signatures (HMAC-SHA256)
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
<p className="text-xs text-destructive">
|
||||
These secrets will only be shown once. Store them securely!
|
||||
</p>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button onClick={() => setIsSecretDialogOpen(false)}>Close</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
"use client";
|
||||
import {
|
||||
getGetOauthClientsListClientsQueryKey,
|
||||
usePostOauthClientsRegisterClient,
|
||||
usePostOauthClientsRotateClientSecret,
|
||||
} from "@/app/api/__generated__/endpoints/oauth-clients/oauth-clients";
|
||||
import { ClientSecretResponse } from "@/app/api/__generated__/models/clientSecretResponse";
|
||||
import { RegisterClientRequest } from "@/app/api/__generated__/models/registerClientRequest";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { validateRedirectUri } from "@/lib/utils";
|
||||
import { useState } from "react";
|
||||
|
||||
type ClientType = "public" | "confidential";
|
||||
|
||||
// ClientSecretResponse already includes webhook_secret from the generated types
|
||||
type ClientSecretResponseWithWebhook = ClientSecretResponse;
|
||||
|
||||
interface ClientFormState {
|
||||
name: string;
|
||||
description: string;
|
||||
redirectUris: string;
|
||||
clientType: ClientType;
|
||||
homepageUrl: string;
|
||||
privacyPolicyUrl: string;
|
||||
termsOfServiceUrl: string;
|
||||
}
|
||||
|
||||
const initialFormState: ClientFormState = {
|
||||
name: "",
|
||||
description: "",
|
||||
redirectUris: "",
|
||||
clientType: "public",
|
||||
homepageUrl: "",
|
||||
privacyPolicyUrl: "",
|
||||
termsOfServiceUrl: "",
|
||||
};
|
||||
|
||||
export function useOAuthClientModals() {
|
||||
const [isCreateOpen, setIsCreateOpen] = useState(false);
|
||||
const [isSecretDialogOpen, setIsSecretDialogOpen] = useState(false);
|
||||
const [formState, setFormState] = useState<ClientFormState>(initialFormState);
|
||||
const [newClientSecret, setNewClientSecret] =
|
||||
useState<ClientSecretResponseWithWebhook | null>(null);
|
||||
|
||||
const queryClient = getQueryClient();
|
||||
const { toast } = useToast();
|
||||
|
||||
const { mutateAsync: registerClient, isPending: isCreating } =
|
||||
usePostOauthClientsRegisterClient({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetOauthClientsListClientsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { mutateAsync: rotateSecret, isPending: isRotating } =
|
||||
usePostOauthClientsRotateClientSecret({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetOauthClientsListClientsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
function resetForm() {
|
||||
setFormState(initialFormState);
|
||||
}
|
||||
|
||||
async function handleCreateClient() {
|
||||
// Parse redirect URIs (comma or newline separated)
|
||||
const redirectUris = formState.redirectUris
|
||||
.split(/[,\n]/)
|
||||
.map((uri) => uri.trim())
|
||||
.filter((uri) => uri.length > 0);
|
||||
|
||||
if (redirectUris.length === 0) {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "At least one redirect URI is required",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Validate each redirect URI
|
||||
for (const uri of redirectUris) {
|
||||
const validation = validateRedirectUri(uri);
|
||||
if (!validation.valid) {
|
||||
toast({
|
||||
title: "Invalid Redirect URI",
|
||||
description: `"${uri}": ${validation.error}`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!formState.name.trim()) {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Client name is required",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const requestData: RegisterClientRequest = {
|
||||
name: formState.name.trim(),
|
||||
redirect_uris: redirectUris,
|
||||
client_type: formState.clientType,
|
||||
};
|
||||
|
||||
if (formState.description.trim()) {
|
||||
requestData.description = formState.description.trim();
|
||||
}
|
||||
if (formState.homepageUrl.trim()) {
|
||||
requestData.homepage_url = formState.homepageUrl.trim();
|
||||
}
|
||||
if (formState.privacyPolicyUrl.trim()) {
|
||||
requestData.privacy_policy_url = formState.privacyPolicyUrl.trim();
|
||||
}
|
||||
if (formState.termsOfServiceUrl.trim()) {
|
||||
requestData.terms_of_service_url = formState.termsOfServiceUrl.trim();
|
||||
}
|
||||
|
||||
const response = await registerClient({
|
||||
data: requestData,
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
const secretData = response.data as ClientSecretResponseWithWebhook;
|
||||
setNewClientSecret(secretData);
|
||||
setIsCreateOpen(false);
|
||||
setIsSecretDialogOpen(true);
|
||||
resetForm();
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "OAuth client created successfully",
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to create OAuth client",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async function handleRotateSecret(clientId: string) {
|
||||
try {
|
||||
const response = await rotateSecret({ clientId });
|
||||
if (response.status === 200) {
|
||||
const secretData = response.data as ClientSecretResponseWithWebhook;
|
||||
setNewClientSecret(secretData);
|
||||
setIsSecretDialogOpen(true);
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "Client secret rotated successfully",
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to rotate client secret",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleCopyClientId() {
|
||||
if (newClientSecret?.client_id) {
|
||||
navigator.clipboard.writeText(newClientSecret.client_id);
|
||||
toast({
|
||||
title: "Copied",
|
||||
description: "Client ID copied to clipboard",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleCopyClientSecret() {
|
||||
if (newClientSecret?.client_secret) {
|
||||
navigator.clipboard.writeText(newClientSecret.client_secret);
|
||||
toast({
|
||||
title: "Copied",
|
||||
description: "Client secret copied to clipboard",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleCopyWebhookSecret() {
|
||||
if (newClientSecret?.webhook_secret) {
|
||||
navigator.clipboard.writeText(newClientSecret.webhook_secret);
|
||||
toast({
|
||||
title: "Copied",
|
||||
description: "Webhook secret copied to clipboard",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleSecretDialogChange(open: boolean) {
|
||||
setIsSecretDialogOpen(open);
|
||||
if (!open) {
|
||||
setNewClientSecret(null);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isCreateOpen,
|
||||
setIsCreateOpen,
|
||||
isSecretDialogOpen,
|
||||
setIsSecretDialogOpen: handleSecretDialogChange,
|
||||
formState,
|
||||
setFormState,
|
||||
newClientSecret,
|
||||
isCreating,
|
||||
isRotating,
|
||||
handleCreateClient,
|
||||
handleRotateSecret,
|
||||
handleCopyClientId,
|
||||
handleCopyClientSecret,
|
||||
handleCopyWebhookSecret,
|
||||
resetForm,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,381 @@
|
||||
"use client";
|
||||
|
||||
import { CircleNotch, Copy, DotsThreeVertical } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/__legacy__/ui/dropdown-menu";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { Input } from "@/components/__legacy__/ui/input";
|
||||
import { Textarea } from "@/components/__legacy__/ui/textarea";
|
||||
import { useOAuthClientSection } from "./useOAuthClientSection";
|
||||
|
||||
export function OAuthClientSection() {
|
||||
const {
|
||||
oauthClients,
|
||||
isLoading,
|
||||
isDeleting,
|
||||
isSuspending,
|
||||
isActivating,
|
||||
isRotatingWebhookSecret,
|
||||
isUpdating,
|
||||
handleDeleteClient,
|
||||
handleSuspendClient,
|
||||
handleActivateClient,
|
||||
handleRotateWebhookSecret,
|
||||
handleCopyWebhookSecret,
|
||||
handleEditClient,
|
||||
handleSaveClient,
|
||||
webhookSecretDialogOpen,
|
||||
setWebhookSecretDialogOpen,
|
||||
newWebhookSecret,
|
||||
editDialogOpen,
|
||||
setEditDialogOpen,
|
||||
editingClient,
|
||||
editFormState,
|
||||
setEditFormState,
|
||||
} = useOAuthClientSection();
|
||||
|
||||
const isActionPending =
|
||||
isDeleting ||
|
||||
isSuspending ||
|
||||
isActivating ||
|
||||
isRotatingWebhookSecret ||
|
||||
isUpdating;
|
||||
|
||||
return (
|
||||
<>
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center p-4">
|
||||
<CircleNotch className="h-6 w-6 animate-spin" weight="bold" />
|
||||
</div>
|
||||
) : oauthClients && oauthClients.length > 0 ? (
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Client ID</TableHead>
|
||||
<TableHead>Type</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead></TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{oauthClients.map((client) => (
|
||||
<TableRow key={client.id} data-testid="oauth-client-row">
|
||||
<TableCell>
|
||||
<div className="flex flex-col">
|
||||
<span className="font-medium">{client.name}</span>
|
||||
{client.description && (
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{client.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell data-testid="oauth-client-id">
|
||||
<div className="rounded-md border p-1 px-2 font-mono text-xs">
|
||||
{client.client_id}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Badge variant="outline">
|
||||
{client.client_type === "confidential"
|
||||
? "Confidential"
|
||||
: "Public"}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Badge
|
||||
variant={
|
||||
client.status === "ACTIVE" ? "default" : "destructive"
|
||||
}
|
||||
className={
|
||||
client.status === "ACTIVE"
|
||||
? "border-green-600 bg-green-100 text-green-800"
|
||||
: "border-red-600 bg-red-100 text-red-800"
|
||||
}
|
||||
>
|
||||
{client.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{new Date(client.created_at).toLocaleDateString()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
data-testid="oauth-client-actions"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<DotsThreeVertical className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleEditClient(client)}
|
||||
disabled={isActionPending}
|
||||
>
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={() =>
|
||||
handleRotateWebhookSecret(client.client_id)
|
||||
}
|
||||
disabled={isActionPending}
|
||||
>
|
||||
Rotate Webhook Secret
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
{client.status === "ACTIVE" ? (
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleSuspendClient(client.client_id)}
|
||||
disabled={isActionPending}
|
||||
>
|
||||
Suspend
|
||||
</DropdownMenuItem>
|
||||
) : (
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleActivateClient(client.client_id)}
|
||||
disabled={isActionPending}
|
||||
>
|
||||
Activate
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem
|
||||
className="text-destructive"
|
||||
onClick={() => handleDeleteClient(client.client_id)}
|
||||
disabled={isActionPending}
|
||||
>
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
) : (
|
||||
<div className="py-8 text-center text-muted-foreground">
|
||||
No OAuth clients registered yet. Create one to get started.
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog
|
||||
open={webhookSecretDialogOpen}
|
||||
onOpenChange={setWebhookSecretDialogOpen}
|
||||
>
|
||||
<DialogContent className="sm:max-w-[525px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Webhook Secret Rotated</DialogTitle>
|
||||
<DialogDescription>
|
||||
Your new webhook secret has been generated. Please copy it now as
|
||||
it will not be shown again.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label>New Webhook Secret</Label>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 break-all rounded-md bg-secondary p-2 font-mono text-sm">
|
||||
{newWebhookSecret}
|
||||
</code>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="outline"
|
||||
onClick={handleCopyWebhookSecret}
|
||||
>
|
||||
<Copy className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Use this secret to verify webhook signatures (HMAC-SHA256)
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-xs text-destructive">
|
||||
This secret will only be shown once. Store it securely!
|
||||
</p>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button onClick={() => setWebhookSecretDialogOpen(false)}>
|
||||
Close
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<Dialog open={editDialogOpen} onOpenChange={setEditDialogOpen}>
|
||||
<DialogContent className="max-h-[90vh] overflow-y-auto sm:max-w-[525px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Edit OAuth Client</DialogTitle>
|
||||
<DialogDescription>
|
||||
Update your OAuth client settings. Changes will take effect
|
||||
immediately.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="edit-name">Name</Label>
|
||||
<Input
|
||||
id="edit-name"
|
||||
value={editFormState.name ?? ""}
|
||||
onChange={(e) =>
|
||||
setEditFormState((prev) => ({
|
||||
...prev,
|
||||
name: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="My Application"
|
||||
maxLength={100}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="edit-description">Description</Label>
|
||||
<Input
|
||||
id="edit-description"
|
||||
value={editFormState.description ?? ""}
|
||||
onChange={(e) =>
|
||||
setEditFormState((prev) => ({
|
||||
...prev,
|
||||
description: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="A brief description of your application"
|
||||
maxLength={500}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="edit-redirectUris">Redirect URIs</Label>
|
||||
<Textarea
|
||||
id="edit-redirectUris"
|
||||
value={editFormState.redirect_uris?.join("\n") ?? ""}
|
||||
onChange={(e) =>
|
||||
setEditFormState((prev) => ({
|
||||
...prev,
|
||||
redirect_uris: e.target.value
|
||||
.split(/[\n,]/)
|
||||
.map((uri) => uri.trim())
|
||||
.filter(Boolean),
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com/callback https://localhost:3000/callback"
|
||||
rows={3}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Enter one URI per line or separate with commas
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="edit-webhookDomains">Webhook Domains</Label>
|
||||
<Textarea
|
||||
id="edit-webhookDomains"
|
||||
value={editFormState.webhook_domains?.join("\n") ?? ""}
|
||||
onChange={(e) =>
|
||||
setEditFormState((prev) => ({
|
||||
...prev,
|
||||
webhook_domains: e.target.value
|
||||
.split(/[\n,]/)
|
||||
.map((domain) => domain.trim())
|
||||
.filter(Boolean),
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com https://api.myapp.com"
|
||||
rows={3}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Domains that can receive webhook notifications
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="edit-homepageUrl">Homepage URL</Label>
|
||||
<Input
|
||||
id="edit-homepageUrl"
|
||||
type="url"
|
||||
value={editFormState.homepage_url ?? ""}
|
||||
onChange={(e) =>
|
||||
setEditFormState((prev) => ({
|
||||
...prev,
|
||||
homepage_url: e.target.value || undefined,
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="edit-privacyPolicyUrl">Privacy Policy URL</Label>
|
||||
<Input
|
||||
id="edit-privacyPolicyUrl"
|
||||
type="url"
|
||||
value={editFormState.privacy_policy_url ?? ""}
|
||||
onChange={(e) =>
|
||||
setEditFormState((prev) => ({
|
||||
...prev,
|
||||
privacy_policy_url: e.target.value || undefined,
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com/privacy"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="edit-termsOfServiceUrl">
|
||||
Terms of Service URL
|
||||
</Label>
|
||||
<Input
|
||||
id="edit-termsOfServiceUrl"
|
||||
type="url"
|
||||
value={editFormState.terms_of_service_url ?? ""}
|
||||
onChange={(e) =>
|
||||
setEditFormState((prev) => ({
|
||||
...prev,
|
||||
terms_of_service_url: e.target.value || undefined,
|
||||
}))
|
||||
}
|
||||
placeholder="https://myapp.com/terms"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => setEditDialogOpen(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleSaveClient} disabled={isUpdating}>
|
||||
{isUpdating ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
"use client";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
getGetOauthClientsListClientsQueryKey,
|
||||
useDeleteOauthClientsDeleteClient,
|
||||
useGetOauthClientsListClients,
|
||||
usePatchOauthClientsUpdateClient,
|
||||
usePostOauthClientsActivateClient,
|
||||
usePostOauthClientsRotateWebhookSecret,
|
||||
usePostOauthClientsSuspendClient,
|
||||
} from "@/app/api/__generated__/endpoints/oauth-clients/oauth-clients";
|
||||
import type { ClientResponse } from "@/app/api/__generated__/models/clientResponse";
|
||||
import type { UpdateClientRequest } from "@/app/api/__generated__/models/updateClientRequest";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { validateRedirectUri } from "@/lib/utils";
|
||||
|
||||
export function useOAuthClientSection() {
|
||||
const queryClient = getQueryClient();
|
||||
const { toast } = useToast();
|
||||
|
||||
const [webhookSecretDialogOpen, setWebhookSecretDialogOpen] = useState(false);
|
||||
const [newWebhookSecret, setNewWebhookSecret] = useState<string | null>(null);
|
||||
const [editDialogOpen, setEditDialogOpen] = useState(false);
|
||||
const [editingClient, setEditingClient] = useState<ClientResponse | null>(
|
||||
null,
|
||||
);
|
||||
const [editFormState, setEditFormState] = useState<UpdateClientRequest>({});
|
||||
|
||||
const { data: oauthClients, isLoading } = useGetOauthClientsListClients({
|
||||
query: {
|
||||
select: (res) => {
|
||||
if (res.status !== 200) return undefined;
|
||||
return res.data;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { mutateAsync: deleteClient, isPending: isDeleting } =
|
||||
useDeleteOauthClientsDeleteClient({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetOauthClientsListClientsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { mutateAsync: suspendClient, isPending: isSuspending } =
|
||||
usePostOauthClientsSuspendClient({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetOauthClientsListClientsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { mutateAsync: activateClient, isPending: isActivating } =
|
||||
usePostOauthClientsActivateClient({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetOauthClientsListClientsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const {
|
||||
mutateAsync: rotateWebhookSecret,
|
||||
isPending: isRotatingWebhookSecret,
|
||||
} = usePostOauthClientsRotateWebhookSecret();
|
||||
|
||||
const { mutateAsync: updateClient, isPending: isUpdating } =
|
||||
usePatchOauthClientsUpdateClient({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetOauthClientsListClientsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
async function handleDeleteClient(clientId: string) {
|
||||
try {
|
||||
await deleteClient({ clientId });
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "OAuth client deleted successfully",
|
||||
});
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to delete OAuth client",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSuspendClient(clientId: string) {
|
||||
try {
|
||||
await suspendClient({ clientId });
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "OAuth client suspended successfully",
|
||||
});
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to suspend OAuth client",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async function handleActivateClient(clientId: string) {
|
||||
try {
|
||||
await activateClient({ clientId });
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "OAuth client activated successfully",
|
||||
});
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to activate OAuth client",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async function handleRotateWebhookSecret(clientId: string) {
|
||||
try {
|
||||
const response = await rotateWebhookSecret({ clientId });
|
||||
if (response.status === 200) {
|
||||
setNewWebhookSecret(response.data.webhook_secret);
|
||||
setWebhookSecretDialogOpen(true);
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "Webhook secret rotated successfully",
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to rotate webhook secret",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleCopyWebhookSecret() {
|
||||
if (newWebhookSecret) {
|
||||
navigator.clipboard.writeText(newWebhookSecret);
|
||||
toast({
|
||||
title: "Copied",
|
||||
description: "Webhook secret copied to clipboard",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleEditClient(client: ClientResponse) {
|
||||
setEditingClient(client);
|
||||
setEditFormState({
|
||||
name: client.name,
|
||||
description: client.description ?? undefined,
|
||||
homepage_url: client.homepage_url ?? undefined,
|
||||
privacy_policy_url: client.privacy_policy_url ?? undefined,
|
||||
terms_of_service_url: client.terms_of_service_url ?? undefined,
|
||||
redirect_uris: client.redirect_uris,
|
||||
webhook_domains: client.webhook_domains,
|
||||
});
|
||||
setEditDialogOpen(true);
|
||||
}
|
||||
|
||||
async function handleSaveClient() {
|
||||
if (!editingClient) return;
|
||||
|
||||
// Validate redirect URIs before saving
|
||||
if (editFormState.redirect_uris) {
|
||||
for (const uri of editFormState.redirect_uris) {
|
||||
const validation = validateRedirectUri(uri);
|
||||
if (!validation.valid) {
|
||||
toast({
|
||||
title: "Invalid Redirect URI",
|
||||
description: `"${uri}": ${validation.error}`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await updateClient({
|
||||
clientId: editingClient.client_id,
|
||||
data: editFormState,
|
||||
});
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "OAuth client updated successfully",
|
||||
});
|
||||
setEditDialogOpen(false);
|
||||
setEditingClient(null);
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to update OAuth client",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
oauthClients,
|
||||
isLoading,
|
||||
isDeleting,
|
||||
isSuspending,
|
||||
isActivating,
|
||||
isRotatingWebhookSecret,
|
||||
isUpdating,
|
||||
handleDeleteClient,
|
||||
handleSuspendClient,
|
||||
handleActivateClient,
|
||||
handleRotateWebhookSecret,
|
||||
handleCopyWebhookSecret,
|
||||
handleEditClient,
|
||||
handleSaveClient,
|
||||
webhookSecretDialogOpen,
|
||||
setWebhookSecretDialogOpen,
|
||||
newWebhookSecret,
|
||||
editDialogOpen,
|
||||
setEditDialogOpen,
|
||||
editingClient,
|
||||
editFormState,
|
||||
setEditFormState,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
import { Metadata } from "next/types";
|
||||
import { OAuthClientSection } from "@/app/(platform)/profile/(user)/developer/components/OAuthClientSection/OAuthClientSection";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/__legacy__/ui/card";
|
||||
import { OAuthClientModals } from "./components/OAuthClientModals/OAuthClientModals";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Developer Settings - AutoGPT Platform",
|
||||
};
|
||||
|
||||
function DeveloperPage() {
|
||||
return (
|
||||
<div className="w-full pr-4 pt-24 md:pt-0">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>OAuth Applications</CardTitle>
|
||||
<CardDescription>
|
||||
Register and manage OAuth clients to integrate third-party
|
||||
applications with the AutoGPT Platform. OAuth clients allow external
|
||||
applications to access AutoGPT APIs on behalf of users.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<OAuthClientModals />
|
||||
<OAuthClientSection />
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default DeveloperPage;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -883,6 +883,30 @@ export const IconUploadCloud = createIcon((props) => (
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Code icon component for developer settings.
|
||||
*
|
||||
* @component IconCode
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The code icon.
|
||||
*/
|
||||
export const IconCode = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
aria-label="Code Icon"
|
||||
{...props}
|
||||
>
|
||||
<polyline points="16 18 22 12 16 6" />
|
||||
<polyline points="8 6 2 12 8 18" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Chevron up icon component.
|
||||
*
|
||||
@@ -1838,6 +1862,7 @@ export enum IconType {
|
||||
AutoGPTLogo,
|
||||
Sliders,
|
||||
Chat,
|
||||
Code,
|
||||
}
|
||||
|
||||
export function getIconForSocial(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import {
|
||||
IconBuilder,
|
||||
IconCode,
|
||||
IconEdit,
|
||||
IconLibrary,
|
||||
IconLogOut,
|
||||
@@ -130,10 +131,15 @@ export function getAccountMenuItems(userRole?: string): MenuItemGroup[] {
|
||||
});
|
||||
}
|
||||
|
||||
// Add settings and logout
|
||||
// Add developer settings and settings
|
||||
baseMenuItems.push(
|
||||
{
|
||||
items: [
|
||||
{
|
||||
icon: IconType.Code,
|
||||
text: "Developer",
|
||||
href: "/profile/developer",
|
||||
},
|
||||
{
|
||||
icon: IconType.Settings,
|
||||
text: "Settings",
|
||||
@@ -177,6 +183,8 @@ export function getAccountMenuOptionIcon(icon: IconType) {
|
||||
return <IconSliders className={iconClass} />;
|
||||
case IconType.Chat:
|
||||
return <ChatsIcon className={iconClass} />;
|
||||
case IconType.Code:
|
||||
return <IconCode className={iconClass} />;
|
||||
default:
|
||||
return <IconRefresh className={iconClass} />;
|
||||
}
|
||||
|
||||
@@ -445,3 +445,31 @@ export function validateYouTubeUrl(val: string): boolean {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate OAuth redirect URI.
|
||||
* Allows HTTPS URLs or http://localhost for development.
|
||||
*/
|
||||
export function validateRedirectUri(uri: string): {
|
||||
valid: boolean;
|
||||
error?: string;
|
||||
} {
|
||||
try {
|
||||
const url = new URL(uri);
|
||||
if (url.protocol === "https:") {
|
||||
return { valid: true };
|
||||
}
|
||||
if (
|
||||
url.protocol === "http:" &&
|
||||
(url.hostname === "localhost" || url.hostname === "127.0.0.1")
|
||||
) {
|
||||
return { valid: true };
|
||||
}
|
||||
return {
|
||||
valid: false,
|
||||
error: "Must be HTTPS (or http://localhost for development)",
|
||||
};
|
||||
} catch {
|
||||
return { valid: false, error: "Invalid URL format" };
|
||||
}
|
||||
}
|
||||
|
||||
701
docs/content/platform/external-api-integration.md
Normal file
701
docs/content/platform/external-api-integration.md
Normal file
@@ -0,0 +1,701 @@
|
||||
# External API Integration Guide
|
||||
|
||||
This guide explains how third-party applications can integrate with AutoGPT Platform to execute agents on behalf of users using the OAuth Provider and Credential Broker system.
|
||||
|
||||
## Overview
|
||||
|
||||
The AutoGPT External API allows your application to:
|
||||
|
||||
- **Execute agents** - Run user-owned or marketplace agents with user-granted credentials
|
||||
- **Access integrations** - Use third-party service credentials (Google, GitHub, etc.) that users have connected
|
||||
- **Receive webhooks** - Get notified when agent executions complete
|
||||
|
||||
The integration uses standard OAuth 2.0 with PKCE for secure authentication, with API key-based user identification during the authorization flow.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ Your App │────▶│ AutoGPT OAuth │────▶│ External API │
|
||||
│ │ │ Provider │ │ │
|
||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────┐
|
||||
│ Credential │
|
||||
│ Broker │
|
||||
└──────────────────┘
|
||||
```
|
||||
|
||||
**Key concepts:**
|
||||
|
||||
1. **OAuth Client** - Your registered application with AutoGPT
|
||||
2. **API Key** - User's AutoGPT API key for identifying the user during authorization
|
||||
3. **OAuth Tokens** - Access/refresh tokens for API authentication
|
||||
4. **Credential Grants** - User permissions to use their connected integrations
|
||||
5. **Integration Scopes** - Specific permissions for each integration (e.g., `google:gmail.readonly`)
|
||||
|
||||
## Getting Started
|
||||
|
||||
### 1. Register Your OAuth Client
|
||||
|
||||
Register your application to get a `client_id` and `client_secret`:
|
||||
|
||||
```bash
|
||||
# Requires user authentication (JWT token)
|
||||
curl -X POST https://platform.agpt.co/oauth/clients/ \
|
||||
-H "Authorization: Bearer <user_jwt>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "My App",
|
||||
"description": "Description of your app",
|
||||
"client_type": "confidential",
|
||||
"redirect_uris": ["https://myapp.com/oauth/callback"],
|
||||
"webhook_domains": ["myapp.com", "*.myapp.com"]
|
||||
}'
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"client_id": "app_abc123xyz",
|
||||
"client_secret": "secret_xyz789..."
|
||||
}
|
||||
```
|
||||
|
||||
> **Important:** Store the `client_secret` securely - it's only shown once!
|
||||
|
||||
**Client types:**
|
||||
- `confidential` - Server-side apps that can securely store secrets
|
||||
- `public` - Browser/mobile apps (no client secret)
|
||||
|
||||
### 2. OAuth Authorization Flow
|
||||
|
||||
Use the standard OAuth 2.0 Authorization Code flow with PKCE to get user consent and tokens.
|
||||
|
||||
#### Authentication Methods
|
||||
|
||||
The authorization endpoint supports two authentication methods:
|
||||
|
||||
1. **API Key (Recommended for server-side apps)**: Pass the user's AutoGPT API key via `X-API-Key` header. This shows the consent page directly without requiring a browser login.
|
||||
|
||||
2. **Login Flow (For browser-based apps)**: If no API key is provided, the user is redirected to the AutoGPT login page, which then continues the OAuth flow automatically after login.
|
||||
|
||||
#### Generate PKCE Parameters
|
||||
|
||||
```javascript
|
||||
// Generate code verifier and challenge
|
||||
function generateCodeVerifier() {
|
||||
const array = new Uint8Array(32);
|
||||
crypto.getRandomValues(array);
|
||||
return base64UrlEncode(array);
|
||||
}
|
||||
|
||||
async function generateCodeChallenge(verifier) {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return base64UrlEncode(new Uint8Array(digest));
|
||||
}
|
||||
```
|
||||
|
||||
#### Request Authorization
|
||||
|
||||
**Option 1: With API Key (Server-side apps)**
|
||||
|
||||
If you have the user's API key, make a server-side request to get the consent page directly:
|
||||
|
||||
```javascript
|
||||
const state = crypto.randomUUID(); // Store this for validation
|
||||
const codeVerifier = generateCodeVerifier(); // Store this securely
|
||||
const codeChallenge = await generateCodeChallenge(codeVerifier);
|
||||
|
||||
const authUrl = 'https://platform.agpt.co/oauth/authorize?' + new URLSearchParams({
|
||||
response_type: 'code',
|
||||
client_id: CLIENT_ID,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
state: state,
|
||||
code_challenge: codeChallenge,
|
||||
code_challenge_method: 'S256',
|
||||
scope: 'openid profile email agents:execute integrations:connect',
|
||||
});
|
||||
|
||||
// Server-side request with user's API key
|
||||
const response = await fetch(authUrl, {
|
||||
headers: {
|
||||
'X-API-Key': userApiKey, // User's AutoGPT API key (optional)
|
||||
},
|
||||
});
|
||||
|
||||
// Response is HTML consent page - render it to the user
|
||||
const consentHtml = await response.text();
|
||||
```
|
||||
|
||||
**Option 2: Browser Login Flow (No API Key)**
|
||||
|
||||
If you don't have the user's API key, redirect them to the authorization URL. They will be prompted to log in to AutoGPT, and the OAuth flow will continue automatically after login:
|
||||
|
||||
```javascript
|
||||
const state = crypto.randomUUID(); // Store this for validation
|
||||
const codeVerifier = generateCodeVerifier(); // Store this securely
|
||||
const codeChallenge = await generateCodeChallenge(codeVerifier);
|
||||
|
||||
const authUrl = 'https://platform.agpt.co/oauth/authorize?' + new URLSearchParams({
|
||||
response_type: 'code',
|
||||
client_id: CLIENT_ID,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
state: state,
|
||||
code_challenge: codeChallenge,
|
||||
code_challenge_method: 'S256',
|
||||
scope: 'openid profile email agents:execute integrations:connect',
|
||||
});
|
||||
|
||||
// Redirect user to AutoGPT login page
|
||||
// After login, they'll see the consent page and be redirected to your callback
|
||||
window.location.href = authUrl;
|
||||
```
|
||||
|
||||
**Authentication methods (in order of preference):**
|
||||
|
||||
| Method | Header | Description |
|
||||
|--------|--------|-------------|
|
||||
| API Key | `X-API-Key: agpt_xxx` | User's AutoGPT API key (preferred for external apps) |
|
||||
| JWT Token | `Authorization: Bearer <jwt>` | Supabase JWT token |
|
||||
| Cookie | `access_token` cookie | Browser-based authentication |
|
||||
|
||||
**Available scopes:**
|
||||
|
||||
| Scope | Description |
|
||||
|-------|-------------|
|
||||
| `openid` | Required for OIDC |
|
||||
| `profile` | Access user profile (name) |
|
||||
| `email` | Access user email |
|
||||
| `agents:execute` | Execute agents and check status |
|
||||
| `integrations:list` | List user's credential grants |
|
||||
| `integrations:connect` | Request new credential grants |
|
||||
| `integrations:delete` | Delete credentials via grants |
|
||||
|
||||
#### Handle OAuth Callback
|
||||
|
||||
```javascript
|
||||
// Your callback endpoint receives: ?code=xxx&state=xxx
|
||||
app.get('/oauth/callback', async (req, res) => {
|
||||
const { code, state } = req.query;
|
||||
|
||||
// Verify state matches what you stored
|
||||
if (state !== storedState) {
|
||||
return res.status(400).send('Invalid state');
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
const response = await fetch('https://platform.agpt.co/oauth/token', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code_verifier: storedCodeVerifier,
|
||||
}),
|
||||
});
|
||||
|
||||
const tokens = await response.json();
|
||||
// { access_token, refresh_token, token_type, expires_in }
|
||||
});
|
||||
```
|
||||
|
||||
### 3. Request Credential Grants (Connect Flow)
|
||||
|
||||
Before executing agents that require integrations (like Gmail, Google Sheets, GitHub), you need credential grants from the user.
|
||||
|
||||
#### Open Connect Popup
|
||||
|
||||
```javascript
|
||||
function requestCredentialGrant(provider, scopes) {
|
||||
const nonce = crypto.randomUUID();
|
||||
|
||||
// Store nonce to validate response
|
||||
sessionStorage.setItem('connect_nonce', nonce);
|
||||
|
||||
const connectUrl = new URL(`https://platform.agpt.co/connect/${provider}`);
|
||||
connectUrl.searchParams.set('client_id', CLIENT_ID);
|
||||
connectUrl.searchParams.set('scopes', scopes.join(','));
|
||||
connectUrl.searchParams.set('nonce', nonce);
|
||||
connectUrl.searchParams.set('redirect_origin', window.location.origin);
|
||||
|
||||
// Open popup (user must be logged into AutoGPT)
|
||||
const popup = window.open(
|
||||
connectUrl.toString(),
|
||||
'AutoGPT Connect',
|
||||
'width=500,height=600,popup=true'
|
||||
);
|
||||
|
||||
// Listen for result
|
||||
window.addEventListener('message', handleConnectResult, { once: true });
|
||||
}
|
||||
|
||||
function handleConnectResult(event) {
|
||||
// Verify origin
|
||||
if (event.origin !== 'https://platform.agpt.co') return;
|
||||
|
||||
const data = event.data;
|
||||
if (data.type !== 'autogpt_connect_result') return;
|
||||
|
||||
// Verify nonce
|
||||
if (data.nonce !== sessionStorage.getItem('connect_nonce')) return;
|
||||
|
||||
if (data.success) {
|
||||
console.log('Grant created:', data.grant_id);
|
||||
console.log('Credential ID:', data.credential_id);
|
||||
console.log('Provider:', data.provider);
|
||||
} else {
|
||||
console.error('Connect failed:', data.error, data.error_description);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Integration scopes by provider:**
|
||||
|
||||
| Provider | Available Scopes |
|
||||
|----------|------------------|
|
||||
| Google | `google:gmail.readonly`, `google:gmail.send`, `google:sheets.read`, `google:sheets.write`, `google:calendar.read`, `google:calendar.write`, `google:drive.read`, `google:drive.write` |
|
||||
| GitHub | `github:repo.read`, `github:repo.write`, `github:user.read` |
|
||||
| Twitter/X | `twitter:tweet.read`, `twitter:tweet.write`, `twitter:user.read` |
|
||||
| Linear | `linear:read`, `linear:write` |
|
||||
| Notion | `notion:read`, `notion:write` |
|
||||
| Slack | `slack:read`, `slack:write` |
|
||||
|
||||
### 4. Execute Agents
|
||||
|
||||
With an OAuth token and credential grants, you can execute agents:
|
||||
|
||||
```javascript
|
||||
async function executeAgent(agentId, inputs, grantIds = null, webhookUrl = null) {
|
||||
const response = await fetch(
|
||||
`https://platform.agpt.co/api/external/v1/executions/agents/${agentId}/execute`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
inputs,
|
||||
grant_ids: grantIds, // Optional: specific grants to use
|
||||
webhook_url: webhookUrl, // Optional: receive completion webhook
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const result = await response.json();
|
||||
// { execution_id, status: "queued", message: "..." }
|
||||
return result;
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Check Execution Status
|
||||
|
||||
Poll for execution status or use webhooks:
|
||||
|
||||
```javascript
|
||||
async function getExecutionStatus(executionId) {
|
||||
const response = await fetch(
|
||||
`https://platform.agpt.co/api/external/v1/executions/${executionId}`,
|
||||
{
|
||||
headers: { 'Authorization': `Bearer ${accessToken}` },
|
||||
}
|
||||
);
|
||||
|
||||
return await response.json();
|
||||
// {
|
||||
// execution_id,
|
||||
// status: "queued" | "running" | "completed" | "failed",
|
||||
// started_at,
|
||||
// completed_at,
|
||||
// outputs, // Present when completed
|
||||
// error, // Present when failed
|
||||
// }
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Handle Webhooks
|
||||
|
||||
If you provided a `webhook_url`, you'll receive POST requests with execution events:
|
||||
|
||||
```javascript
|
||||
app.post('/webhooks/autogpt', (req, res) => {
|
||||
// Verify webhook signature (if configured)
|
||||
const signature = req.headers['x-webhook-signature'];
|
||||
const timestamp = req.headers['x-webhook-timestamp'];
|
||||
|
||||
if (signature) {
|
||||
const expectedSignature = crypto
|
||||
.createHmac('sha256', WEBHOOK_SECRET)
|
||||
.update(JSON.stringify(req.body))
|
||||
.digest('hex');
|
||||
|
||||
if (signature !== `sha256=${expectedSignature}`) {
|
||||
return res.status(401).send('Invalid signature');
|
||||
}
|
||||
}
|
||||
|
||||
const { event, timestamp, data } = req.body;
|
||||
|
||||
switch (event) {
|
||||
case 'execution.started':
|
||||
console.log(`Execution ${data.execution_id} started`);
|
||||
break;
|
||||
case 'execution.completed':
|
||||
console.log(`Execution ${data.execution_id} completed`, data.outputs);
|
||||
break;
|
||||
case 'execution.failed':
|
||||
console.error(`Execution ${data.execution_id} failed:`, data.error);
|
||||
break;
|
||||
case 'grant.revoked':
|
||||
console.log(`Grant ${data.grant_id} was revoked`);
|
||||
break;
|
||||
}
|
||||
|
||||
res.status(200).send('OK');
|
||||
});
|
||||
```
|
||||
|
||||
> **Note:** Webhook URLs must match domains registered in your OAuth client's `webhook_domains`.
|
||||
|
||||
## API Reference
|
||||
|
||||
### External API Endpoints
|
||||
|
||||
Base URL: `https://platform.agpt.co/api/external/v1`
|
||||
|
||||
| Method | Endpoint | Scope Required | Description |
|
||||
|--------|----------|----------------|-------------|
|
||||
| GET | `/executions/capabilities` | `agents:execute` | Get available grants and scopes |
|
||||
| POST | `/executions/agents/{agent_id}/execute` | `agents:execute` | Execute an agent |
|
||||
| GET | `/executions/{execution_id}` | `agents:execute` | Get execution status |
|
||||
| POST | `/executions/{execution_id}/cancel` | `agents:execute` | Cancel execution |
|
||||
| GET | `/grants/` | `integrations:list` | List credential grants |
|
||||
| GET | `/grants/{grant_id}` | `integrations:list` | Get grant details |
|
||||
| DELETE | `/grants/{grant_id}/credential` | `integrations:delete` | Delete credential via grant |
|
||||
|
||||
### OAuth Endpoints
|
||||
|
||||
Base URL: `https://platform.agpt.co`
|
||||
|
||||
| Method | Endpoint | Auth Required | Description |
|
||||
|--------|----------|---------------|-------------|
|
||||
| GET | `/oauth/authorize` | API Key or JWT | Authorization endpoint |
|
||||
| POST | `/oauth/token` | None | Token endpoint |
|
||||
| GET | `/oauth/userinfo` | OAuth Token | OIDC UserInfo |
|
||||
| POST | `/oauth/revoke` | OAuth Token | Revoke tokens |
|
||||
| GET | `/.well-known/openid-configuration` | None | OIDC Discovery |
|
||||
|
||||
### Client Management Endpoints
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| POST | `/oauth/clients/` | Register new client |
|
||||
| GET | `/oauth/clients/` | List your clients |
|
||||
| GET | `/oauth/clients/{client_id}` | Get client details |
|
||||
| PATCH | `/oauth/clients/{client_id}` | Update client |
|
||||
| DELETE | `/oauth/clients/{client_id}` | Delete client |
|
||||
| POST | `/oauth/clients/{client_id}/rotate-secret` | Rotate client secret |
|
||||
|
||||
## Rate Limits
|
||||
|
||||
| Endpoint Type | Limit |
|
||||
|--------------|-------|
|
||||
| OAuth endpoints | 20-30 requests/minute |
|
||||
| Agent execution | 10 requests/minute, 100/hour |
|
||||
| Read endpoints | 60 requests/minute, 1000/hour |
|
||||
|
||||
Rate limit headers are included in responses:
|
||||
- `X-RateLimit-Remaining` - Requests remaining in current window
|
||||
- `X-RateLimit-Reset` - Unix timestamp when limit resets
|
||||
- `Retry-After` - Seconds to wait (when rate limited)
|
||||
|
||||
## Error Handling
|
||||
|
||||
### OAuth Errors
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Authorization code has expired"
|
||||
}
|
||||
```
|
||||
|
||||
Common OAuth errors:
|
||||
- `invalid_client` - Unknown or invalid client
|
||||
- `invalid_grant` - Expired/invalid authorization code
|
||||
- `access_denied` - User denied consent
|
||||
- `invalid_scope` - Requested scope not allowed
|
||||
- `login_required` - User authentication required (provide API key or JWT)
|
||||
|
||||
### API Errors
|
||||
|
||||
```json
|
||||
{
|
||||
"detail": "Grant validation failed: No valid grants found for requested integrations"
|
||||
}
|
||||
```
|
||||
|
||||
HTTP status codes:
|
||||
- `400` - Bad request (invalid parameters)
|
||||
- `401` - Unauthorized (invalid/expired token or missing API key)
|
||||
- `403` - Forbidden (insufficient scopes or grants)
|
||||
- `404` - Resource not found
|
||||
- `429` - Rate limited
|
||||
- `500` - Internal server error
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Store secrets securely** - Never expose `client_secret` in client-side code
|
||||
2. **Protect API keys** - User API keys should be stored securely and transmitted over HTTPS
|
||||
3. **Validate state parameter** - Prevent CSRF attacks
|
||||
4. **Use PKCE** - Required for all authorization flows
|
||||
5. **Verify webhook signatures** - Prevent spoofed webhooks
|
||||
6. **Request minimal scopes** - Only request what you need
|
||||
7. **Handle token refresh** - Refresh tokens before they expire
|
||||
8. **Validate redirect origins** - Only accept messages from expected origins
|
||||
|
||||
## Complete Integration Example
|
||||
|
||||
```javascript
|
||||
class AutoGPTClient {
|
||||
constructor(clientId, clientSecret, redirectUri) {
|
||||
this.clientId = clientId;
|
||||
this.clientSecret = clientSecret;
|
||||
this.redirectUri = redirectUri;
|
||||
this.baseUrl = 'https://platform.agpt.co';
|
||||
}
|
||||
|
||||
// Step 1: Build authorization URL
|
||||
async buildAuthorizationUrl(scopes) {
|
||||
const state = crypto.randomUUID();
|
||||
const codeVerifier = this.generateCodeVerifier();
|
||||
const codeChallenge = await this.generateCodeChallenge(codeVerifier);
|
||||
|
||||
// Store for callback
|
||||
this.pendingAuth = { state, codeVerifier };
|
||||
|
||||
const params = new URLSearchParams({
|
||||
response_type: 'code',
|
||||
client_id: this.clientId,
|
||||
redirect_uri: this.redirectUri,
|
||||
state: state,
|
||||
code_challenge: codeChallenge,
|
||||
code_challenge_method: 'S256',
|
||||
scope: scopes.join(' '),
|
||||
});
|
||||
|
||||
return `${this.baseUrl}/oauth/authorize?${params}`;
|
||||
}
|
||||
|
||||
// Step 1a: Start authorization with user's API key (server-side apps)
|
||||
async startAuthorizationWithApiKey(userApiKey, scopes) {
|
||||
const authUrl = await this.buildAuthorizationUrl(scopes);
|
||||
|
||||
// Request authorization with user's API key
|
||||
const response = await fetch(authUrl, {
|
||||
headers: {
|
||||
'X-API-Key': userApiKey,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Authorization failed: ${error}`);
|
||||
}
|
||||
|
||||
// Return consent page HTML for rendering to user
|
||||
return await response.text();
|
||||
}
|
||||
|
||||
// Step 1b: Start authorization via login flow (browser-based apps)
|
||||
async startAuthorizationWithLogin(scopes) {
|
||||
const authUrl = await this.buildAuthorizationUrl(scopes);
|
||||
// Redirect user to AutoGPT login - they'll be redirected back after login
|
||||
window.location.href = authUrl;
|
||||
}
|
||||
|
||||
// Step 2: Exchange code for tokens (after user consents)
|
||||
async exchangeCode(code, state) {
|
||||
if (state !== this.pendingAuth?.state) {
|
||||
throw new Error('Invalid state');
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/oauth/token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code,
|
||||
redirect_uri: this.redirectUri,
|
||||
client_id: this.clientId,
|
||||
client_secret: this.clientSecret,
|
||||
code_verifier: this.pendingAuth.codeVerifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
// Step 3: Request credential grant via popup
|
||||
requestGrant(provider, scopes) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const nonce = crypto.randomUUID();
|
||||
|
||||
const url = new URL(`${this.baseUrl}/connect/${provider}`);
|
||||
url.searchParams.set('client_id', this.clientId);
|
||||
url.searchParams.set('scopes', scopes.join(','));
|
||||
url.searchParams.set('nonce', nonce);
|
||||
url.searchParams.set('redirect_origin', window.location.origin);
|
||||
|
||||
const popup = window.open(url.toString(), 'connect', 'width=500,height=600');
|
||||
|
||||
const handler = (event) => {
|
||||
if (event.origin !== this.baseUrl) return;
|
||||
if (event.data?.type !== 'autogpt_connect_result') return;
|
||||
if (event.data?.nonce !== nonce) return;
|
||||
|
||||
window.removeEventListener('message', handler);
|
||||
|
||||
if (event.data.success) {
|
||||
resolve(event.data);
|
||||
} else {
|
||||
reject(new Error(event.data.error_description));
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('message', handler);
|
||||
});
|
||||
}
|
||||
|
||||
// Step 4: Execute agent
|
||||
async executeAgent(accessToken, agentId, inputs, options = {}) {
|
||||
const response = await fetch(
|
||||
`${this.baseUrl}/api/external/v1/executions/agents/${agentId}/execute`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
inputs,
|
||||
grant_ids: options.grantIds,
|
||||
webhook_url: options.webhookUrl,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
// Step 5: Poll for completion
|
||||
async waitForCompletion(accessToken, executionId, timeoutMs = 300000) {
|
||||
const startTime = Date.now();
|
||||
|
||||
while (Date.now() - startTime < timeoutMs) {
|
||||
const response = await fetch(
|
||||
`${this.baseUrl}/api/external/v1/executions/${executionId}`,
|
||||
{ headers: { 'Authorization': `Bearer ${accessToken}` } }
|
||||
);
|
||||
|
||||
const status = await response.json();
|
||||
|
||||
if (status.status === 'completed') {
|
||||
return status.outputs;
|
||||
}
|
||||
|
||||
if (status.status === 'failed') {
|
||||
throw new Error(status.error || 'Execution failed');
|
||||
}
|
||||
|
||||
// Wait before polling again
|
||||
await new Promise(resolve => setTimeout(resolve, 2000));
|
||||
}
|
||||
|
||||
throw new Error('Execution timeout');
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
generateCodeVerifier() {
|
||||
const array = new Uint8Array(32);
|
||||
crypto.getRandomValues(array);
|
||||
return this.base64UrlEncode(array);
|
||||
}
|
||||
|
||||
async generateCodeChallenge(verifier) {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return this.base64UrlEncode(new Uint8Array(digest));
|
||||
}
|
||||
|
||||
base64UrlEncode(buffer) {
|
||||
return btoa(String.fromCharCode(...buffer))
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=+$/, '');
|
||||
}
|
||||
}
|
||||
|
||||
// Usage
|
||||
const client = new AutoGPTClient(
|
||||
'app_abc123',
|
||||
'secret_xyz789',
|
||||
'https://myapp.com/oauth/callback'
|
||||
);
|
||||
|
||||
const scopes = ['openid', 'profile', 'agents:execute', 'integrations:connect'];
|
||||
|
||||
// Option A: Start authorization with user's API key (server-side apps)
|
||||
// User provides their AutoGPT API key (from Settings > Developer)
|
||||
const userApiKey = 'agpt_xxx...';
|
||||
const consentHtml = await client.startAuthorizationWithApiKey(userApiKey, scopes);
|
||||
// Render consent page HTML to user in popup/iframe
|
||||
|
||||
// Option B: Start authorization via login flow (browser-based apps)
|
||||
// No API key needed - user will log in to AutoGPT
|
||||
await client.startAuthorizationWithLogin(scopes);
|
||||
// User is redirected to AutoGPT login, then back to your callback
|
||||
|
||||
// 2. After user consents, your callback receives ?code=xxx&state=xxx
|
||||
// Exchange code for tokens
|
||||
const tokens = await client.exchangeCode(code, state);
|
||||
|
||||
// 3. Request Google credentials (user must be logged into AutoGPT in browser)
|
||||
const grant = await client.requestGrant('google', ['google:gmail.readonly']);
|
||||
|
||||
// 4. Execute an agent
|
||||
const execution = await client.executeAgent(
|
||||
tokens.access_token,
|
||||
'agent-uuid-here',
|
||||
{ query: 'Search my emails for invoices' },
|
||||
{ grantIds: [grant.grant_id] }
|
||||
);
|
||||
|
||||
// 5. Wait for results
|
||||
const outputs = await client.waitForCompletion(
|
||||
tokens.access_token,
|
||||
execution.execution_id
|
||||
);
|
||||
console.log('Agent outputs:', outputs);
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
- [GitHub Issues](https://github.com/Significant-Gravitas/AutoGPT/issues) - Bug reports and feature requests
|
||||
- [Discord Community](https://discord.gg/autogpt) - Community support
|
||||
1251
docs/content/platform/integrations/nextjs.md
Normal file
1251
docs/content/platform/integrations/nextjs.md
Normal file
File diff suppressed because it is too large
Load Diff
1166
docs/content/platform/integrations/rails.md
Normal file
1166
docs/content/platform/integrations/rails.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,10 @@ nav:
|
||||
- Using AI/ML API: platform/aimlapi.md
|
||||
- Using D-ID: platform/d_id.md
|
||||
- Blocks: platform/blocks/blocks.md
|
||||
- External API Integration: platform/external-api-integration.md
|
||||
- Framework Integrations:
|
||||
- Next.js: platform/integrations/nextjs.md
|
||||
- Ruby on Rails: platform/integrations/rails.md
|
||||
- Contributing:
|
||||
- Tests: platform/contributing/tests.md
|
||||
- OAuth Flows: platform/contributing/oauth-integration-flow.md
|
||||
|
||||
Reference in New Issue
Block a user