mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
add external auth flows
This commit is contained in:
327
autogpt_platform/backend/backend/data/credential_grants.py
Normal file
327
autogpt_platform/backend/backend/data/credential_grants.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Credential Grant data layer.
|
||||
|
||||
Handles database operations for credential grants which allow OAuth clients
|
||||
to use credentials on behalf of users.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from prisma.enums import CredentialGrantPermission
|
||||
from prisma.models import CredentialGrant
|
||||
|
||||
from backend.data.db import prisma
|
||||
|
||||
|
||||
async def create_credential_grant(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
credential_id: str,
|
||||
provider: str,
|
||||
granted_scopes: list[str],
|
||||
permissions: list[CredentialGrantPermission],
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> CredentialGrant:
|
||||
"""
|
||||
Create a new credential grant.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user granting access
|
||||
client_id: Database ID of the OAuth client
|
||||
credential_id: ID of the credential being granted
|
||||
provider: Provider name (e.g., "google", "github")
|
||||
granted_scopes: List of integration scopes granted
|
||||
permissions: List of permissions (USE, DELETE)
|
||||
expires_at: Optional expiration datetime
|
||||
|
||||
Returns:
|
||||
Created CredentialGrant
|
||||
"""
|
||||
return await prisma.credentialgrant.create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"credentialId": credential_id,
|
||||
"provider": provider,
|
||||
"grantedScopes": granted_scopes,
|
||||
"permissions": permissions,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_credential_grant(
|
||||
grant_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
) -> Optional[CredentialGrant]:
|
||||
"""
|
||||
Get a credential grant by ID.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
user_id: Optional user ID filter
|
||||
client_id: Optional client database ID filter
|
||||
|
||||
Returns:
|
||||
CredentialGrant or None
|
||||
"""
|
||||
where: dict[str, str] = {"id": grant_id}
|
||||
if user_id:
|
||||
where["userId"] = user_id
|
||||
if client_id:
|
||||
where["clientId"] = client_id
|
||||
|
||||
return await prisma.credentialgrant.find_first(where=where) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def get_grants_for_user_client(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
include_revoked: bool = False,
|
||||
include_expired: bool = False,
|
||||
) -> list[CredentialGrant]:
|
||||
"""
|
||||
Get all credential grants for a user-client pair.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client_id: Client database ID
|
||||
include_revoked: Include revoked grants
|
||||
include_expired: Include expired grants
|
||||
|
||||
Returns:
|
||||
List of CredentialGrant objects
|
||||
"""
|
||||
where: dict[str, str | None] = {
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
}
|
||||
|
||||
if not include_revoked:
|
||||
where["revokedAt"] = None
|
||||
|
||||
grants = await prisma.credentialgrant.find_many(
|
||||
where=where, # type: ignore[arg-type]
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
|
||||
# Filter expired if needed
|
||||
if not include_expired:
|
||||
now = datetime.now(timezone.utc)
|
||||
grants = [g for g in grants if g.expiresAt is None or g.expiresAt > now]
|
||||
|
||||
return grants
|
||||
|
||||
|
||||
async def get_grants_for_credential(
|
||||
user_id: str,
|
||||
credential_id: str,
|
||||
) -> list[CredentialGrant]:
|
||||
"""
|
||||
Get all active grants for a specific credential.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
credential_id: Credential ID
|
||||
|
||||
Returns:
|
||||
List of active CredentialGrant objects
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
grants = await prisma.credentialgrant.find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialId": credential_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
include={"Client": True},
|
||||
)
|
||||
|
||||
# Filter expired
|
||||
return [g for g in grants if g.expiresAt is None or g.expiresAt > now]
|
||||
|
||||
|
||||
async def get_grant_by_credential_and_client(
|
||||
user_id: str,
|
||||
credential_id: str,
|
||||
client_id: str,
|
||||
) -> Optional[CredentialGrant]:
|
||||
"""
|
||||
Get the grant for a specific credential and client.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
credential_id: Credential ID
|
||||
client_id: Client database ID
|
||||
|
||||
Returns:
|
||||
CredentialGrant or None
|
||||
"""
|
||||
return await prisma.credentialgrant.find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialId": credential_id,
|
||||
"clientId": client_id,
|
||||
"revokedAt": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def update_grant_scopes(
|
||||
grant_id: str,
|
||||
granted_scopes: list[str],
|
||||
) -> CredentialGrant:
|
||||
"""
|
||||
Update the granted scopes for a credential grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
granted_scopes: New list of granted scopes
|
||||
|
||||
Returns:
|
||||
Updated CredentialGrant
|
||||
"""
|
||||
result = await prisma.credentialgrant.update(
|
||||
where={"id": grant_id},
|
||||
data={"grantedScopes": granted_scopes},
|
||||
)
|
||||
if result is None:
|
||||
raise ValueError(f"Grant {grant_id} not found")
|
||||
return result
|
||||
|
||||
|
||||
async def update_grant_last_used(grant_id: str) -> None:
|
||||
"""
|
||||
Update the lastUsedAt timestamp for a grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
"""
|
||||
await prisma.credentialgrant.update(
|
||||
where={"id": grant_id},
|
||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
async def revoke_grant(grant_id: str) -> CredentialGrant:
|
||||
"""
|
||||
Revoke a credential grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
|
||||
Returns:
|
||||
Revoked CredentialGrant
|
||||
"""
|
||||
result = await prisma.credentialgrant.update(
|
||||
where={"id": grant_id},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
if result is None:
|
||||
raise ValueError(f"Grant {grant_id} not found")
|
||||
return result
|
||||
|
||||
|
||||
async def revoke_grants_for_credential(
|
||||
user_id: str,
|
||||
credential_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all grants for a specific credential.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
credential_id: Credential ID
|
||||
|
||||
Returns:
|
||||
Number of grants revoked
|
||||
"""
|
||||
return await prisma.credentialgrant.update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialId": credential_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
async def revoke_grants_for_client(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all grants for a specific client.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client_id: Client database ID
|
||||
|
||||
Returns:
|
||||
Number of grants revoked
|
||||
"""
|
||||
return await prisma.credentialgrant.update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
async def delete_grant(grant_id: str) -> None:
|
||||
"""
|
||||
Permanently delete a credential grant.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
"""
|
||||
await prisma.credentialgrant.delete(where={"id": grant_id})
|
||||
|
||||
|
||||
async def check_grant_permission(
|
||||
grant_id: str,
|
||||
required_permission: CredentialGrantPermission,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a grant has a specific permission.
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
required_permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if grant has the permission
|
||||
"""
|
||||
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
|
||||
if not grant:
|
||||
return False
|
||||
|
||||
return required_permission in grant.permissions
|
||||
|
||||
|
||||
async def is_grant_valid(grant_id: str) -> bool:
|
||||
"""
|
||||
Check if a grant is valid (not revoked and not expired).
|
||||
|
||||
Args:
|
||||
grant_id: Grant ID
|
||||
|
||||
Returns:
|
||||
True if grant is valid
|
||||
"""
|
||||
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
|
||||
if not grant:
|
||||
return False
|
||||
|
||||
if grant.revokedAt:
|
||||
return False
|
||||
|
||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -71,6 +71,13 @@ logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class GrantResolverContext(BaseModel):
|
||||
"""Context for grant-based credential resolution in external API executions."""
|
||||
|
||||
client_db_id: str # The OAuth client database UUID
|
||||
grant_ids: list[str] # List of grant IDs to use for credential resolution
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""
|
||||
Unified context that carries execution-level data throughout the entire execution flow.
|
||||
@@ -81,6 +88,8 @@ class ExecutionContext(BaseModel):
|
||||
user_timezone: str = "UTC"
|
||||
root_execution_id: Optional[str] = None
|
||||
parent_execution_id: Optional[str] = None
|
||||
# For external API executions using credential grants
|
||||
grant_resolver_context: Optional[GrantResolverContext] = None
|
||||
|
||||
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
302
autogpt_platform/backend/backend/data/integration_scopes.py
Normal file
302
autogpt_platform/backend/backend/data/integration_scopes.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Integration scopes mapping.
|
||||
|
||||
Maps AutoGPT's fine-grained integration scopes to provider-specific OAuth scopes.
|
||||
These scopes are used to request granular permissions when connecting integrations
|
||||
through the Credential Broker.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class IntegrationScope(str, Enum):
|
||||
"""
|
||||
Fine-grained integration scopes for credential grants.
|
||||
|
||||
Format: {provider}:{resource}.{permission}
|
||||
"""
|
||||
|
||||
# Google scopes
|
||||
GOOGLE_EMAIL_READ = "google:email.read"
|
||||
GOOGLE_GMAIL_READONLY = "google:gmail.readonly"
|
||||
GOOGLE_GMAIL_SEND = "google:gmail.send"
|
||||
GOOGLE_GMAIL_MODIFY = "google:gmail.modify"
|
||||
GOOGLE_DRIVE_READONLY = "google:drive.readonly"
|
||||
GOOGLE_DRIVE_FILE = "google:drive.file"
|
||||
GOOGLE_CALENDAR_READONLY = "google:calendar.readonly"
|
||||
GOOGLE_CALENDAR_EVENTS = "google:calendar.events"
|
||||
GOOGLE_SHEETS_READONLY = "google:sheets.readonly"
|
||||
GOOGLE_SHEETS = "google:sheets"
|
||||
GOOGLE_DOCS_READONLY = "google:docs.readonly"
|
||||
GOOGLE_DOCS = "google:docs"
|
||||
|
||||
# GitHub scopes
|
||||
GITHUB_REPOS_READ = "github:repos.read"
|
||||
GITHUB_REPOS_WRITE = "github:repos.write"
|
||||
GITHUB_ISSUES_READ = "github:issues.read"
|
||||
GITHUB_ISSUES_WRITE = "github:issues.write"
|
||||
GITHUB_USER_READ = "github:user.read"
|
||||
GITHUB_GISTS = "github:gists"
|
||||
GITHUB_NOTIFICATIONS = "github:notifications"
|
||||
|
||||
# Discord scopes
|
||||
DISCORD_IDENTIFY = "discord:identify"
|
||||
DISCORD_EMAIL = "discord:email"
|
||||
DISCORD_GUILDS = "discord:guilds"
|
||||
DISCORD_MESSAGES_READ = "discord:messages.read"
|
||||
|
||||
# Twitter scopes
|
||||
TWITTER_READ = "twitter:read"
|
||||
TWITTER_WRITE = "twitter:write"
|
||||
TWITTER_DM = "twitter:dm"
|
||||
|
||||
# Notion scopes
|
||||
NOTION_READ = "notion:read"
|
||||
NOTION_WRITE = "notion:write"
|
||||
|
||||
# Todoist scopes
|
||||
TODOIST_READ = "todoist:read"
|
||||
TODOIST_WRITE = "todoist:write"
|
||||
|
||||
|
||||
# Scope descriptions for consent UI
|
||||
INTEGRATION_SCOPE_DESCRIPTIONS: dict[str, str] = {
|
||||
# Google
|
||||
IntegrationScope.GOOGLE_EMAIL_READ.value: "Read your email address",
|
||||
IntegrationScope.GOOGLE_GMAIL_READONLY.value: "Read your Gmail messages",
|
||||
IntegrationScope.GOOGLE_GMAIL_SEND.value: "Send emails on your behalf",
|
||||
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: "Read, send, and manage your emails",
|
||||
IntegrationScope.GOOGLE_DRIVE_READONLY.value: "View files in your Google Drive",
|
||||
IntegrationScope.GOOGLE_DRIVE_FILE.value: "Create and edit files in Google Drive",
|
||||
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: "View your calendar",
|
||||
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: "Create and edit calendar events",
|
||||
IntegrationScope.GOOGLE_SHEETS_READONLY.value: "View your spreadsheets",
|
||||
IntegrationScope.GOOGLE_SHEETS.value: "Create and edit spreadsheets",
|
||||
IntegrationScope.GOOGLE_DOCS_READONLY.value: "View your documents",
|
||||
IntegrationScope.GOOGLE_DOCS.value: "Create and edit documents",
|
||||
# GitHub
|
||||
IntegrationScope.GITHUB_REPOS_READ.value: "Read repository information",
|
||||
IntegrationScope.GITHUB_REPOS_WRITE.value: "Create and manage repositories",
|
||||
IntegrationScope.GITHUB_ISSUES_READ.value: "Read issues and pull requests",
|
||||
IntegrationScope.GITHUB_ISSUES_WRITE.value: "Create and manage issues",
|
||||
IntegrationScope.GITHUB_USER_READ.value: "Read your GitHub profile",
|
||||
IntegrationScope.GITHUB_GISTS.value: "Create and manage gists",
|
||||
IntegrationScope.GITHUB_NOTIFICATIONS.value: "Access notifications",
|
||||
# Discord
|
||||
IntegrationScope.DISCORD_IDENTIFY.value: "Access your Discord username",
|
||||
IntegrationScope.DISCORD_EMAIL.value: "Access your Discord email",
|
||||
IntegrationScope.DISCORD_GUILDS.value: "View your server list",
|
||||
IntegrationScope.DISCORD_MESSAGES_READ.value: "Read messages",
|
||||
# Twitter
|
||||
IntegrationScope.TWITTER_READ.value: "Read tweets and profile",
|
||||
IntegrationScope.TWITTER_WRITE.value: "Post tweets on your behalf",
|
||||
IntegrationScope.TWITTER_DM.value: "Send and read direct messages",
|
||||
# Notion
|
||||
IntegrationScope.NOTION_READ.value: "View Notion pages",
|
||||
IntegrationScope.NOTION_WRITE.value: "Create and edit Notion pages",
|
||||
# Todoist
|
||||
IntegrationScope.TODOIST_READ.value: "View your tasks",
|
||||
IntegrationScope.TODOIST_WRITE.value: "Create and manage tasks",
|
||||
}
|
||||
|
||||
|
||||
# Mapping from integration scopes to provider OAuth scopes
|
||||
INTEGRATION_SCOPE_MAPPING: dict[str, dict[str, list[str]]] = {
|
||||
ProviderName.GOOGLE.value: {
|
||||
IntegrationScope.GOOGLE_EMAIL_READ.value: [
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"openid",
|
||||
],
|
||||
IntegrationScope.GOOGLE_GMAIL_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_GMAIL_SEND.value: [
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
],
|
||||
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: [
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DRIVE_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DRIVE_FILE.value: [
|
||||
"https://www.googleapis.com/auth/drive.file",
|
||||
],
|
||||
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: [
|
||||
"https://www.googleapis.com/auth/calendar.events",
|
||||
],
|
||||
IntegrationScope.GOOGLE_SHEETS_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/spreadsheets.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_SHEETS.value: [
|
||||
"https://www.googleapis.com/auth/spreadsheets",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DOCS_READONLY.value: [
|
||||
"https://www.googleapis.com/auth/documents.readonly",
|
||||
],
|
||||
IntegrationScope.GOOGLE_DOCS.value: [
|
||||
"https://www.googleapis.com/auth/documents",
|
||||
],
|
||||
},
|
||||
ProviderName.GITHUB.value: {
|
||||
IntegrationScope.GITHUB_REPOS_READ.value: [
|
||||
"repo:status",
|
||||
"public_repo",
|
||||
],
|
||||
IntegrationScope.GITHUB_REPOS_WRITE.value: [
|
||||
"repo",
|
||||
],
|
||||
IntegrationScope.GITHUB_ISSUES_READ.value: [
|
||||
"repo:status",
|
||||
],
|
||||
IntegrationScope.GITHUB_ISSUES_WRITE.value: [
|
||||
"repo",
|
||||
],
|
||||
IntegrationScope.GITHUB_USER_READ.value: [
|
||||
"read:user",
|
||||
"user:email",
|
||||
],
|
||||
IntegrationScope.GITHUB_GISTS.value: [
|
||||
"gist",
|
||||
],
|
||||
IntegrationScope.GITHUB_NOTIFICATIONS.value: [
|
||||
"notifications",
|
||||
],
|
||||
},
|
||||
ProviderName.DISCORD.value: {
|
||||
IntegrationScope.DISCORD_IDENTIFY.value: [
|
||||
"identify",
|
||||
],
|
||||
IntegrationScope.DISCORD_EMAIL.value: [
|
||||
"email",
|
||||
],
|
||||
IntegrationScope.DISCORD_GUILDS.value: [
|
||||
"guilds",
|
||||
],
|
||||
IntegrationScope.DISCORD_MESSAGES_READ.value: [
|
||||
"messages.read",
|
||||
],
|
||||
},
|
||||
ProviderName.TWITTER.value: {
|
||||
IntegrationScope.TWITTER_READ.value: [
|
||||
"tweet.read",
|
||||
"users.read",
|
||||
],
|
||||
IntegrationScope.TWITTER_WRITE.value: [
|
||||
"tweet.write",
|
||||
],
|
||||
IntegrationScope.TWITTER_DM.value: [
|
||||
"dm.read",
|
||||
"dm.write",
|
||||
],
|
||||
},
|
||||
ProviderName.NOTION.value: {
|
||||
IntegrationScope.NOTION_READ.value: [], # Notion uses workspace-level access
|
||||
IntegrationScope.NOTION_WRITE.value: [],
|
||||
},
|
||||
ProviderName.TODOIST.value: {
|
||||
IntegrationScope.TODOIST_READ.value: [
|
||||
"data:read",
|
||||
],
|
||||
IntegrationScope.TODOIST_WRITE.value: [
|
||||
"data:read_write",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_scopes(
|
||||
provider: ProviderName | str, integration_scopes: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Convert integration scopes to provider-specific OAuth scopes.
|
||||
|
||||
Args:
|
||||
provider: The provider name
|
||||
integration_scopes: List of integration scope strings
|
||||
|
||||
Returns:
|
||||
List of provider-specific OAuth scopes
|
||||
"""
|
||||
provider_value = provider.value if isinstance(provider, ProviderName) else provider
|
||||
provider_mapping = INTEGRATION_SCOPE_MAPPING.get(provider_value, {})
|
||||
|
||||
oauth_scopes: set[str] = set()
|
||||
for scope in integration_scopes:
|
||||
if scope in provider_mapping:
|
||||
oauth_scopes.update(provider_mapping[scope])
|
||||
|
||||
return list(oauth_scopes)
|
||||
|
||||
|
||||
def get_provider_for_scope(scope: str) -> Optional[ProviderName]:
|
||||
"""
|
||||
Get the provider for an integration scope.
|
||||
|
||||
Args:
|
||||
scope: Integration scope string (e.g., "google:gmail.readonly")
|
||||
|
||||
Returns:
|
||||
ProviderName or None if not recognized
|
||||
"""
|
||||
if ":" not in scope:
|
||||
return None
|
||||
|
||||
provider_prefix = scope.split(":")[0]
|
||||
|
||||
# Map prefixes to providers
|
||||
prefix_mapping = {
|
||||
"google": ProviderName.GOOGLE,
|
||||
"github": ProviderName.GITHUB,
|
||||
"discord": ProviderName.DISCORD,
|
||||
"twitter": ProviderName.TWITTER,
|
||||
"notion": ProviderName.NOTION,
|
||||
"todoist": ProviderName.TODOIST,
|
||||
}
|
||||
|
||||
return prefix_mapping.get(provider_prefix)
|
||||
|
||||
|
||||
def validate_integration_scopes(scopes: list[str]) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate a list of integration scopes.
|
||||
|
||||
Args:
|
||||
scopes: List of integration scope strings
|
||||
|
||||
Returns:
|
||||
Tuple of (valid, invalid_scopes)
|
||||
"""
|
||||
valid_scopes = {s.value for s in IntegrationScope}
|
||||
invalid = [s for s in scopes if s not in valid_scopes]
|
||||
return len(invalid) == 0, invalid
|
||||
|
||||
|
||||
def group_scopes_by_provider(
|
||||
scopes: list[str],
|
||||
) -> dict[ProviderName, list[str]]:
|
||||
"""
|
||||
Group integration scopes by their provider.
|
||||
|
||||
Args:
|
||||
scopes: List of integration scope strings
|
||||
|
||||
Returns:
|
||||
Dictionary mapping providers to their scopes
|
||||
"""
|
||||
grouped: dict[ProviderName, list[str]] = {}
|
||||
|
||||
for scope in scopes:
|
||||
provider = get_provider_for_scope(scope)
|
||||
if provider:
|
||||
if provider not in grouped:
|
||||
grouped[provider] = []
|
||||
grouped[provider].append(scope)
|
||||
|
||||
return grouped
|
||||
176
autogpt_platform/backend/backend/data/oauth_audit.py
Normal file
176
autogpt_platform/backend/backend/data/oauth_audit.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
OAuth Audit Logging.
|
||||
|
||||
Logs all OAuth-related operations for security auditing and compliance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.data.db import prisma
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEventType(str, Enum):
|
||||
"""Types of OAuth events to audit."""
|
||||
|
||||
# Client events
|
||||
CLIENT_REGISTERED = "client.registered"
|
||||
CLIENT_UPDATED = "client.updated"
|
||||
CLIENT_DELETED = "client.deleted"
|
||||
CLIENT_SECRET_ROTATED = "client.secret_rotated"
|
||||
CLIENT_SUSPENDED = "client.suspended"
|
||||
CLIENT_ACTIVATED = "client.activated"
|
||||
|
||||
# Authorization events
|
||||
AUTHORIZATION_REQUESTED = "authorization.requested"
|
||||
AUTHORIZATION_GRANTED = "authorization.granted"
|
||||
AUTHORIZATION_DENIED = "authorization.denied"
|
||||
AUTHORIZATION_REVOKED = "authorization.revoked"
|
||||
|
||||
# Token events
|
||||
TOKEN_ISSUED = "token.issued"
|
||||
TOKEN_REFRESHED = "token.refreshed"
|
||||
TOKEN_REVOKED = "token.revoked"
|
||||
TOKEN_EXPIRED = "token.expired"
|
||||
|
||||
# Grant events
|
||||
GRANT_CREATED = "grant.created"
|
||||
GRANT_UPDATED = "grant.updated"
|
||||
GRANT_REVOKED = "grant.revoked"
|
||||
GRANT_USED = "grant.used"
|
||||
|
||||
# Credential events
|
||||
CREDENTIAL_CONNECTED = "credential.connected"
|
||||
CREDENTIAL_DELETED = "credential.deleted"
|
||||
|
||||
# Execution events
|
||||
EXECUTION_STARTED = "execution.started"
|
||||
EXECUTION_COMPLETED = "execution.completed"
|
||||
EXECUTION_FAILED = "execution.failed"
|
||||
EXECUTION_CANCELLED = "execution.cancelled"
|
||||
|
||||
|
||||
async def log_oauth_event(
|
||||
event_type: OAuthEventType,
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
grant_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
details: Optional[dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Log an OAuth audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
user_id: User ID involved (if any)
|
||||
client_id: OAuth client ID involved (if any)
|
||||
grant_id: Grant ID involved (if any)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
details: Additional event details
|
||||
|
||||
Returns:
|
||||
ID of the created audit log entry
|
||||
"""
|
||||
try:
|
||||
from prisma import Json
|
||||
|
||||
audit_entry = await prisma.oauthauditlog.create(
|
||||
data={
|
||||
"eventType": event_type.value,
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"grantId": grant_id,
|
||||
"ipAddress": ip_address,
|
||||
"userAgent": user_agent,
|
||||
"details": Json(details or {}),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"OAuth audit: {event_type.value} - "
|
||||
f"user={user_id}, client={client_id}, grant={grant_id}"
|
||||
)
|
||||
|
||||
return audit_entry.id
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail the operation if audit logging fails
|
||||
logger.error(f"Failed to create OAuth audit log: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def get_audit_logs(
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
event_type: Optional[OAuthEventType] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list:
|
||||
"""
|
||||
Query OAuth audit logs.
|
||||
|
||||
Args:
|
||||
user_id: Filter by user ID
|
||||
client_id: Filter by client ID
|
||||
event_type: Filter by event type
|
||||
start_date: Filter by start date
|
||||
end_date: Filter by end date
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of audit log entries
|
||||
"""
|
||||
where: dict[str, Any] = {}
|
||||
|
||||
if user_id:
|
||||
where["userId"] = user_id
|
||||
if client_id:
|
||||
where["clientId"] = client_id
|
||||
if event_type:
|
||||
where["eventType"] = event_type.value
|
||||
if start_date:
|
||||
where["createdAt"] = {"gte": start_date}
|
||||
if end_date:
|
||||
if "createdAt" in where:
|
||||
where["createdAt"]["lte"] = end_date
|
||||
else:
|
||||
where["createdAt"] = {"lte": end_date}
|
||||
|
||||
return await prisma.oauthauditlog.find_many(
|
||||
where=where if where else None, # type: ignore[arg-type]
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_old_audit_logs(days_to_keep: int = 90) -> int:
|
||||
"""
|
||||
Delete audit logs older than the specified number of days.
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days of logs to retain
|
||||
|
||||
Returns:
|
||||
Number of logs deleted
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
result = await prisma.oauthauditlog.delete_many(
|
||||
where={"createdAt": {"lt": cutoff_date}}
|
||||
)
|
||||
|
||||
logger.info(f"Cleaned up {result} OAuth audit logs older than {days_to_keep} days")
|
||||
return result
|
||||
@@ -221,11 +221,31 @@ async def execute_node(
|
||||
creds_locks: list[AsyncRedisLock] = []
|
||||
input_model = cast(type[BlockSchema], node_block.input_schema)
|
||||
|
||||
# Check if this is an external API execution using grant-based credential resolution
|
||||
grant_resolver = None
|
||||
if execution_context and execution_context.grant_resolver_context:
|
||||
from backend.integrations.grant_resolver import GrantBasedCredentialResolver
|
||||
|
||||
grant_ctx = execution_context.grant_resolver_context
|
||||
grant_resolver = GrantBasedCredentialResolver(
|
||||
user_id=user_id,
|
||||
client_id=grant_ctx.client_db_id,
|
||||
grant_ids=grant_ctx.grant_ids,
|
||||
)
|
||||
await grant_resolver.initialize()
|
||||
|
||||
# Handle regular credentials fields
|
||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||
credentials_meta = input_type(**input_data[field_name])
|
||||
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||||
creds_locks.append(lock)
|
||||
if grant_resolver:
|
||||
# External API execution - use grant resolver (no locking needed)
|
||||
credentials = await grant_resolver.resolve_credential(credentials_meta.id)
|
||||
else:
|
||||
# Normal execution - use credentials manager with locking
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||
@@ -243,10 +263,17 @@ async def execute_node(
|
||||
)
|
||||
file_name = field_data.get("name", "selected file")
|
||||
try:
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, cred_id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
if grant_resolver:
|
||||
# External API execution - use grant resolver
|
||||
credentials = await grant_resolver.resolve_credential(
|
||||
cred_id
|
||||
)
|
||||
else:
|
||||
# Normal execution - use credentials manager
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, cred_id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
extra_exec_kwargs[kwarg_name] = credentials
|
||||
except ValueError:
|
||||
# Credential was deleted or doesn't exist
|
||||
|
||||
278
autogpt_platform/backend/backend/integrations/grant_resolver.py
Normal file
278
autogpt_platform/backend/backend/integrations/grant_resolver.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Grant-Based Credential Resolver.
|
||||
|
||||
Resolves credentials during agent execution based on credential grants.
|
||||
External applications can only use credentials they have been granted access to,
|
||||
and only for the scopes that were granted.
|
||||
|
||||
Credentials are NEVER exposed to external applications - this resolver
|
||||
provides the credentials to the execution engine internally.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from prisma.enums import CredentialGrantPermission
|
||||
from prisma.models import CredentialGrant
|
||||
|
||||
from backend.data import credential_grants as grants_db
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrantValidationError(Exception):
|
||||
"""Raised when a grant is invalid or lacks required permissions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialNotFoundError(Exception):
|
||||
"""Raised when a credential referenced by a grant is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScopeMismatchError(Exception):
|
||||
"""Raised when the grant doesn't cover required scopes."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GrantBasedCredentialResolver:
|
||||
"""
|
||||
Resolves credentials for agent execution based on credential grants.
|
||||
|
||||
This resolver validates that:
|
||||
1. The grant exists and is valid (not revoked/expired)
|
||||
2. The grant has USE permission
|
||||
3. The grant covers the required scopes (if specified)
|
||||
4. The underlying credential exists
|
||||
|
||||
Then it provides the credential to the execution engine internally.
|
||||
The credential value is NEVER exposed to external applications.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
grant_ids: list[str],
|
||||
):
|
||||
"""
|
||||
Initialize the resolver.
|
||||
|
||||
Args:
|
||||
user_id: User ID who owns the credentials
|
||||
client_id: Database ID of the OAuth client
|
||||
grant_ids: List of grant IDs the client is using for this execution
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.client_id = client_id
|
||||
self.grant_ids = grant_ids
|
||||
self._grants: dict[str, CredentialGrant] = {}
|
||||
self._credentials_manager = IntegrationCredentialsManager()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Load and validate all grants.
|
||||
|
||||
This should be called before any credential resolution.
|
||||
|
||||
Raises:
|
||||
GrantValidationError: If any grant is invalid
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for grant_id in self.grant_ids:
|
||||
grant = await grants_db.get_credential_grant(
|
||||
grant_id=grant_id,
|
||||
user_id=self.user_id,
|
||||
client_id=self.client_id,
|
||||
)
|
||||
|
||||
if not grant:
|
||||
raise GrantValidationError(f"Grant {grant_id} not found")
|
||||
|
||||
# Check if revoked
|
||||
if grant.revokedAt:
|
||||
raise GrantValidationError(f"Grant {grant_id} has been revoked")
|
||||
|
||||
# Check if expired
|
||||
if grant.expiresAt and grant.expiresAt < now:
|
||||
raise GrantValidationError(f"Grant {grant_id} has expired")
|
||||
|
||||
# Check USE permission
|
||||
if CredentialGrantPermission.USE not in grant.permissions:
|
||||
raise GrantValidationError(
|
||||
f"Grant {grant_id} does not have USE permission"
|
||||
)
|
||||
|
||||
self._grants[grant_id] = grant
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
f"Initialized grant resolver with {len(self._grants)} grants "
|
||||
f"for user {self.user_id}, client {self.client_id}"
|
||||
)
|
||||
|
||||
async def resolve_credential(
|
||||
self,
|
||||
credential_id: str,
|
||||
required_scopes: Optional[list[str]] = None,
|
||||
) -> Credentials:
|
||||
"""
|
||||
Resolve a credential for agent execution.
|
||||
|
||||
This method:
|
||||
1. Finds a grant that covers this credential
|
||||
2. Validates the grant covers required scopes
|
||||
3. Retrieves the actual credential
|
||||
4. Updates grant usage tracking
|
||||
|
||||
Args:
|
||||
credential_id: ID of the credential to resolve
|
||||
required_scopes: Optional list of scopes the credential must have
|
||||
|
||||
Returns:
|
||||
The resolved Credentials object
|
||||
|
||||
Raises:
|
||||
GrantValidationError: If no valid grant covers this credential
|
||||
ScopeMismatchError: If the grant doesn't cover required scopes
|
||||
CredentialNotFoundError: If the underlying credential doesn't exist
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("Resolver not initialized. Call initialize() first.")
|
||||
|
||||
# Find a grant that covers this credential
|
||||
matching_grant: Optional[CredentialGrant] = None
|
||||
for grant in self._grants.values():
|
||||
if grant.credentialId == credential_id:
|
||||
matching_grant = grant
|
||||
break
|
||||
|
||||
if not matching_grant:
|
||||
raise GrantValidationError(f"No grant found for credential {credential_id}")
|
||||
|
||||
# Validate scopes if required
|
||||
if required_scopes:
|
||||
granted_scopes = set(matching_grant.grantedScopes)
|
||||
required_scopes_set = set(required_scopes)
|
||||
|
||||
missing_scopes = required_scopes_set - granted_scopes
|
||||
if missing_scopes:
|
||||
raise ScopeMismatchError(
|
||||
f"Grant {matching_grant.id} is missing required scopes: "
|
||||
f"{', '.join(missing_scopes)}"
|
||||
)
|
||||
|
||||
# Get the actual credential
|
||||
credentials = await self._credentials_manager.get(
|
||||
user_id=self.user_id,
|
||||
credentials_id=credential_id,
|
||||
lock=True,
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
raise CredentialNotFoundError(
|
||||
f"Credential {credential_id} not found for user {self.user_id}"
|
||||
)
|
||||
|
||||
# Update last used timestamp for the grant
|
||||
await grants_db.update_grant_last_used(matching_grant.id)
|
||||
|
||||
logger.debug(
|
||||
f"Resolved credential {credential_id} via grant {matching_grant.id} "
|
||||
f"for client {self.client_id}"
|
||||
)
|
||||
|
||||
return credentials
|
||||
|
||||
async def get_available_credentials(self) -> list[dict]:
|
||||
"""
|
||||
Get list of available credentials based on grants.
|
||||
|
||||
Returns a list of credential metadata (NOT the actual credential values).
|
||||
|
||||
Returns:
|
||||
List of dicts with credential metadata
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("Resolver not initialized. Call initialize() first.")
|
||||
|
||||
credentials_info = []
|
||||
for grant in self._grants.values():
|
||||
credentials_info.append(
|
||||
{
|
||||
"grant_id": grant.id,
|
||||
"credential_id": grant.credentialId,
|
||||
"provider": grant.provider,
|
||||
"granted_scopes": grant.grantedScopes,
|
||||
}
|
||||
)
|
||||
|
||||
return credentials_info
|
||||
|
||||
def get_grant_for_credential(self, credential_id: str) -> Optional[CredentialGrant]:
|
||||
"""
|
||||
Get the grant for a specific credential.
|
||||
|
||||
Args:
|
||||
credential_id: ID of the credential
|
||||
|
||||
Returns:
|
||||
CredentialGrant or None if not found
|
||||
"""
|
||||
for grant in self._grants.values():
|
||||
if grant.credentialId == credential_id:
|
||||
return grant
|
||||
return None
|
||||
|
||||
|
||||
async def create_resolver_from_oauth_token(
|
||||
user_id: str,
|
||||
client_public_id: str,
|
||||
grant_ids: Optional[list[str]] = None,
|
||||
) -> GrantBasedCredentialResolver:
|
||||
"""
|
||||
Create a credential resolver from OAuth token context.
|
||||
|
||||
This is a convenience function for creating a resolver from
|
||||
the context available in OAuth-authenticated requests.
|
||||
|
||||
Args:
|
||||
user_id: User ID from the OAuth token
|
||||
client_public_id: Public client ID from the OAuth token
|
||||
grant_ids: Optional list of grant IDs to use
|
||||
|
||||
Returns:
|
||||
Initialized GrantBasedCredentialResolver
|
||||
"""
|
||||
# Look up the OAuth client database ID from the public client ID
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": client_public_id})
|
||||
if not client:
|
||||
raise GrantValidationError(f"OAuth client {client_public_id} not found")
|
||||
|
||||
# If no grant IDs specified, get all grants for this client+user
|
||||
if grant_ids is None:
|
||||
grants = await grants_db.get_grants_for_user_client(
|
||||
user_id=user_id,
|
||||
client_id=client.id,
|
||||
include_revoked=False,
|
||||
include_expired=False,
|
||||
)
|
||||
grant_ids = [g.id for g in grants]
|
||||
|
||||
resolver = GrantBasedCredentialResolver(
|
||||
user_id=user_id,
|
||||
client_id=client.id,
|
||||
grant_ids=grant_ids,
|
||||
)
|
||||
await resolver.initialize()
|
||||
|
||||
return resolver
|
||||
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Webhook Notification System for External API.
|
||||
|
||||
Sends webhook notifications to external applications for execution events.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import weakref
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Coroutine, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Webhook delivery settings
|
||||
WEBHOOK_TIMEOUT_SECONDS = 30
|
||||
WEBHOOK_MAX_RETRIES = 3
|
||||
WEBHOOK_RETRY_DELAYS = [5, 30, 300] # seconds: 5s, 30s, 5min
|
||||
|
||||
|
||||
class WebhookDeliveryError(Exception):
|
||||
"""Raised when webhook delivery fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def sign_webhook_payload(payload: dict[str, Any], secret: str) -> str:
|
||||
"""
|
||||
Create HMAC-SHA256 signature for webhook payload.
|
||||
|
||||
Args:
|
||||
payload: The webhook payload to sign
|
||||
secret: The webhook secret key
|
||||
|
||||
Returns:
|
||||
Hex-encoded HMAC-SHA256 signature
|
||||
"""
|
||||
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
|
||||
signature = hmac.new(
|
||||
secret.encode(),
|
||||
payload_bytes,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return signature
|
||||
|
||||
|
||||
def verify_webhook_signature(
|
||||
payload: dict[str, Any],
|
||||
signature: str,
|
||||
secret: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify a webhook signature.
|
||||
|
||||
Args:
|
||||
payload: The webhook payload
|
||||
signature: The signature to verify
|
||||
secret: The webhook secret key
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
"""
|
||||
expected = sign_webhook_payload(payload, secret)
|
||||
return hmac.compare_digest(expected, signature)
|
||||
|
||||
|
||||
def validate_webhook_url(url: str, allowed_domains: list[str]) -> bool:
|
||||
"""
|
||||
Validate that a webhook URL is allowed.
|
||||
|
||||
Args:
|
||||
url: The webhook URL to validate
|
||||
allowed_domains: List of allowed domains (from OAuth client config)
|
||||
|
||||
Returns:
|
||||
True if URL is valid and allowed
|
||||
"""
|
||||
from backend.util.url import hostname_matches_any_domain
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Must be HTTPS (except for localhost in development)
|
||||
if parsed.scheme != "https":
|
||||
if not (
|
||||
parsed.scheme == "http"
|
||||
and parsed.hostname in ["localhost", "127.0.0.1"]
|
||||
):
|
||||
return False
|
||||
|
||||
# Must have a host
|
||||
if not parsed.hostname:
|
||||
return False
|
||||
|
||||
# Check against allowed domains
|
||||
return hostname_matches_any_domain(parsed.hostname, allowed_domains)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def send_webhook(
|
||||
url: str,
|
||||
payload: dict[str, Any],
|
||||
secret: Optional[str] = None,
|
||||
timeout: int = WEBHOOK_TIMEOUT_SECONDS,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a webhook notification.
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
payload: Payload to send
|
||||
secret: Optional secret for signature
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if webhook was delivered successfully
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "AutoGPT-Webhook/1.0",
|
||||
"X-Webhook-Timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
if secret:
|
||||
signature = sign_webhook_payload(payload, secret)
|
||||
headers["X-Webhook-Signature"] = f"sha256={signature}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
logger.debug(f"Webhook delivered successfully to {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Webhook delivery failed: {url} returned {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning(f"Webhook delivery timed out: {url}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Webhook delivery error: {url} - {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def send_webhook_with_retry(
|
||||
url: str,
|
||||
payload: dict[str, Any],
|
||||
secret: Optional[str] = None,
|
||||
max_retries: int = WEBHOOK_MAX_RETRIES,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a webhook with automatic retries.
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
payload: Payload to send
|
||||
secret: Optional secret for signature
|
||||
max_retries: Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
True if webhook was eventually delivered successfully
|
||||
"""
|
||||
for attempt in range(max_retries + 1):
|
||||
if await send_webhook(url, payload, secret):
|
||||
return True
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = WEBHOOK_RETRY_DELAYS[min(attempt, len(WEBHOOK_RETRY_DELAYS) - 1)]
|
||||
logger.info(
|
||||
f"Webhook delivery failed, retrying in {delay}s (attempt {attempt + 1})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error(f"Webhook delivery failed after {max_retries} retries: {url}")
|
||||
return False
|
||||
|
||||
|
||||
# Track pending webhook tasks to prevent garbage collection
|
||||
# Using WeakSet so tasks are automatically removed when they complete and are dereferenced
|
||||
_pending_webhook_tasks: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
|
||||
|
||||
|
||||
def _create_tracked_task(coro: Coroutine[Any, Any, bool]) -> asyncio.Task[bool]:
|
||||
"""Create a task that is tracked to prevent garbage collection."""
|
||||
task = asyncio.create_task(coro)
|
||||
_pending_webhook_tasks.add(task)
|
||||
# No explicit done callback needed - WeakSet automatically removes
|
||||
# references when tasks are garbage collected after completion
|
||||
return task
|
||||
|
||||
|
||||
class WebhookNotifier:
|
||||
"""
|
||||
Service for sending webhook notifications to external applications.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def notify_execution_started(
|
||||
self,
|
||||
execution_id: str,
|
||||
agent_id: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that an execution has started.
|
||||
"""
|
||||
payload = {
|
||||
"event": "execution.started",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"agent_id": agent_id,
|
||||
"status": "running",
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
async def notify_execution_completed(
|
||||
self,
|
||||
execution_id: str,
|
||||
agent_id: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
outputs: dict[str, Any],
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that an execution has completed successfully.
|
||||
"""
|
||||
payload = {
|
||||
"event": "execution.completed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"agent_id": agent_id,
|
||||
"status": "completed",
|
||||
"outputs": outputs,
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
async def notify_execution_failed(
|
||||
self,
|
||||
execution_id: str,
|
||||
agent_id: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
error: str,
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that an execution has failed.
|
||||
"""
|
||||
payload = {
|
||||
"event": "execution.failed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"agent_id": agent_id,
|
||||
"status": "failed",
|
||||
"error": error,
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
async def notify_grant_revoked(
|
||||
self,
|
||||
grant_id: str,
|
||||
credential_id: str,
|
||||
provider: str,
|
||||
client_id: str,
|
||||
webhook_url: str,
|
||||
webhook_secret: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Notify external app that a credential grant has been revoked.
|
||||
"""
|
||||
payload = {
|
||||
"event": "grant.revoked",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"grant_id": grant_id,
|
||||
"credential_id": credential_id,
|
||||
"provider": provider,
|
||||
},
|
||||
}
|
||||
|
||||
_create_tracked_task(
|
||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_webhook_notifier: Optional[WebhookNotifier] = None
|
||||
|
||||
|
||||
def get_webhook_notifier() -> WebhookNotifier:
|
||||
"""Get the singleton webhook notifier instance."""
|
||||
global _webhook_notifier
|
||||
if _webhook_notifier is None:
|
||||
_webhook_notifier = WebhookNotifier()
|
||||
return _webhook_notifier
|
||||
@@ -3,6 +3,8 @@ from fastapi import FastAPI
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.execution import execution_router
|
||||
from .routes.grants import grants_router
|
||||
from .routes.integrations import integrations_router
|
||||
from .routes.tools import tools_router
|
||||
from .routes.v1 import v1_router
|
||||
@@ -18,6 +20,8 @@ external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
external_app.include_router(tools_router, prefix="/v1")
|
||||
external_app.include_router(integrations_router, prefix="/v1")
|
||||
external_app.include_router(grants_router, prefix="/v1")
|
||||
external_app.include_router(execution_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
|
||||
164
autogpt_platform/backend/backend/server/external/oauth_middleware.py
vendored
Normal file
164
autogpt_platform/backend/backend/server/external/oauth_middleware.py
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
OAuth Access Token middleware for external API.
|
||||
|
||||
Validates OAuth access tokens and provides user/client context
|
||||
for external API endpoints that use OAuth authentication.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.server.oauth.token_service import get_token_service
|
||||
|
||||
|
||||
class OAuthTokenInfo(BaseModel):
|
||||
"""Information extracted from a validated OAuth access token."""
|
||||
|
||||
user_id: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
token_id: str
|
||||
|
||||
|
||||
# HTTP Bearer token extractor
|
||||
oauth_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_oauth_token(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(oauth_bearer),
|
||||
) -> OAuthTokenInfo:
|
||||
"""
|
||||
Validate an OAuth access token and return token info.
|
||||
|
||||
Extracts the Bearer token from the Authorization header,
|
||||
validates the JWT signature and claims, and checks that
|
||||
the token hasn't been revoked.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if token is missing, invalid, or revoked
|
||||
"""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
token_service = get_token_service()
|
||||
|
||||
try:
|
||||
# Verify JWT signature and claims
|
||||
claims = token_service.verify_access_token(token)
|
||||
|
||||
# Check if token is in database and not revoked
|
||||
token_hash = token_service.hash_token(token)
|
||||
stored_token = await prisma.oauthaccesstoken.find_unique(
|
||||
where={"tokenHash": token_hash}
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if stored_token.revokedAt:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has been revoked",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if stored_token.expiresAt < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Update last used timestamp (fire and forget)
|
||||
await prisma.oauthaccesstoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
return OAuthTokenInfo(
|
||||
user_id=claims.sub,
|
||||
client_id=claims.client_id,
|
||||
scopes=claims.scope.split() if claims.scope else [],
|
||||
token_id=stored_token.id,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid token: {str(e)}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def require_scope(required_scope: str):
|
||||
"""
|
||||
Dependency that validates OAuth token and checks for required scope.
|
||||
|
||||
Args:
|
||||
required_scope: The scope required for this endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that returns OAuthTokenInfo if authorized
|
||||
"""
|
||||
|
||||
async def check_scope(
|
||||
token: OAuthTokenInfo = Security(require_oauth_token),
|
||||
) -> OAuthTokenInfo:
|
||||
if required_scope not in token.scopes:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Token lacks required scope '{required_scope}'",
|
||||
headers={"WWW-Authenticate": f'Bearer scope="{required_scope}"'},
|
||||
)
|
||||
return token
|
||||
|
||||
return check_scope
|
||||
|
||||
|
||||
def require_any_scope(*required_scopes: str):
|
||||
"""
|
||||
Dependency that validates OAuth token and checks for any of the required scopes.
|
||||
|
||||
Args:
|
||||
required_scopes: At least one of these scopes is required
|
||||
|
||||
Returns:
|
||||
Dependency function that returns OAuthTokenInfo if authorized
|
||||
"""
|
||||
|
||||
async def check_scopes(
|
||||
token: OAuthTokenInfo = Security(require_oauth_token),
|
||||
) -> OAuthTokenInfo:
|
||||
for scope in required_scopes:
|
||||
if scope in token.scopes:
|
||||
return token
|
||||
|
||||
scope_list = " ".join(required_scopes)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Token lacks required scopes (need one of: {scope_list})",
|
||||
headers={"WWW-Authenticate": f'Bearer scope="{scope_list}"'},
|
||||
)
|
||||
|
||||
return check_scopes
|
||||
376
autogpt_platform/backend/backend/server/external/routes/execution.py
vendored
Normal file
376
autogpt_platform/backend/backend/server/external/routes/execution.py
vendored
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Agent Execution endpoints for external OAuth clients.
|
||||
|
||||
Allows external applications to:
|
||||
- Execute agents using granted credentials
|
||||
- Poll execution status
|
||||
- Cancel running executions
|
||||
- Get available capabilities
|
||||
|
||||
External apps can only use credentials they have been granted access to.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import ExecutionContext, GrantResolverContext
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.grant_resolver import (
|
||||
GrantValidationError,
|
||||
create_resolver_from_oauth_token,
|
||||
)
|
||||
from backend.integrations.webhook_notifier import validate_webhook_url
|
||||
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
execution_router = APIRouter(prefix="/executions", tags=["executions"])
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Models
|
||||
# ================================================================
|
||||
|
||||
|
||||
class ExecuteAgentRequest(BaseModel):
|
||||
"""Request to execute an agent."""
|
||||
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Input values for the agent",
|
||||
)
|
||||
grant_ids: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="Specific grant IDs to use. If not provided, uses all available grants.",
|
||||
)
|
||||
webhook_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="URL to receive execution status webhooks",
|
||||
)
|
||||
|
||||
|
||||
class ExecuteAgentResponse(BaseModel):
|
||||
"""Response from starting an agent execution."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class ExecutionStatusResponse(BaseModel):
|
||||
"""Response with execution status."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class GrantInfo(BaseModel):
|
||||
"""Summary of a credential grant for capabilities."""
|
||||
|
||||
grant_id: str
|
||||
provider: str
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class CapabilitiesResponse(BaseModel):
|
||||
"""Response describing what the client can do."""
|
||||
|
||||
user_id: str
|
||||
client_id: str
|
||||
grants: list[GrantInfo]
|
||||
available_scopes: list[str]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
|
||||
@execution_router.get("/capabilities", response_model=CapabilitiesResponse)
|
||||
async def get_capabilities(
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> CapabilitiesResponse:
|
||||
"""
|
||||
Get the capabilities available to this client for the authenticated user.
|
||||
|
||||
Returns information about:
|
||||
- Available credential grants (NOT credential values)
|
||||
- Scopes the client has access to
|
||||
"""
|
||||
try:
|
||||
resolver = await create_resolver_from_oauth_token(
|
||||
user_id=token.user_id,
|
||||
client_public_id=token.client_id,
|
||||
)
|
||||
credentials_info = await resolver.get_available_credentials()
|
||||
|
||||
grants = [
|
||||
GrantInfo(
|
||||
grant_id=info["grant_id"],
|
||||
provider=info["provider"],
|
||||
scopes=info["granted_scopes"],
|
||||
)
|
||||
for info in credentials_info
|
||||
]
|
||||
|
||||
return CapabilitiesResponse(
|
||||
user_id=token.user_id,
|
||||
client_id=token.client_id,
|
||||
grants=grants,
|
||||
available_scopes=token.scopes,
|
||||
)
|
||||
except GrantValidationError:
|
||||
# No grants available is not an error, just empty capabilities
|
||||
return CapabilitiesResponse(
|
||||
user_id=token.user_id,
|
||||
client_id=token.client_id,
|
||||
grants=[],
|
||||
available_scopes=token.scopes,
|
||||
)
|
||||
|
||||
|
||||
@execution_router.post(
|
||||
"/agents/{agent_id}/execute",
|
||||
response_model=ExecuteAgentResponse,
|
||||
)
|
||||
async def execute_agent(
|
||||
agent_id: str,
|
||||
request: ExecuteAgentRequest,
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> ExecuteAgentResponse:
|
||||
"""
|
||||
Execute an agent using granted credentials.
|
||||
|
||||
The agent must be accessible to the user, and the client must have
|
||||
valid credential grants that satisfy the agent's requirements.
|
||||
|
||||
Args:
|
||||
agent_id: The agent (graph) ID to execute
|
||||
request: Execution parameters including inputs and optional grant IDs
|
||||
"""
|
||||
# Verify the agent exists and user has access
|
||||
# First try to get the latest version
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=None,
|
||||
user_id=token.user_id,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
# Try to find it in the store (public agents)
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=None,
|
||||
user_id=None,
|
||||
skip_access_check=True,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Agent {agent_id} not found or not accessible",
|
||||
)
|
||||
|
||||
# Initialize the grant resolver to validate grants exist
|
||||
# The resolver context will be passed to the execution engine
|
||||
grant_resolver_context = None
|
||||
try:
|
||||
resolver = await create_resolver_from_oauth_token(
|
||||
user_id=token.user_id,
|
||||
client_public_id=token.client_id,
|
||||
grant_ids=request.grant_ids,
|
||||
)
|
||||
# Get available credentials info to build resolver context
|
||||
credentials_info = await resolver.get_available_credentials()
|
||||
grant_resolver_context = GrantResolverContext(
|
||||
client_db_id=resolver.client_id,
|
||||
grant_ids=[c["grant_id"] for c in credentials_info],
|
||||
)
|
||||
except GrantValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Grant validation failed: {str(e)}",
|
||||
)
|
||||
|
||||
try:
|
||||
# Build execution context with grant resolver info
|
||||
execution_context = ExecutionContext(
|
||||
grant_resolver_context=grant_resolver_context,
|
||||
)
|
||||
|
||||
# Execute the agent with grant resolver context
|
||||
graph_exec = await add_graph_execution(
|
||||
graph_id=agent_id,
|
||||
user_id=token.user_id,
|
||||
inputs=request.inputs,
|
||||
graph_version=graph.version,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
# Log the execution for audit
|
||||
logger.info(
|
||||
f"External execution started: agent={agent_id}, "
|
||||
f"execution={graph_exec.id}, client={token.client_id}, "
|
||||
f"user={token.user_id}"
|
||||
)
|
||||
|
||||
# Register webhook if provided
|
||||
if request.webhook_url:
|
||||
# Get client to check webhook domains
|
||||
client = await prisma.oauthclient.find_unique(
|
||||
where={"clientId": token.client_id}
|
||||
)
|
||||
if client:
|
||||
if not validate_webhook_url(request.webhook_url, client.webhookDomains):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Webhook URL not in allowed domains for this client",
|
||||
)
|
||||
|
||||
# Store webhook registration
|
||||
await prisma.executionwebhook.create(
|
||||
data={ # type: ignore[typeddict-item]
|
||||
"executionId": graph_exec.id,
|
||||
"webhookUrl": request.webhook_url,
|
||||
"clientId": client.id,
|
||||
"userId": token.user_id,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Registered webhook for execution {graph_exec.id}: {request.webhook_url}"
|
||||
)
|
||||
|
||||
return ExecuteAgentResponse(
|
||||
execution_id=graph_exec.id,
|
||||
status="queued",
|
||||
message="Agent execution has been queued",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# Client error - invalid input or configuration
|
||||
logger.warning(
|
||||
f"Invalid execution request: agent={agent_id}, "
|
||||
f"client={token.client_id}, error={str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid request: {str(e)}",
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except Exception:
|
||||
# Server error - log full exception but don't expose details to client
|
||||
logger.exception(
|
||||
f"Unexpected error starting execution: agent={agent_id}, "
|
||||
f"client={token.client_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An internal error occurred while starting execution",
|
||||
)
|
||||
|
||||
|
||||
@execution_router.get(
|
||||
"/{execution_id}",
|
||||
response_model=ExecutionStatusResponse,
|
||||
)
|
||||
async def get_execution_status(
|
||||
execution_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> ExecutionStatusResponse:
|
||||
"""
|
||||
Get the status of an agent execution.
|
||||
|
||||
Returns current status, outputs (if completed), and any error messages.
|
||||
"""
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=token.user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Execution {execution_id} not found",
|
||||
)
|
||||
|
||||
# Build response
|
||||
outputs = None
|
||||
error = None
|
||||
|
||||
if graph_exec.status == AgentExecutionStatus.COMPLETED:
|
||||
outputs = graph_exec.outputs
|
||||
elif graph_exec.status == AgentExecutionStatus.FAILED:
|
||||
# Get error from execution stats
|
||||
# Note: Currently no standard error field in stats, but could be added
|
||||
error = "Execution failed"
|
||||
|
||||
return ExecutionStatusResponse(
|
||||
execution_id=execution_id,
|
||||
status=graph_exec.status.value,
|
||||
started_at=graph_exec.started_at,
|
||||
completed_at=graph_exec.ended_at,
|
||||
outputs=outputs,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
@execution_router.post("/{execution_id}/cancel")
|
||||
async def cancel_execution(
|
||||
execution_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
||||
) -> dict:
|
||||
"""
|
||||
Cancel a running agent execution.
|
||||
|
||||
Only executions in QUEUED or RUNNING status can be cancelled.
|
||||
"""
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=token.user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Execution {execution_id} not found",
|
||||
)
|
||||
|
||||
# Check if execution can be cancelled
|
||||
if graph_exec.status not in [
|
||||
AgentExecutionStatus.QUEUED,
|
||||
AgentExecutionStatus.RUNNING,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel execution with status {graph_exec.status.value}",
|
||||
)
|
||||
|
||||
# Update execution status to TERMINATED
|
||||
# Note: This is a simplified implementation. A full implementation would
|
||||
# need to signal the executor to stop processing.
|
||||
await prisma.agentgraphexecution.update(
|
||||
where={"id": execution_id},
|
||||
data={"executionStatus": AgentExecutionStatus.TERMINATED},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Execution terminated: execution={execution_id}, "
|
||||
f"client={token.client_id}, user={token.user_id}"
|
||||
)
|
||||
|
||||
return {"message": "Execution terminated", "execution_id": execution_id}
|
||||
207
autogpt_platform/backend/backend/server/external/routes/grants.py
vendored
Normal file
207
autogpt_platform/backend/backend/server/external/routes/grants.py
vendored
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Credential Grants endpoints for external OAuth clients.
|
||||
|
||||
Allows external applications to:
|
||||
- List their credential grants (metadata only, NOT credential values)
|
||||
- Get grant details
|
||||
- Delete credentials via grants (if permitted)
|
||||
|
||||
Credentials are NEVER returned to external applications.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import credential_grants as grants_db
|
||||
from backend.data.db import prisma
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
|
||||
|
||||
grants_router = APIRouter(prefix="/grants", tags=["grants"])
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Response Models
|
||||
# ================================================================
|
||||
|
||||
|
||||
class GrantSummary(BaseModel):
|
||||
"""Summary of a credential grant (returned in list endpoints)."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
granted_scopes: list[str]
|
||||
permissions: list[str]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class GrantDetail(BaseModel):
|
||||
"""Detailed grant information."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
credential_id: str
|
||||
granted_scopes: list[str]
|
||||
permissions: list[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
|
||||
@grants_router.get("/", response_model=list[GrantSummary])
|
||||
async def list_grants(
|
||||
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
|
||||
) -> list[GrantSummary]:
|
||||
"""
|
||||
List all active credential grants for this client and user.
|
||||
|
||||
Returns grant metadata but NOT credential values.
|
||||
Credentials are never exposed to external applications.
|
||||
"""
|
||||
# Get the OAuth client's database ID from the public client_id
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
||||
if not client:
|
||||
raise HTTPException(status_code=400, detail="Invalid client")
|
||||
|
||||
grants = await grants_db.get_grants_for_user_client(
|
||||
user_id=token.user_id,
|
||||
client_id=client.id,
|
||||
include_revoked=False,
|
||||
include_expired=False,
|
||||
)
|
||||
|
||||
return [
|
||||
GrantSummary(
|
||||
id=grant.id,
|
||||
provider=grant.provider,
|
||||
granted_scopes=grant.grantedScopes,
|
||||
permissions=[p.value for p in grant.permissions],
|
||||
created_at=grant.createdAt,
|
||||
last_used_at=grant.lastUsedAt,
|
||||
expires_at=grant.expiresAt,
|
||||
)
|
||||
for grant in grants
|
||||
]
|
||||
|
||||
|
||||
@grants_router.get("/{grant_id}", response_model=GrantDetail)
|
||||
async def get_grant(
|
||||
grant_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
|
||||
) -> GrantDetail:
|
||||
"""
|
||||
Get detailed information about a specific grant.
|
||||
|
||||
Returns grant metadata including scopes and permissions.
|
||||
Does NOT return the credential value.
|
||||
"""
|
||||
# Get the OAuth client's database ID
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
||||
if not client:
|
||||
raise HTTPException(status_code=400, detail="Invalid client")
|
||||
|
||||
grant = await grants_db.get_credential_grant(
|
||||
grant_id=grant_id,
|
||||
user_id=token.user_id,
|
||||
client_id=client.id,
|
||||
)
|
||||
|
||||
if not grant:
|
||||
raise HTTPException(status_code=404, detail="Grant not found")
|
||||
|
||||
# Check if expired
|
||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=404, detail="Grant has expired")
|
||||
|
||||
# Check if revoked
|
||||
if grant.revokedAt:
|
||||
raise HTTPException(status_code=404, detail="Grant has been revoked")
|
||||
|
||||
return GrantDetail(
|
||||
id=grant.id,
|
||||
provider=grant.provider,
|
||||
credential_id=grant.credentialId,
|
||||
granted_scopes=grant.grantedScopes,
|
||||
permissions=[p.value for p in grant.permissions],
|
||||
created_at=grant.createdAt,
|
||||
updated_at=grant.updatedAt,
|
||||
last_used_at=grant.lastUsedAt,
|
||||
expires_at=grant.expiresAt,
|
||||
revoked_at=grant.revokedAt,
|
||||
)
|
||||
|
||||
|
||||
@grants_router.delete("/{grant_id}/credential")
|
||||
async def delete_credential_via_grant(
|
||||
grant_id: str,
|
||||
token: OAuthTokenInfo = Security(require_scope("integrations:delete")),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete the underlying credential associated with a grant.
|
||||
|
||||
This requires the grant to have the DELETE permission.
|
||||
Deleting the credential also invalidates all grants for that credential.
|
||||
"""
|
||||
from prisma.enums import CredentialGrantPermission
|
||||
|
||||
# Get the OAuth client's database ID
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
||||
if not client:
|
||||
raise HTTPException(status_code=400, detail="Invalid client")
|
||||
|
||||
# Get the grant
|
||||
grant = await grants_db.get_credential_grant(
|
||||
grant_id=grant_id,
|
||||
user_id=token.user_id,
|
||||
client_id=client.id,
|
||||
)
|
||||
|
||||
if not grant:
|
||||
raise HTTPException(status_code=404, detail="Grant not found")
|
||||
|
||||
# Check if grant is valid
|
||||
if grant.revokedAt:
|
||||
raise HTTPException(status_code=400, detail="Grant has been revoked")
|
||||
|
||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=400, detail="Grant has expired")
|
||||
|
||||
# Check DELETE permission
|
||||
if CredentialGrantPermission.DELETE not in grant.permissions:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Grant does not have DELETE permission for this credential",
|
||||
)
|
||||
|
||||
# Delete the credential using the credentials store
|
||||
try:
|
||||
creds_store = IntegrationCredentialsStore()
|
||||
await creds_store.delete_creds_by_id(
|
||||
user_id=token.user_id,
|
||||
credentials_id=grant.credentialId,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete credential: {str(e)}",
|
||||
)
|
||||
|
||||
# Revoke all grants for this credential
|
||||
await grants_db.revoke_grants_for_credential(
|
||||
user_id=token.user_id,
|
||||
credential_id=grant.credentialId,
|
||||
)
|
||||
|
||||
return {"message": "Credential deleted successfully"}
|
||||
@@ -0,0 +1,912 @@
|
||||
"""
|
||||
Integration Connect popup endpoints.
|
||||
|
||||
Implements the popup flow for external applications to connect integrations
|
||||
on behalf of users through AutoGPT's Credential Broker.
|
||||
|
||||
Flow:
|
||||
1. External app opens popup to /connect/{provider}
|
||||
2. User sees consent page with existing credentials or option to connect new
|
||||
3. User approves, grant is created
|
||||
4. Popup sends postMessage with grant_id back to opener
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Form, Query, Request, Security
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from prisma.enums import CredentialGrantPermission
|
||||
|
||||
from backend.data.credential_grants import (
|
||||
create_credential_grant,
|
||||
get_grant_by_credential_and_client,
|
||||
update_grant_scopes,
|
||||
)
|
||||
from backend.data.db import prisma
|
||||
from backend.data.integration_scopes import (
|
||||
INTEGRATION_SCOPE_DESCRIPTIONS,
|
||||
get_provider_for_scope,
|
||||
get_provider_scopes,
|
||||
validate_integration_scopes,
|
||||
)
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.integrations.connect_security import (
|
||||
consume_connect_continuation,
|
||||
consume_connect_state,
|
||||
create_post_message_data,
|
||||
store_connect_continuation,
|
||||
store_connect_state,
|
||||
validate_nonce,
|
||||
validate_redirect_origin,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
connect_router = APIRouter(prefix="/connect", tags=["integration-connect"])
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
settings = Settings()
|
||||
|
||||
|
||||
async def _create_or_update_grant(
|
||||
user_id: str,
|
||||
credential_id: str,
|
||||
client_db_id: str,
|
||||
provider: str,
|
||||
requested_scopes: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Create a new credential grant or update existing one with merged scopes.
|
||||
|
||||
Args:
|
||||
user_id: User who owns the credential
|
||||
credential_id: ID of the credential to grant access to
|
||||
client_db_id: Database UUID of the OAuth client
|
||||
provider: Integration provider name
|
||||
requested_scopes: Scopes being requested
|
||||
|
||||
Returns:
|
||||
The grant ID (either existing or newly created)
|
||||
"""
|
||||
existing_grant = await get_grant_by_credential_and_client(
|
||||
user_id=user_id,
|
||||
credential_id=credential_id,
|
||||
client_id=client_db_id,
|
||||
)
|
||||
|
||||
if existing_grant:
|
||||
# Update scopes if needed (merge with existing)
|
||||
merged_scopes = list(set(existing_grant.grantedScopes) | set(requested_scopes))
|
||||
if set(merged_scopes) != set(existing_grant.grantedScopes):
|
||||
await update_grant_scopes(existing_grant.id, merged_scopes)
|
||||
return existing_grant.id
|
||||
|
||||
# Create new grant
|
||||
grant = await create_credential_grant(
|
||||
user_id=user_id,
|
||||
client_id=client_db_id,
|
||||
credential_id=credential_id,
|
||||
provider=provider,
|
||||
granted_scopes=requested_scopes,
|
||||
permissions=[
|
||||
CredentialGrantPermission.USE,
|
||||
CredentialGrantPermission.DELETE,
|
||||
],
|
||||
)
|
||||
return grant.id
|
||||
|
||||
|
||||
def _base_styles() -> str:
|
||||
"""Common CSS styles for connect pages."""
|
||||
return """
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
color: #e4e4e7;
|
||||
}
|
||||
.container {
|
||||
background: #27272a;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
|
||||
max-width: 450px;
|
||||
width: 100%;
|
||||
padding: 32px;
|
||||
}
|
||||
h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
color: #a1a1aa;
|
||||
font-size: 14px;
|
||||
text-align: center;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.divider {
|
||||
height: 1px;
|
||||
background: #3f3f46;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.section-title {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: #a1a1aa;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.credential-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 12px;
|
||||
border: 1px solid #3f3f46;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
margin-bottom: 8px;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.credential-item:hover {
|
||||
border-color: #22d3ee;
|
||||
background: rgba(34, 211, 238, 0.1);
|
||||
}
|
||||
.credential-item.selected {
|
||||
border-color: #22d3ee;
|
||||
background: rgba(34, 211, 238, 0.15);
|
||||
}
|
||||
.credential-item input[type="radio"] {
|
||||
display: none;
|
||||
}
|
||||
.credential-info {
|
||||
flex: 1;
|
||||
}
|
||||
.credential-title {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
}
|
||||
.credential-meta {
|
||||
font-size: 12px;
|
||||
color: #71717a;
|
||||
}
|
||||
.scope-item {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 8px;
|
||||
padding: 8px 0;
|
||||
font-size: 14px;
|
||||
}
|
||||
.scope-icon {
|
||||
color: #22d3ee;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.buttons {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
margin-top: 24px;
|
||||
}
|
||||
.btn {
|
||||
flex: 1;
|
||||
padding: 12px 24px;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.btn-cancel {
|
||||
background: #3f3f46;
|
||||
color: #e4e4e7;
|
||||
}
|
||||
.btn-cancel:hover {
|
||||
background: #52525b;
|
||||
}
|
||||
.btn-primary {
|
||||
background: #22d3ee;
|
||||
color: #0f172a;
|
||||
}
|
||||
.btn-primary:hover {
|
||||
background: #06b6d4;
|
||||
}
|
||||
.btn-connect {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
}
|
||||
.btn-connect:hover {
|
||||
background: #2563eb;
|
||||
}
|
||||
.error-message {
|
||||
background: rgba(239, 68, 68, 0.1);
|
||||
border: 1px solid rgba(239, 68, 68, 0.3);
|
||||
border-radius: 8px;
|
||||
padding: 12px;
|
||||
color: #ef4444;
|
||||
font-size: 14px;
|
||||
text-align: center;
|
||||
}
|
||||
.app-name {
|
||||
color: #22d3ee;
|
||||
font-weight: 600;
|
||||
}
|
||||
.provider-badge {
|
||||
display: inline-block;
|
||||
background: #3f3f46;
|
||||
padding: 2px 8px;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
text-transform: capitalize;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _render_connect_page(
|
||||
client_name: str,
|
||||
provider: str,
|
||||
scopes: list[str],
|
||||
credentials: list[OAuth2Credentials],
|
||||
connect_token: str,
|
||||
action_url: str,
|
||||
) -> str:
|
||||
"""Render the connect consent page."""
|
||||
# Build scopes HTML
|
||||
scopes_html = ""
|
||||
for scope in scopes:
|
||||
description = INTEGRATION_SCOPE_DESCRIPTIONS.get(scope, scope)
|
||||
scopes_html += f"""
|
||||
<div class="scope-item">
|
||||
<span class="scope-icon">✓</span>
|
||||
<span>{description}</span>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Build credentials selection HTML
|
||||
creds_html = ""
|
||||
if credentials:
|
||||
creds_html = '<div class="section-title">Select an existing credential:</div>'
|
||||
for i, cred in enumerate(credentials):
|
||||
checked = "checked" if i == 0 else ""
|
||||
selected = "selected" if i == 0 else ""
|
||||
creds_html += f"""
|
||||
<label class="credential-item {selected}" onclick="selectCredential(this)">
|
||||
<input type="radio" name="credential_id" value="{cred.id}" {checked}>
|
||||
<div class="credential-info">
|
||||
<div class="credential-title">{cred.title or cred.username or 'Credential'}</div>
|
||||
<div class="credential-meta">{cred.username or ''}</div>
|
||||
</div>
|
||||
</label>
|
||||
"""
|
||||
creds_html += '<div class="divider"></div>'
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connect {provider.title()} - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
<script>
|
||||
function selectCredential(element) {{
|
||||
document.querySelectorAll('.credential-item').forEach(el => el.classList.remove('selected'));
|
||||
element.classList.add('selected');
|
||||
element.querySelector('input[type="radio"]').checked = true;
|
||||
}}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Connect <span class="provider-badge">{provider}</span></h1>
|
||||
<p class="subtitle">
|
||||
<span class="app-name">{client_name}</span> wants to use your {provider.title()} integration
|
||||
</p>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<div class="section-title">This will allow {client_name} to:</div>
|
||||
{scopes_html}
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<form method="POST" action="{action_url}">
|
||||
<input type="hidden" name="connect_token" value="{connect_token}">
|
||||
|
||||
{creds_html}
|
||||
|
||||
{'''<div class="section-title">Or connect a new account:</div>
|
||||
<button type="submit" name="action" value="connect_new" class="btn btn-connect" style="width: 100%; margin-bottom: 16px;">
|
||||
Connect New {0} Account
|
||||
</button>'''.format(provider.title()) if credentials else '''
|
||||
<p class="subtitle" style="margin-bottom: 16px;">
|
||||
You don't have any {0} credentials yet.
|
||||
</p>
|
||||
<button type="submit" name="action" value="connect_new" class="btn btn-connect" style="width: 100%; margin-bottom: 16px;">
|
||||
Connect {0} Account
|
||||
</button>
|
||||
'''.format(provider.title())}
|
||||
|
||||
<div class="buttons">
|
||||
<button type="submit" name="action" value="deny" class="btn btn-cancel">
|
||||
Cancel
|
||||
</button>
|
||||
{'''<button type="submit" name="action" value="approve" class="btn btn-primary">
|
||||
Approve
|
||||
</button>''' if credentials else ''}
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def _render_error_page(error: str, error_description: str) -> str:
|
||||
"""Render an error page."""
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connection Error - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1 style="color: #ef4444;">Connection Failed</h1>
|
||||
<div class="error-message" style="margin-top: 24px;">
|
||||
{error_description}
|
||||
</div>
|
||||
<p class="subtitle" style="margin-top: 16px;">
|
||||
Error code: {error}
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def _render_result_page(
|
||||
success: bool,
|
||||
redirect_origin: str,
|
||||
post_message_data: dict,
|
||||
) -> str:
|
||||
"""Render a result page that sends postMessage to opener."""
|
||||
import json
|
||||
|
||||
status_class = "color: #22c55e;" if success else "color: #ef4444;"
|
||||
status_text = "Connected Successfully!" if success else "Connection Failed"
|
||||
message = (
|
||||
"You can close this window."
|
||||
if success
|
||||
else post_message_data.get("error_description", "An error occurred")
|
||||
)
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{'Connected' if success else 'Error'} - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1 style="{status_class}">{status_text}</h1>
|
||||
<p class="subtitle" style="margin-top: 16px;">
|
||||
{message}
|
||||
</p>
|
||||
<p class="subtitle" style="margin-top: 8px; font-size: 12px;">
|
||||
This window will close automatically...
|
||||
</p>
|
||||
</div>
|
||||
<script>
|
||||
(function() {{
|
||||
var targetOrigin = {json.dumps(redirect_origin)};
|
||||
var message = {json.dumps(post_message_data)};
|
||||
if (window.opener) {{
|
||||
window.opener.postMessage(message, targetOrigin);
|
||||
setTimeout(function() {{ window.close(); }}, 1500);
|
||||
}}
|
||||
}})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@connect_router.get("/{provider}", response_model=None)
|
||||
async def connect_page(
|
||||
provider: ProviderName,
|
||||
client_id: Annotated[str, Query(description="OAuth client ID")],
|
||||
scopes: Annotated[str, Query(description="Comma-separated integration scopes")],
|
||||
nonce: Annotated[str, Query(description="Nonce for replay protection")],
|
||||
redirect_origin: Annotated[str, Query(description="Origin for postMessage")],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Render the connect consent page.
|
||||
|
||||
This page allows users to select an existing credential or connect a new one
|
||||
for use by an external application.
|
||||
"""
|
||||
# Validate client
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": client_id})
|
||||
if not client:
|
||||
return HTMLResponse(
|
||||
_render_error_page("invalid_client", "Unknown application"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if client.status.value != "ACTIVE":
|
||||
return HTMLResponse(
|
||||
_render_error_page("invalid_client", "Application is not active"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Validate redirect origin
|
||||
if not validate_redirect_origin(redirect_origin, client):
|
||||
return HTMLResponse(
|
||||
_render_error_page(
|
||||
"invalid_request", "Invalid redirect origin for this application"
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Validate nonce
|
||||
if not await validate_nonce(client_id, nonce):
|
||||
return HTMLResponse(
|
||||
_render_error_page("invalid_request", "Nonce has already been used"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
requested_scopes = [s.strip() for s in scopes.split(",") if s.strip()]
|
||||
valid, invalid = validate_integration_scopes(requested_scopes)
|
||||
|
||||
if not valid:
|
||||
return HTMLResponse(
|
||||
_render_error_page(
|
||||
"invalid_scope", f"Invalid scopes requested: {', '.join(invalid)}"
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Verify all scopes are for the requested provider
|
||||
for scope in requested_scopes:
|
||||
scope_provider = get_provider_for_scope(scope)
|
||||
if scope_provider != provider:
|
||||
return HTMLResponse(
|
||||
_render_error_page(
|
||||
"invalid_scope",
|
||||
f"Scope '{scope}' is not for provider '{provider.value}'",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Get user's existing credentials for this provider
|
||||
user_credentials = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, provider
|
||||
)
|
||||
oauth_credentials = [
|
||||
c for c in user_credentials if isinstance(c, OAuth2Credentials)
|
||||
]
|
||||
|
||||
# Store connect state
|
||||
connect_token = await store_connect_state(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
provider=provider.value,
|
||||
requested_scopes=requested_scopes,
|
||||
redirect_origin=redirect_origin,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
return HTMLResponse(
|
||||
_render_connect_page(
|
||||
client_name=client.name,
|
||||
provider=provider.value,
|
||||
scopes=requested_scopes,
|
||||
credentials=oauth_credentials,
|
||||
connect_token=connect_token,
|
||||
action_url=f"/connect/{provider.value}/approve",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@connect_router.post("/{provider}/approve", response_model=None)
|
||||
async def approve_connect(
|
||||
provider: ProviderName,
|
||||
request: Request,
|
||||
connect_token: Annotated[str, Form()],
|
||||
action: Annotated[str, Form()],
|
||||
credential_id: Annotated[Optional[str], Form()] = None,
|
||||
user_id: Annotated[str, Security(get_user_id)] = "",
|
||||
) -> HTMLResponse | RedirectResponse:
|
||||
"""
|
||||
Process the connect form submission.
|
||||
|
||||
Creates a credential grant and returns a page that sends postMessage to opener.
|
||||
"""
|
||||
# Consume state (one-time use)
|
||||
state = await consume_connect_state(connect_token)
|
||||
if not state:
|
||||
return HTMLResponse(
|
||||
_render_error_page("invalid_request", "Invalid or expired connect session"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Verify user
|
||||
if state.user_id != user_id:
|
||||
return HTMLResponse(
|
||||
_render_error_page("access_denied", "User mismatch"),
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
redirect_origin = state.redirect_origin
|
||||
nonce = state.nonce
|
||||
requested_scopes = state.requested_scopes
|
||||
client_id = state.client_id
|
||||
|
||||
# Handle denial
|
||||
if action == "deny":
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="access_denied",
|
||||
error_description="User denied the connection request",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
# Handle connect new - redirect to OAuth login
|
||||
if action == "connect_new":
|
||||
# Get client database ID for continuation state
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": client_id})
|
||||
if not client:
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="invalid_client",
|
||||
error_description="Client not found",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
# Get OAuth handler for this provider
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
if not handler:
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="unsupported_provider",
|
||||
error_description=f"Provider '{provider.value}' does not support OAuth",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
# Store continuation state for after OAuth completes
|
||||
continuation_token = await store_connect_continuation(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
client_db_id=client.id,
|
||||
provider=provider.value,
|
||||
requested_scopes=requested_scopes,
|
||||
redirect_origin=redirect_origin,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
# Convert integration scopes to provider OAuth scopes
|
||||
provider_scopes = get_provider_scopes(provider, requested_scopes)
|
||||
|
||||
# Store OAuth state with continuation token in metadata
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id=user_id,
|
||||
provider=provider.value,
|
||||
scopes=provider_scopes,
|
||||
use_pkce=True,
|
||||
state_metadata={"connect_continuation": continuation_token},
|
||||
)
|
||||
|
||||
# Build OAuth URL and redirect
|
||||
login_url = handler.get_login_url(
|
||||
provider_scopes, state_token, code_challenge=code_challenge
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Redirecting to OAuth for connect_new: provider={provider.value}, "
|
||||
f"user={user_id}, client={client_id}"
|
||||
)
|
||||
|
||||
return RedirectResponse(url=login_url, status_code=302)
|
||||
|
||||
# Handle approval with existing credential
|
||||
if action == "approve":
|
||||
if not credential_id:
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="invalid_request",
|
||||
error_description="No credential selected",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
# Verify credential belongs to user and provider
|
||||
credential = await creds_manager.get(user_id, credential_id)
|
||||
if not credential:
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="invalid_request",
|
||||
error_description="Credential not found",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
if credential.provider != provider.value:
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="invalid_request",
|
||||
error_description="Credential provider mismatch",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
# Get client database ID
|
||||
client = await prisma.oauthclient.find_unique(where={"clientId": client_id})
|
||||
if not client:
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="invalid_client",
|
||||
error_description="Client not found",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
# Create or update grant
|
||||
grant_id = await _create_or_update_grant(
|
||||
user_id=user_id,
|
||||
credential_id=credential_id,
|
||||
client_db_id=client.id,
|
||||
provider=provider.value,
|
||||
requested_scopes=requested_scopes,
|
||||
)
|
||||
|
||||
post_data = create_post_message_data(
|
||||
success=True,
|
||||
grant_id=grant_id,
|
||||
credential_id=credential_id,
|
||||
provider=provider.value,
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(True, redirect_origin, post_data))
|
||||
|
||||
# Unknown action
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="invalid_request",
|
||||
error_description="Unknown action",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
|
||||
@connect_router.get("/{provider}/callback", response_model=None)
|
||||
async def connect_oauth_callback(
|
||||
provider: ProviderName,
|
||||
request: Request,
|
||||
code: str,
|
||||
state: str,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Handle OAuth callback after user authorizes a new connection.
|
||||
|
||||
This endpoint is called after the OAuth provider redirects back with an
|
||||
authorization code. It exchanges the code for tokens, creates the credential,
|
||||
creates the grant, and returns a page that sends postMessage to the opener.
|
||||
"""
|
||||
# Get OAuth handler
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
if not handler:
|
||||
return HTMLResponse(
|
||||
_render_error_page(
|
||||
"unsupported_provider",
|
||||
f"Provider '{provider.value}' does not support OAuth",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Verify the state token and get the associated state
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
user_id, state, provider.value
|
||||
)
|
||||
|
||||
if not valid_state:
|
||||
return HTMLResponse(
|
||||
_render_error_page(
|
||||
"invalid_state",
|
||||
"Invalid or expired OAuth state. Please try again.",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Check for continuation token in state metadata
|
||||
continuation_token = valid_state.state_metadata.get("connect_continuation")
|
||||
if not continuation_token:
|
||||
return HTMLResponse(
|
||||
_render_error_page(
|
||||
"invalid_request",
|
||||
"Missing continuation token. Please try again.",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Get continuation state
|
||||
continuation = await consume_connect_continuation(continuation_token)
|
||||
if not continuation:
|
||||
return HTMLResponse(
|
||||
_render_error_page(
|
||||
"invalid_request",
|
||||
"Invalid or expired continuation state. Please try again.",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Verify user matches
|
||||
if continuation.user_id != user_id:
|
||||
return HTMLResponse(
|
||||
_render_error_page("access_denied", "User mismatch"),
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
redirect_origin = continuation.redirect_origin
|
||||
nonce = continuation.nonce
|
||||
requested_scopes = continuation.requested_scopes
|
||||
client_db_id = continuation.client_db_id
|
||||
|
||||
try:
|
||||
# Handle default scopes
|
||||
scopes = handler.handle_default_scopes(valid_state.scopes)
|
||||
|
||||
# Exchange code for tokens
|
||||
credentials = await handler.exchange_code_for_tokens(
|
||||
code, scopes, valid_state.code_verifier
|
||||
)
|
||||
|
||||
# Linear returns scopes as a single string with spaces
|
||||
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
|
||||
credentials.scopes = credentials.scopes[0].split(" ")
|
||||
|
||||
# Store the new credentials
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.info(
|
||||
f"Created new credential via connect flow: provider={provider.value}, "
|
||||
f"user={user_id}, credential={credentials.id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token exchange failed: {e}")
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="oauth_error",
|
||||
error_description=f"Failed to complete OAuth: {str(e)}",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
# Create the credential grant
|
||||
try:
|
||||
grant_id = await _create_or_update_grant(
|
||||
user_id=user_id,
|
||||
credential_id=credentials.id,
|
||||
client_db_id=client_db_id,
|
||||
provider=provider.value,
|
||||
requested_scopes=requested_scopes,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created grant via connect flow: grant={grant_id}, "
|
||||
f"credential={credentials.id}, client={client_db_id}"
|
||||
)
|
||||
|
||||
post_data = create_post_message_data(
|
||||
success=True,
|
||||
grant_id=grant_id,
|
||||
credential_id=credentials.id,
|
||||
provider=provider.value,
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(True, redirect_origin, post_data))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create grant: {e}")
|
||||
post_data = create_post_message_data(
|
||||
success=False,
|
||||
error="grant_error",
|
||||
error_description=f"Failed to create grant: {str(e)}",
|
||||
nonce=nonce,
|
||||
)
|
||||
return HTMLResponse(_render_result_page(False, redirect_origin, post_data))
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(request: Request, provider_name: ProviderName):
|
||||
"""Get the OAuth handler for a provider."""
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER
|
||||
|
||||
# Ensure blocks are loaded so SDK providers are available
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks() # This is cached, so it only runs once
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks: {e}")
|
||||
|
||||
# Convert provider_name to string for lookup
|
||||
provider_key = (
|
||||
provider_name.value if hasattr(provider_name, "value") else str(provider_name)
|
||||
)
|
||||
|
||||
if provider_key not in HANDLERS_BY_NAME:
|
||||
return None
|
||||
|
||||
# Check if this provider has custom OAuth credentials
|
||||
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_key)
|
||||
|
||||
if oauth_credentials and not oauth_credentials.use_secrets:
|
||||
# SDK provider with custom env vars
|
||||
import os
|
||||
|
||||
client_id = (
|
||||
os.getenv(oauth_credentials.client_id_env_var)
|
||||
if oauth_credentials.client_id_env_var
|
||||
else None
|
||||
)
|
||||
client_secret = (
|
||||
os.getenv(oauth_credentials.client_secret_env_var)
|
||||
if oauth_credentials.client_secret_env_var
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# Original provider using settings.secrets
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id", None)
|
||||
client_secret = getattr(
|
||||
settings.secrets, f"{provider_name.value}_client_secret", None
|
||||
)
|
||||
|
||||
if not (client_id and client_secret):
|
||||
logger.warning(
|
||||
f"OAuth integration not configured for provider {provider_name.value}"
|
||||
)
|
||||
return None
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_key]
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
|
||||
if not frontend_base_url:
|
||||
logger.error("Frontend base URL is not configured")
|
||||
return None
|
||||
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Security utilities for the integration connect popup flow.
|
||||
|
||||
Handles state management, nonce validation, and origin verification
|
||||
for the OAuth-style popup flow when connecting integrations.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from prisma.models import OAuthClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# State expiration time
|
||||
STATE_EXPIRATION_SECONDS = 600 # 10 minutes
|
||||
NONCE_EXPIRATION_SECONDS = 3600 # 1 hour (nonces valid for longer to prevent races)
|
||||
|
||||
|
||||
class ConnectState(BaseModel):
|
||||
"""Pydantic model for connect state stored in Redis."""
|
||||
|
||||
user_id: str
|
||||
client_id: str
|
||||
provider: str
|
||||
requested_scopes: list[str]
|
||||
redirect_origin: str
|
||||
nonce: str
|
||||
credential_id: Optional[str] = None
|
||||
created_at: str
|
||||
expires_at: str
|
||||
|
||||
|
||||
class ConnectContinuationState(BaseModel):
|
||||
"""
|
||||
State for continuing the connect flow after OAuth completes.
|
||||
|
||||
When a user chooses to "connect new" during the connect flow,
|
||||
we store this state so we can complete the grant creation after
|
||||
the OAuth callback.
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
client_id: str # Public client ID
|
||||
client_db_id: str # Database UUID of the OAuth client
|
||||
provider: str
|
||||
requested_scopes: list[str] # Integration scopes (e.g., "google:gmail.readonly")
|
||||
redirect_origin: str
|
||||
nonce: str
|
||||
created_at: str
|
||||
|
||||
|
||||
# Continuation state expiration (same as regular state)
|
||||
CONTINUATION_EXPIRATION_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
async def store_connect_continuation(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
client_db_id: str,
|
||||
provider: str,
|
||||
requested_scopes: list[str],
|
||||
redirect_origin: str,
|
||||
nonce: str,
|
||||
) -> str:
|
||||
"""
|
||||
Store continuation state for completing connect flow after OAuth.
|
||||
|
||||
Args:
|
||||
user_id: User initiating the connection
|
||||
client_id: Public OAuth client ID
|
||||
client_db_id: Database UUID of the OAuth client
|
||||
provider: Integration provider name
|
||||
requested_scopes: Requested integration scopes
|
||||
redirect_origin: Origin to send postMessage to
|
||||
nonce: Client-provided nonce for replay protection
|
||||
|
||||
Returns:
|
||||
Continuation token to be stored in OAuth state metadata
|
||||
"""
|
||||
token = generate_connect_token()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
state = ConnectContinuationState(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
client_db_id=client_db_id,
|
||||
provider=provider,
|
||||
requested_scopes=requested_scopes,
|
||||
redirect_origin=redirect_origin,
|
||||
nonce=nonce,
|
||||
created_at=now.isoformat(),
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_continuation:{token}"
|
||||
await redis.setex(key, CONTINUATION_EXPIRATION_SECONDS, state.model_dump_json())
|
||||
|
||||
logger.debug(f"Stored connect continuation state for token {token[:8]}...")
|
||||
return token
|
||||
|
||||
|
||||
async def get_connect_continuation(token: str) -> Optional[ConnectContinuationState]:
|
||||
"""
|
||||
Get continuation state without consuming it.
|
||||
|
||||
Args:
|
||||
token: Continuation token
|
||||
|
||||
Returns:
|
||||
ConnectContinuationState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_continuation:{token}"
|
||||
data = await redis.get(key)
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
return ConnectContinuationState.model_validate_json(data)
|
||||
|
||||
|
||||
async def consume_connect_continuation(
|
||||
token: str,
|
||||
) -> Optional[ConnectContinuationState]:
|
||||
"""
|
||||
Get and consume (delete) continuation state.
|
||||
|
||||
This ensures the token can only be used once.
|
||||
|
||||
Args:
|
||||
token: Continuation token
|
||||
|
||||
Returns:
|
||||
ConnectContinuationState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_continuation:{token}"
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions
|
||||
data = await redis.getdel(key)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
state = ConnectContinuationState.model_validate_json(data)
|
||||
logger.debug(f"Consumed connect continuation state for token {token[:8]}...")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def generate_connect_token() -> str:
|
||||
"""Generate a secure random token for connect state."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
async def store_connect_state(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
provider: str,
|
||||
requested_scopes: list[str],
|
||||
redirect_origin: str,
|
||||
nonce: str,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Store connect state in Redis and return a state token.
|
||||
|
||||
Args:
|
||||
user_id: User initiating the connection
|
||||
client_id: OAuth client ID (public identifier)
|
||||
provider: Integration provider name
|
||||
requested_scopes: Requested integration scopes
|
||||
redirect_origin: Origin to send postMessage to
|
||||
nonce: Client-provided nonce for replay protection
|
||||
credential_id: Optional existing credential to grant access to
|
||||
|
||||
Returns:
|
||||
State token to be used in the connect flow
|
||||
"""
|
||||
token = generate_connect_token()
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now.timestamp() + STATE_EXPIRATION_SECONDS
|
||||
|
||||
state = ConnectState(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
provider=provider,
|
||||
requested_scopes=requested_scopes,
|
||||
redirect_origin=redirect_origin,
|
||||
nonce=nonce,
|
||||
credential_id=credential_id,
|
||||
created_at=now.isoformat(),
|
||||
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_state:{token}"
|
||||
await redis.setex(key, STATE_EXPIRATION_SECONDS, state.model_dump_json())
|
||||
|
||||
logger.debug(f"Stored connect state for token {token[:8]}...")
|
||||
return token
|
||||
|
||||
|
||||
async def get_connect_state(token: str) -> Optional[ConnectState]:
|
||||
"""
|
||||
Get connect state without consuming it.
|
||||
|
||||
Args:
|
||||
token: State token
|
||||
|
||||
Returns:
|
||||
ConnectState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_state:{token}"
|
||||
data = await redis.get(key)
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
return ConnectState.model_validate_json(data)
|
||||
|
||||
|
||||
async def consume_connect_state(token: str) -> Optional[ConnectState]:
|
||||
"""
|
||||
Get and consume (delete) connect state.
|
||||
|
||||
This ensures the token can only be used once.
|
||||
|
||||
Args:
|
||||
token: State token
|
||||
|
||||
Returns:
|
||||
ConnectState or None if not found/expired
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"connect_state:{token}"
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions
|
||||
data = await redis.getdel(key)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
state = ConnectState.model_validate_json(data)
|
||||
logger.debug(f"Consumed connect state for token {token[:8]}...")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
async def validate_nonce(client_id: str, nonce: str) -> bool:
|
||||
"""
|
||||
Validate that a nonce hasn't been used before (replay protection).
|
||||
|
||||
Uses atomic SET NX EX for check-and-set with automatic TTL expiry.
|
||||
|
||||
Args:
|
||||
client_id: OAuth client ID
|
||||
nonce: Client-provided nonce
|
||||
|
||||
Returns:
|
||||
True if nonce is valid (not replayed)
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Create a hash of the nonce for storage
|
||||
nonce_hash = hashlib.sha256(nonce.encode()).hexdigest()
|
||||
key = f"nonce:{client_id}:{nonce_hash}"
|
||||
|
||||
# Atomic set-if-not-exists with expiration (prevents race condition)
|
||||
was_set = await redis.set(key, "1", nx=True, ex=NONCE_EXPIRATION_SECONDS)
|
||||
if was_set:
|
||||
return True
|
||||
|
||||
logger.warning(f"Nonce replay detected for client {client_id}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_redirect_origin(origin: str, client: OAuthClient) -> bool:
|
||||
"""
|
||||
Validate that a redirect origin is allowed for the client.
|
||||
|
||||
The origin must match one of the client's registered redirect URIs
|
||||
or webhook domains.
|
||||
|
||||
Args:
|
||||
origin: Origin URL to validate
|
||||
client: OAuth client to check against
|
||||
|
||||
Returns:
|
||||
True if origin is allowed
|
||||
"""
|
||||
from backend.util.url import hostname_matches_any_domain
|
||||
|
||||
try:
|
||||
parsed_origin = urlparse(origin)
|
||||
origin_host = parsed_origin.netloc.lower()
|
||||
|
||||
# Check against redirect URIs
|
||||
for redirect_uri in client.redirectUris:
|
||||
parsed_redirect = urlparse(redirect_uri)
|
||||
if parsed_redirect.netloc.lower() == origin_host:
|
||||
return True
|
||||
|
||||
# Check against webhook domains
|
||||
if hostname_matches_any_domain(origin_host, client.webhookDomains):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def create_post_message_data(
|
||||
success: bool,
|
||||
grant_id: Optional[str] = None,
|
||||
credential_id: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
nonce: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create the postMessage data to send back to the opener.
|
||||
|
||||
Args:
|
||||
success: Whether the operation succeeded
|
||||
grant_id: ID of the created grant (if successful)
|
||||
credential_id: ID of the credential (if successful)
|
||||
provider: Provider name
|
||||
error: Error code (if failed)
|
||||
error_description: Human-readable error description
|
||||
nonce: Original nonce for correlation
|
||||
|
||||
Returns:
|
||||
Dictionary to be sent via postMessage
|
||||
"""
|
||||
data: dict[str, Any] = {
|
||||
"type": "autogpt_connect_result",
|
||||
"success": success,
|
||||
}
|
||||
|
||||
if nonce:
|
||||
data["nonce"] = nonce
|
||||
|
||||
if success:
|
||||
data["grant_id"] = grant_id
|
||||
data["credential_id"] = credential_id
|
||||
data["provider"] = provider
|
||||
else:
|
||||
data["error"] = error
|
||||
data["error_description"] = error_description
|
||||
|
||||
return data
|
||||
20
autogpt_platform/backend/backend/server/oauth/__init__.py
Normal file
20
autogpt_platform/backend/backend/server/oauth/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
OAuth 2.0 Provider module for AutoGPT Platform.
|
||||
|
||||
This module implements AutoGPT as an OAuth 2.0 Authorization Server,
|
||||
allowing external applications to authenticate users and access
|
||||
platform resources with user consent.
|
||||
|
||||
Key components:
|
||||
- router.py: OAuth authorization and token endpoints
|
||||
- discovery_router.py: OIDC discovery endpoints
|
||||
- client_router.py: OAuth client management
|
||||
- token_service.py: JWT generation and validation
|
||||
- service.py: Core OAuth business logic
|
||||
"""
|
||||
|
||||
from backend.server.oauth.client_router import client_router
|
||||
from backend.server.oauth.discovery_router import discovery_router
|
||||
from backend.server.oauth.router import oauth_router
|
||||
|
||||
__all__ = ["oauth_router", "discovery_router", "client_router"]
|
||||
311
autogpt_platform/backend/backend/server/oauth/client_router.py
Normal file
311
autogpt_platform/backend/backend/server/oauth/client_router.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
OAuth Client Management endpoints.
|
||||
|
||||
Implements self-service client registration and management:
|
||||
- POST /oauth/clients - Register a new client
|
||||
- GET /oauth/clients - List owned clients
|
||||
- GET /oauth/clients/{client_id} - Get client details
|
||||
- PATCH /oauth/clients/{client_id} - Update client
|
||||
- DELETE /oauth/clients/{client_id} - Delete client
|
||||
- POST /oauth/clients/{client_id}/rotate-secret - Rotate client secret
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.enums import OAuthClientStatus
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.server.oauth.models import (
|
||||
ClientResponse,
|
||||
ClientSecretResponse,
|
||||
OAuthScope,
|
||||
RegisterClientRequest,
|
||||
UpdateClientRequest,
|
||||
)
|
||||
|
||||
client_router = APIRouter(prefix="/oauth/clients", tags=["oauth-clients"])
|
||||
|
||||
|
||||
def _generate_client_id() -> str:
|
||||
"""Generate a unique client ID."""
|
||||
return f"app_{secrets.token_urlsafe(16)}"
|
||||
|
||||
|
||||
def _generate_client_secret() -> str:
|
||||
"""Generate a secure client secret."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _hash_secret(secret: str, salt: str) -> str:
|
||||
"""Hash a client secret with salt."""
|
||||
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
|
||||
|
||||
|
||||
def _client_to_response(client) -> ClientResponse:
|
||||
"""Convert Prisma client to response model."""
|
||||
return ClientResponse(
|
||||
id=client.id,
|
||||
client_id=client.clientId,
|
||||
client_type=client.clientType,
|
||||
name=client.name,
|
||||
description=client.description,
|
||||
logo_url=client.logoUrl,
|
||||
homepage_url=client.homepageUrl,
|
||||
privacy_policy_url=client.privacyPolicyUrl,
|
||||
terms_of_service_url=client.termsOfServiceUrl,
|
||||
redirect_uris=client.redirectUris,
|
||||
allowed_scopes=client.allowedScopes,
|
||||
webhook_domains=client.webhookDomains,
|
||||
status=client.status.value,
|
||||
created_at=client.createdAt,
|
||||
updated_at=client.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
# Default allowed scopes for new clients
|
||||
DEFAULT_ALLOWED_SCOPES = [
|
||||
OAuthScope.OPENID.value,
|
||||
OAuthScope.PROFILE.value,
|
||||
OAuthScope.EMAIL.value,
|
||||
OAuthScope.INTEGRATIONS_LIST.value,
|
||||
OAuthScope.INTEGRATIONS_CONNECT.value,
|
||||
OAuthScope.INTEGRATIONS_DELETE.value,
|
||||
OAuthScope.AGENTS_EXECUTE.value,
|
||||
]
|
||||
|
||||
|
||||
@client_router.post("/", response_model=ClientSecretResponse)
|
||||
async def register_client(
|
||||
request: RegisterClientRequest,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientSecretResponse:
|
||||
"""
|
||||
Register a new OAuth client.
|
||||
|
||||
The client is immediately active (no admin approval required).
|
||||
For confidential clients, the client_secret is returned only once.
|
||||
"""
|
||||
# Generate client credentials
|
||||
client_id = _generate_client_id()
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
client_secret_salt = None
|
||||
|
||||
if request.client_type == "confidential":
|
||||
client_secret = _generate_client_secret()
|
||||
client_secret_salt = secrets.token_urlsafe(16)
|
||||
client_secret_hash = _hash_secret(client_secret, client_secret_salt)
|
||||
|
||||
# Create client
|
||||
await prisma.oauthclient.create(
|
||||
data={
|
||||
"clientId": client_id,
|
||||
"clientSecretHash": client_secret_hash,
|
||||
"clientSecretSalt": client_secret_salt,
|
||||
"clientType": request.client_type,
|
||||
"name": request.name,
|
||||
"description": request.description,
|
||||
"logoUrl": str(request.logo_url) if request.logo_url else None,
|
||||
"homepageUrl": str(request.homepage_url) if request.homepage_url else None,
|
||||
"privacyPolicyUrl": (
|
||||
str(request.privacy_policy_url) if request.privacy_policy_url else None
|
||||
),
|
||||
"termsOfServiceUrl": (
|
||||
str(request.terms_of_service_url)
|
||||
if request.terms_of_service_url
|
||||
else None
|
||||
),
|
||||
"redirectUris": request.redirect_uris,
|
||||
"allowedScopes": DEFAULT_ALLOWED_SCOPES,
|
||||
"webhookDomains": request.webhook_domains,
|
||||
"status": OAuthClientStatus.ACTIVE,
|
||||
"ownerId": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return ClientSecretResponse(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret or "",
|
||||
)
|
||||
|
||||
|
||||
@client_router.get("/", response_model=list[ClientResponse])
|
||||
async def list_clients(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[ClientResponse]:
|
||||
"""List all OAuth clients owned by the current user."""
|
||||
clients = await prisma.oauthclient.find_many(
|
||||
where={"ownerId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [_client_to_response(c) for c in clients]
|
||||
|
||||
|
||||
@client_router.get("/{client_id}", response_model=ClientResponse)
|
||||
async def get_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Get details of a specific OAuth client."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
return _client_to_response(client)
|
||||
|
||||
|
||||
@client_router.patch("/{client_id}", response_model=ClientResponse)
|
||||
async def update_client(
|
||||
client_id: str,
|
||||
request: UpdateClientRequest,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Update an OAuth client."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
# Build update data
|
||||
update_data: dict[str, str | list[str] | None] = {}
|
||||
if request.name is not None:
|
||||
update_data["name"] = request.name
|
||||
if request.description is not None:
|
||||
update_data["description"] = request.description
|
||||
if request.logo_url is not None:
|
||||
update_data["logoUrl"] = str(request.logo_url)
|
||||
if request.homepage_url is not None:
|
||||
update_data["homepageUrl"] = str(request.homepage_url)
|
||||
if request.privacy_policy_url is not None:
|
||||
update_data["privacyPolicyUrl"] = str(request.privacy_policy_url)
|
||||
if request.terms_of_service_url is not None:
|
||||
update_data["termsOfServiceUrl"] = str(request.terms_of_service_url)
|
||||
if request.redirect_uris is not None:
|
||||
update_data["redirectUris"] = request.redirect_uris
|
||||
if request.webhook_domains is not None:
|
||||
update_data["webhookDomains"] = request.webhook_domains
|
||||
|
||||
if not update_data:
|
||||
return _client_to_response(client)
|
||||
|
||||
updated = await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data=update_data, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return _client_to_response(updated)
|
||||
|
||||
|
||||
@client_router.delete("/{client_id}")
|
||||
async def delete_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete an OAuth client.
|
||||
|
||||
This will also revoke all tokens and authorizations for this client.
|
||||
"""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
# Delete cascades will handle tokens, codes, and authorizations
|
||||
await prisma.oauthclient.delete(where={"id": client.id})
|
||||
|
||||
return {"status": "deleted", "client_id": client_id}
|
||||
|
||||
|
||||
@client_router.post("/{client_id}/rotate-secret", response_model=ClientSecretResponse)
|
||||
async def rotate_client_secret(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientSecretResponse:
|
||||
"""
|
||||
Rotate the client secret for a confidential client.
|
||||
|
||||
The new secret is returned only once. All existing tokens remain valid.
|
||||
"""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
if client.clientType != "confidential":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot rotate secret for public clients",
|
||||
)
|
||||
|
||||
# Generate new secret
|
||||
new_secret = _generate_client_secret()
|
||||
new_salt = secrets.token_urlsafe(16)
|
||||
new_hash = _hash_secret(new_secret, new_salt)
|
||||
|
||||
await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data={
|
||||
"clientSecretHash": new_hash,
|
||||
"clientSecretSalt": new_salt,
|
||||
},
|
||||
)
|
||||
|
||||
return ClientSecretResponse(
|
||||
client_id=client_id,
|
||||
client_secret=new_secret,
|
||||
)
|
||||
|
||||
|
||||
@client_router.post("/{client_id}/suspend")
|
||||
async def suspend_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Suspend an OAuth client (prevents new authorizations)."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
updated = await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data={"status": OAuthClientStatus.SUSPENDED},
|
||||
)
|
||||
|
||||
return _client_to_response(updated)
|
||||
|
||||
|
||||
@client_router.post("/{client_id}/activate")
|
||||
async def activate_client(
|
||||
client_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> ClientResponse:
|
||||
"""Reactivate a suspended OAuth client."""
|
||||
client = await prisma.oauthclient.find_first(
|
||||
where={"clientId": client_id, "ownerId": user_id}
|
||||
)
|
||||
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
updated = await prisma.oauthclient.update(
|
||||
where={"id": client.id},
|
||||
data={"status": OAuthClientStatus.ACTIVE},
|
||||
)
|
||||
|
||||
return _client_to_response(updated)
|
||||
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Server-rendered HTML templates for OAuth consent UI.
|
||||
|
||||
These templates are used for the OAuth authorization flow
|
||||
when the user needs to approve access for an external application.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
|
||||
|
||||
|
||||
def _base_styles() -> str:
|
||||
"""Common CSS styles for all OAuth pages."""
|
||||
return """
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
color: #e4e4e7;
|
||||
}
|
||||
.container {
|
||||
background: #27272a;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
|
||||
max-width: 420px;
|
||||
width: 100%;
|
||||
padding: 32px;
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.logo {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
border-radius: 12px;
|
||||
margin-bottom: 16px;
|
||||
background: #3f3f46;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
.logo img {
|
||||
max-width: 48px;
|
||||
max-height: 48px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.logo-placeholder {
|
||||
font-size: 28px;
|
||||
color: #a1a1aa;
|
||||
}
|
||||
h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.subtitle {
|
||||
color: #a1a1aa;
|
||||
font-size: 14px;
|
||||
}
|
||||
.app-name {
|
||||
color: #22d3ee;
|
||||
font-weight: 600;
|
||||
}
|
||||
.divider {
|
||||
height: 1px;
|
||||
background: #3f3f46;
|
||||
margin: 24px 0;
|
||||
}
|
||||
.scopes-section h2 {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: #a1a1aa;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.scope-item {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
padding: 12px 0;
|
||||
border-bottom: 1px solid #3f3f46;
|
||||
}
|
||||
.scope-item:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
.scope-icon {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
color: #22d3ee;
|
||||
flex-shrink: 0;
|
||||
margin-top: 2px;
|
||||
}
|
||||
.scope-text {
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.buttons {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
margin-top: 24px;
|
||||
}
|
||||
.btn {
|
||||
flex: 1;
|
||||
padding: 12px 24px;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.btn-cancel {
|
||||
background: #3f3f46;
|
||||
color: #e4e4e7;
|
||||
}
|
||||
.btn-cancel:hover {
|
||||
background: #52525b;
|
||||
}
|
||||
.btn-allow {
|
||||
background: #22d3ee;
|
||||
color: #0f172a;
|
||||
}
|
||||
.btn-allow:hover {
|
||||
background: #06b6d4;
|
||||
}
|
||||
.footer {
|
||||
margin-top: 24px;
|
||||
text-align: center;
|
||||
font-size: 12px;
|
||||
color: #71717a;
|
||||
}
|
||||
.footer a {
|
||||
color: #a1a1aa;
|
||||
text-decoration: none;
|
||||
}
|
||||
.footer a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.error-container {
|
||||
text-align: center;
|
||||
}
|
||||
.error-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 16px;
|
||||
color: #ef4444;
|
||||
}
|
||||
.error-title {
|
||||
color: #ef4444;
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.error-message {
|
||||
color: #a1a1aa;
|
||||
font-size: 14px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
.success-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 16px;
|
||||
color: #22c55e;
|
||||
}
|
||||
.success-title {
|
||||
color: #22c55e;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _check_icon() -> str:
|
||||
"""SVG checkmark icon."""
|
||||
return """
|
||||
<svg class="scope-icon" viewBox="0 0 20 20" fill="currentColor">
|
||||
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/>
|
||||
</svg>
|
||||
"""
|
||||
|
||||
|
||||
def _error_icon() -> str:
|
||||
"""SVG error icon."""
|
||||
return """
|
||||
<svg class="error-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="12" r="10"/>
|
||||
<line x1="15" y1="9" x2="9" y2="15"/>
|
||||
<line x1="9" y1="9" x2="15" y2="15"/>
|
||||
</svg>
|
||||
"""
|
||||
|
||||
|
||||
def _success_icon() -> str:
|
||||
"""SVG success icon."""
|
||||
return """
|
||||
<svg class="success-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="12" r="10"/>
|
||||
<path d="M9 12l2 2 4-4"/>
|
||||
</svg>
|
||||
"""
|
||||
|
||||
|
||||
def render_consent_page(
|
||||
client_name: str,
|
||||
client_logo: Optional[str],
|
||||
scopes: list[str],
|
||||
consent_token: str,
|
||||
action_url: str,
|
||||
privacy_policy_url: Optional[str] = None,
|
||||
terms_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render the OAuth consent page.
|
||||
|
||||
Args:
|
||||
client_name: Name of the requesting application
|
||||
client_logo: URL to the client's logo (optional)
|
||||
scopes: List of requested scopes
|
||||
consent_token: CSRF token for the consent form
|
||||
action_url: URL to submit the consent form
|
||||
privacy_policy_url: Client's privacy policy URL (optional)
|
||||
terms_url: Client's terms of service URL (optional)
|
||||
|
||||
Returns:
|
||||
HTML string for the consent page
|
||||
"""
|
||||
# Build logo HTML
|
||||
if client_logo:
|
||||
logo_html = f'<img src="{client_logo}" alt="{client_name}">'
|
||||
else:
|
||||
logo_html = f'<span class="logo-placeholder">{client_name[0].upper()}</span>'
|
||||
|
||||
# Build scopes HTML
|
||||
scopes_html = ""
|
||||
for scope in scopes:
|
||||
description = SCOPE_DESCRIPTIONS.get(scope, scope)
|
||||
scopes_html += f"""
|
||||
<div class="scope-item">
|
||||
{_check_icon()}
|
||||
<span class="scope-text">{description}</span>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Build footer links
|
||||
footer_links = []
|
||||
if privacy_policy_url:
|
||||
footer_links.append(
|
||||
f'<a href="{privacy_policy_url}" target="_blank">Privacy Policy</a>'
|
||||
)
|
||||
if terms_url:
|
||||
footer_links.append(
|
||||
f'<a href="{terms_url}" target="_blank">Terms of Service</a>'
|
||||
)
|
||||
footer_html = " • ".join(footer_links) if footer_links else ""
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorize {client_name} - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<div class="logo">{logo_html}</div>
|
||||
<h1>Authorize <span class="app-name">{client_name}</span></h1>
|
||||
<p class="subtitle">wants to access your AutoGPT account</p>
|
||||
</div>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<div class="scopes-section">
|
||||
<h2>This will allow {client_name} to:</h2>
|
||||
{scopes_html}
|
||||
</div>
|
||||
|
||||
<form method="POST" action="{action_url}">
|
||||
<input type="hidden" name="consent_token" value="{consent_token}">
|
||||
<div class="buttons">
|
||||
<button type="submit" name="authorize" value="false" class="btn btn-cancel">
|
||||
Cancel
|
||||
</button>
|
||||
<button type="submit" name="authorize" value="true" class="btn btn-allow">
|
||||
Allow
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
{f'<div class="footer">{footer_html}</div>' if footer_html else ''}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def render_error_page(
|
||||
error: str,
|
||||
error_description: str,
|
||||
redirect_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render an OAuth error page.
|
||||
|
||||
Args:
|
||||
error: Error code
|
||||
error_description: Human-readable error description
|
||||
redirect_url: Optional URL to redirect back (if safe)
|
||||
|
||||
Returns:
|
||||
HTML string for the error page
|
||||
"""
|
||||
redirect_html = ""
|
||||
if redirect_url:
|
||||
redirect_html = f"""
|
||||
<a href="{redirect_url}" class="btn btn-cancel" style="display: inline-block; text-decoration: none;">
|
||||
Go Back
|
||||
</a>
|
||||
"""
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorization Error - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="error-container">
|
||||
{_error_icon()}
|
||||
<h1 class="error-title">Authorization Failed</h1>
|
||||
<p class="error-message">{error_description}</p>
|
||||
<p class="error-message" style="font-size: 12px; color: #52525b;">
|
||||
Error code: {error}
|
||||
</p>
|
||||
{redirect_html}
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def render_success_page(
|
||||
message: str,
|
||||
redirect_origin: Optional[str] = None,
|
||||
post_message_data: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render a success page, optionally with postMessage for popup flows.
|
||||
|
||||
Args:
|
||||
message: Success message to display
|
||||
redirect_origin: Origin for postMessage (popup flows)
|
||||
post_message_data: Data to send via postMessage (popup flows)
|
||||
|
||||
Returns:
|
||||
HTML string for the success page
|
||||
"""
|
||||
# PostMessage script for popup flows
|
||||
post_message_script = ""
|
||||
if redirect_origin and post_message_data:
|
||||
import json
|
||||
|
||||
post_message_script = f"""
|
||||
<script>
|
||||
(function() {{
|
||||
var targetOrigin = {json.dumps(redirect_origin)};
|
||||
var message = {json.dumps(post_message_data)};
|
||||
if (window.opener) {{
|
||||
window.opener.postMessage(message, targetOrigin);
|
||||
setTimeout(function() {{ window.close(); }}, 1000);
|
||||
}}
|
||||
}})();
|
||||
</script>
|
||||
"""
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorization Successful - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="error-container">
|
||||
{_success_icon()}
|
||||
<h1 class="success-title">Success!</h1>
|
||||
<p class="error-message">{message}</p>
|
||||
<p class="error-message" style="font-size: 12px;">
|
||||
This window will close automatically...
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{post_message_script}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def render_login_redirect_page(login_url: str) -> str:
|
||||
"""
|
||||
Render a page that redirects to login.
|
||||
|
||||
Args:
|
||||
login_url: URL to redirect to for login
|
||||
|
||||
Returns:
|
||||
HTML string with auto-redirect
|
||||
"""
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta http-equiv="refresh" content="0;url={login_url}">
|
||||
<title>Login Required - AutoGPT</title>
|
||||
<style>{_base_styles()}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="error-container">
|
||||
<p class="error-message">Redirecting to login...</p>
|
||||
<a href="{login_url}" class="btn btn-allow" style="display: inline-block; text-decoration: none;">
|
||||
Click here if not redirected
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
OIDC Discovery endpoints.
|
||||
|
||||
Implements:
|
||||
- GET /.well-known/openid-configuration - OIDC Discovery Document
|
||||
- GET /.well-known/jwks.json - JSON Web Key Set
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from backend.server.oauth.models import JWKS, OpenIDConfiguration
|
||||
from backend.server.oauth.token_service import get_token_service
|
||||
from backend.util.settings import Settings
|
||||
|
||||
discovery_router = APIRouter(tags=["oidc-discovery"])
|
||||
|
||||
|
||||
@discovery_router.get(
|
||||
"/.well-known/openid-configuration",
|
||||
response_model=OpenIDConfiguration,
|
||||
)
|
||||
async def openid_configuration() -> OpenIDConfiguration:
|
||||
"""
|
||||
OIDC Discovery Document.
|
||||
|
||||
Returns metadata about the OAuth 2.0 authorization server including
|
||||
endpoints, supported features, and algorithms.
|
||||
"""
|
||||
settings = Settings()
|
||||
base_url = settings.config.platform_base_url or "https://platform.agpt.co"
|
||||
|
||||
return OpenIDConfiguration(
|
||||
issuer=base_url,
|
||||
authorization_endpoint=f"{base_url}/oauth/authorize",
|
||||
token_endpoint=f"{base_url}/oauth/token",
|
||||
userinfo_endpoint=f"{base_url}/oauth/userinfo",
|
||||
revocation_endpoint=f"{base_url}/oauth/revoke",
|
||||
jwks_uri=f"{base_url}/.well-known/jwks.json",
|
||||
scopes_supported=[
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
"integrations:list",
|
||||
"integrations:connect",
|
||||
"integrations:delete",
|
||||
"agents:execute",
|
||||
],
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
token_endpoint_auth_methods_supported=[
|
||||
"client_secret_post",
|
||||
"client_secret_basic",
|
||||
"none", # For public clients with PKCE
|
||||
],
|
||||
code_challenge_methods_supported=["S256"],
|
||||
subject_types_supported=["public"],
|
||||
id_token_signing_alg_values_supported=["RS256"],
|
||||
)
|
||||
|
||||
|
||||
@discovery_router.get("/.well-known/jwks.json", response_model=JWKS)
|
||||
async def jwks() -> dict:
|
||||
"""
|
||||
JSON Web Key Set (JWKS).
|
||||
|
||||
Returns the public key(s) used to verify JWT signatures.
|
||||
External applications can use these keys to verify access tokens
|
||||
and ID tokens issued by this authorization server.
|
||||
"""
|
||||
token_service = get_token_service()
|
||||
return token_service.get_jwks()
|
||||
162
autogpt_platform/backend/backend/server/oauth/errors.py
Normal file
162
autogpt_platform/backend/backend/server/oauth/errors.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
OAuth 2.0 Error Responses (RFC 6749 Section 5.2).
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OAuthErrorCode(str, Enum):
|
||||
"""Standard OAuth 2.0 error codes."""
|
||||
|
||||
# Authorization endpoint errors (RFC 6749 Section 4.1.2.1)
|
||||
INVALID_REQUEST = "invalid_request"
|
||||
UNAUTHORIZED_CLIENT = "unauthorized_client"
|
||||
ACCESS_DENIED = "access_denied"
|
||||
UNSUPPORTED_RESPONSE_TYPE = "unsupported_response_type"
|
||||
INVALID_SCOPE = "invalid_scope"
|
||||
SERVER_ERROR = "server_error"
|
||||
TEMPORARILY_UNAVAILABLE = "temporarily_unavailable"
|
||||
|
||||
# Token endpoint errors (RFC 6749 Section 5.2)
|
||||
INVALID_CLIENT = "invalid_client"
|
||||
INVALID_GRANT = "invalid_grant"
|
||||
UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type"
|
||||
|
||||
# Extension errors
|
||||
LOGIN_REQUIRED = "login_required"
|
||||
CONSENT_REQUIRED = "consent_required"
|
||||
|
||||
|
||||
class OAuthErrorResponse(BaseModel):
|
||||
"""OAuth error response model."""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
error_uri: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthError(Exception):
|
||||
"""Base OAuth error exception."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: OAuthErrorCode,
|
||||
description: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
):
|
||||
self.error = error
|
||||
self.description = description
|
||||
self.uri = uri
|
||||
self.state = state
|
||||
super().__init__(description or error.value)
|
||||
|
||||
def to_response(self) -> OAuthErrorResponse:
|
||||
"""Convert to response model."""
|
||||
return OAuthErrorResponse(
|
||||
error=self.error.value,
|
||||
error_description=self.description,
|
||||
error_uri=self.uri,
|
||||
)
|
||||
|
||||
def to_redirect(self, redirect_uri: str) -> RedirectResponse:
|
||||
"""Convert to redirect response with error in query params."""
|
||||
params = {"error": self.error.value}
|
||||
if self.description:
|
||||
params["error_description"] = self.description
|
||||
if self.uri:
|
||||
params["error_uri"] = self.uri
|
||||
if self.state:
|
||||
params["state"] = self.state
|
||||
|
||||
separator = "&" if "?" in redirect_uri else "?"
|
||||
url = f"{redirect_uri}{separator}{urlencode(params)}"
|
||||
return RedirectResponse(url=url, status_code=302)
|
||||
|
||||
def to_http_exception(self, status_code: int = 400) -> HTTPException:
|
||||
"""Convert to FastAPI HTTPException."""
|
||||
return HTTPException(
|
||||
status_code=status_code,
|
||||
detail=self.to_response().model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
# Convenience error classes
|
||||
class InvalidRequestError(OAuthError):
|
||||
"""The request is missing a required parameter or is otherwise malformed."""
|
||||
|
||||
def __init__(self, description: str, state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.INVALID_REQUEST, description, state=state)
|
||||
|
||||
|
||||
class UnauthorizedClientError(OAuthError):
|
||||
"""The client is not authorized to request an authorization code."""
|
||||
|
||||
def __init__(self, description: str, state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.UNAUTHORIZED_CLIENT, description, state=state)
|
||||
|
||||
|
||||
class AccessDeniedError(OAuthError):
|
||||
"""The resource owner denied the request."""
|
||||
|
||||
def __init__(self, description: str = "Access denied", state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.ACCESS_DENIED, description, state=state)
|
||||
|
||||
|
||||
class InvalidScopeError(OAuthError):
|
||||
"""The requested scope is invalid, unknown, or malformed."""
|
||||
|
||||
def __init__(self, description: str, state: Optional[str] = None):
|
||||
super().__init__(OAuthErrorCode.INVALID_SCOPE, description, state=state)
|
||||
|
||||
|
||||
class InvalidClientError(OAuthError):
|
||||
"""Client authentication failed."""
|
||||
|
||||
def __init__(self, description: str = "Invalid client"):
|
||||
super().__init__(OAuthErrorCode.INVALID_CLIENT, description)
|
||||
|
||||
|
||||
class InvalidGrantError(OAuthError):
|
||||
"""The provided authorization code or refresh token is invalid."""
|
||||
|
||||
def __init__(self, description: str = "Invalid grant"):
|
||||
super().__init__(OAuthErrorCode.INVALID_GRANT, description)
|
||||
|
||||
|
||||
class UnsupportedGrantTypeError(OAuthError):
|
||||
"""The authorization grant type is not supported."""
|
||||
|
||||
def __init__(self, grant_type: str):
|
||||
super().__init__(
|
||||
OAuthErrorCode.UNSUPPORTED_GRANT_TYPE,
|
||||
f"Grant type '{grant_type}' is not supported",
|
||||
)
|
||||
|
||||
|
||||
class LoginRequiredError(OAuthError):
|
||||
"""User must be logged in to complete the request."""
|
||||
|
||||
def __init__(self, state: Optional[str] = None):
|
||||
super().__init__(
|
||||
OAuthErrorCode.LOGIN_REQUIRED,
|
||||
"User authentication required",
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
class ConsentRequiredError(OAuthError):
|
||||
"""User consent is required for the requested scopes."""
|
||||
|
||||
def __init__(self, state: Optional[str] = None):
|
||||
super().__init__(
|
||||
OAuthErrorCode.CONSENT_REQUIRED,
|
||||
"User consent required",
|
||||
state=state,
|
||||
)
|
||||
284
autogpt_platform/backend/backend/server/oauth/models.py
Normal file
284
autogpt_platform/backend/backend/server/oauth/models.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Pydantic models for OAuth 2.0 requests and responses.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
# ============================================================
|
||||
# Enums and Constants
|
||||
# ============================================================
|
||||
|
||||
|
||||
class OAuthScope(str, Enum):
|
||||
"""Supported OAuth scopes."""
|
||||
|
||||
# OpenID Connect standard scopes
|
||||
OPENID = "openid"
|
||||
PROFILE = "profile"
|
||||
EMAIL = "email"
|
||||
|
||||
# AutoGPT-specific scopes
|
||||
INTEGRATIONS_LIST = "integrations:list"
|
||||
INTEGRATIONS_CONNECT = "integrations:connect"
|
||||
INTEGRATIONS_DELETE = "integrations:delete"
|
||||
AGENTS_EXECUTE = "agents:execute"
|
||||
|
||||
|
||||
SCOPE_DESCRIPTIONS: dict[str, str] = {
|
||||
OAuthScope.OPENID.value: "Access your user ID",
|
||||
OAuthScope.PROFILE.value: "Access your profile information (name)",
|
||||
OAuthScope.EMAIL.value: "Access your email address",
|
||||
OAuthScope.INTEGRATIONS_LIST.value: "View your connected integrations",
|
||||
OAuthScope.INTEGRATIONS_CONNECT.value: "Connect new integrations on your behalf",
|
||||
OAuthScope.INTEGRATIONS_DELETE.value: "Delete integrations on your behalf",
|
||||
OAuthScope.AGENTS_EXECUTE.value: "Run agents on your behalf",
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Authorization Request/Response Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class AuthorizationRequest(BaseModel):
|
||||
"""OAuth 2.0 Authorization Request (RFC 6749 Section 4.1.1)."""
|
||||
|
||||
response_type: Literal["code"] = Field(
|
||||
..., description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
client_id: str = Field(..., description="Client identifier")
|
||||
redirect_uri: str = Field(..., description="Redirect URI after authorization")
|
||||
scope: str = Field(default="", description="Space-separated list of scopes")
|
||||
state: str = Field(..., description="CSRF protection token (required)")
|
||||
code_challenge: str = Field(..., description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256"] = Field(
|
||||
default="S256", description="PKCE method (only S256 supported)"
|
||||
)
|
||||
nonce: Optional[str] = Field(None, description="OIDC nonce for replay protection")
|
||||
prompt: Optional[Literal["consent", "login", "none"]] = Field(
|
||||
None, description="Prompt behavior"
|
||||
)
|
||||
|
||||
|
||||
class ConsentFormData(BaseModel):
|
||||
"""Consent form submission data."""
|
||||
|
||||
consent_token: str = Field(..., description="CSRF token for consent")
|
||||
authorize: bool = Field(..., description="Whether user authorized")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Token Request/Response Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TokenRequest(BaseModel):
|
||||
"""OAuth 2.0 Token Request (RFC 6749 Section 4.1.3)."""
|
||||
|
||||
grant_type: Literal["authorization_code", "refresh_token"] = Field(
|
||||
..., description="Grant type"
|
||||
)
|
||||
code: Optional[str] = Field(
|
||||
None, description="Authorization code (for authorization_code grant)"
|
||||
)
|
||||
redirect_uri: Optional[str] = Field(
|
||||
None, description="Must match authorization request"
|
||||
)
|
||||
client_id: str = Field(..., description="Client identifier")
|
||||
client_secret: Optional[str] = Field(
|
||||
None, description="Client secret (for confidential clients)"
|
||||
)
|
||||
code_verifier: Optional[str] = Field(
|
||||
None, description="PKCE code verifier (for authorization_code grant)"
|
||||
)
|
||||
refresh_token: Optional[str] = Field(
|
||||
None, description="Refresh token (for refresh_token grant)"
|
||||
)
|
||||
scope: Optional[str] = Field(
|
||||
None, description="Requested scopes (for refresh_token grant)"
|
||||
)
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
|
||||
|
||||
access_token: str = Field(..., description="Access token")
|
||||
token_type: Literal["Bearer"] = Field(default="Bearer", description="Token type")
|
||||
expires_in: int = Field(..., description="Token lifetime in seconds")
|
||||
refresh_token: Optional[str] = Field(None, description="Refresh token")
|
||||
scope: Optional[str] = Field(None, description="Granted scopes")
|
||||
id_token: Optional[str] = Field(None, description="OIDC ID token")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# UserInfo Response Model
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UserInfoResponse(BaseModel):
|
||||
"""OIDC UserInfo Response."""
|
||||
|
||||
sub: str = Field(..., description="User ID (subject)")
|
||||
email: Optional[str] = Field(None, description="User email")
|
||||
email_verified: Optional[bool] = Field(
|
||||
None, description="Whether email is verified"
|
||||
)
|
||||
name: Optional[str] = Field(None, description="User display name")
|
||||
updated_at: Optional[int] = Field(None, description="Last profile update timestamp")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# OIDC Discovery Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class OpenIDConfiguration(BaseModel):
|
||||
"""OIDC Discovery Document."""
|
||||
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: str
|
||||
revocation_endpoint: str
|
||||
jwks_uri: str
|
||||
scopes_supported: list[str]
|
||||
response_types_supported: list[str]
|
||||
grant_types_supported: list[str]
|
||||
token_endpoint_auth_methods_supported: list[str]
|
||||
code_challenge_methods_supported: list[str]
|
||||
subject_types_supported: list[str]
|
||||
id_token_signing_alg_values_supported: list[str]
|
||||
|
||||
|
||||
class JWK(BaseModel):
|
||||
"""JSON Web Key."""
|
||||
|
||||
kty: str = Field(..., description="Key type (RSA)")
|
||||
use: str = Field(default="sig", description="Key use (signature)")
|
||||
kid: str = Field(..., description="Key ID")
|
||||
alg: str = Field(default="RS256", description="Algorithm")
|
||||
n: str = Field(..., description="RSA modulus")
|
||||
e: str = Field(..., description="RSA exponent")
|
||||
|
||||
|
||||
class JWKS(BaseModel):
|
||||
"""JSON Web Key Set."""
|
||||
|
||||
keys: list[JWK]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Client Management Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class RegisterClientRequest(BaseModel):
|
||||
"""Request to register a new OAuth client."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Client name")
|
||||
description: Optional[str] = Field(
|
||||
None, max_length=500, description="Client description"
|
||||
)
|
||||
logo_url: Optional[HttpUrl] = Field(None, description="Logo URL")
|
||||
homepage_url: Optional[HttpUrl] = Field(None, description="Homepage URL")
|
||||
privacy_policy_url: Optional[HttpUrl] = Field(
|
||||
None, description="Privacy policy URL"
|
||||
)
|
||||
terms_of_service_url: Optional[HttpUrl] = Field(
|
||||
None, description="Terms of service URL"
|
||||
)
|
||||
redirect_uris: list[str] = Field(
|
||||
..., min_length=1, description="Allowed redirect URIs"
|
||||
)
|
||||
client_type: Literal["public", "confidential"] = Field(
|
||||
default="public", description="Client type"
|
||||
)
|
||||
webhook_domains: list[str] = Field(
|
||||
default_factory=list, description="Allowed webhook domains"
|
||||
)
|
||||
|
||||
|
||||
class UpdateClientRequest(BaseModel):
|
||||
"""Request to update an OAuth client."""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
logo_url: Optional[HttpUrl] = None
|
||||
homepage_url: Optional[HttpUrl] = None
|
||||
privacy_policy_url: Optional[HttpUrl] = None
|
||||
terms_of_service_url: Optional[HttpUrl] = None
|
||||
redirect_uris: Optional[list[str]] = None
|
||||
webhook_domains: Optional[list[str]] = None
|
||||
|
||||
|
||||
class ClientResponse(BaseModel):
|
||||
"""OAuth client response."""
|
||||
|
||||
id: str
|
||||
client_id: str
|
||||
client_type: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
logo_url: Optional[str]
|
||||
homepage_url: Optional[str]
|
||||
privacy_policy_url: Optional[str]
|
||||
terms_of_service_url: Optional[str]
|
||||
redirect_uris: list[str]
|
||||
allowed_scopes: list[str]
|
||||
webhook_domains: list[str]
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ClientSecretResponse(BaseModel):
|
||||
"""Response containing newly generated client secret."""
|
||||
|
||||
client_id: str
|
||||
client_secret: str = Field(
|
||||
..., description="Client secret (only shown once, store securely)"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Token Introspection/Revocation Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TokenRevocationRequest(BaseModel):
|
||||
"""Token revocation request (RFC 7009)."""
|
||||
|
||||
token: str = Field(..., description="Token to revoke")
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
|
||||
None, description="Hint about token type"
|
||||
)
|
||||
|
||||
|
||||
class TokenIntrospectionRequest(BaseModel):
|
||||
"""Token introspection request (RFC 7662)."""
|
||||
|
||||
token: str = Field(..., description="Token to introspect")
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
|
||||
None, description="Hint about token type"
|
||||
)
|
||||
|
||||
|
||||
class TokenIntrospectionResponse(BaseModel):
|
||||
"""Token introspection response."""
|
||||
|
||||
active: bool = Field(..., description="Whether the token is active")
|
||||
scope: Optional[str] = Field(None, description="Token scopes")
|
||||
client_id: Optional[str] = Field(
|
||||
None, description="Client that token was issued to"
|
||||
)
|
||||
username: Optional[str] = Field(None, description="User identifier")
|
||||
token_type: Optional[str] = Field(None, description="Token type")
|
||||
exp: Optional[int] = Field(None, description="Expiration timestamp")
|
||||
iat: Optional[int] = Field(None, description="Issued at timestamp")
|
||||
sub: Optional[str] = Field(None, description="Subject (user ID)")
|
||||
aud: Optional[str] = Field(None, description="Audience")
|
||||
iss: Optional[str] = Field(None, description="Issuer")
|
||||
66
autogpt_platform/backend/backend/server/oauth/pkce.py
Normal file
66
autogpt_platform/backend/backend/server/oauth/pkce.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
PKCE (Proof Key for Code Exchange) implementation for OAuth 2.0.
|
||||
|
||||
RFC 7636: https://tools.ietf.org/html/rfc7636
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
|
||||
def generate_code_verifier(length: int = 64) -> str:
|
||||
"""
|
||||
Generate a cryptographically random code verifier.
|
||||
|
||||
Args:
|
||||
length: Length of the verifier (43-128 characters, default 64)
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded random string
|
||||
"""
|
||||
if not 43 <= length <= 128:
|
||||
raise ValueError("Code verifier length must be between 43 and 128")
|
||||
return secrets.token_urlsafe(length)[:length]
|
||||
|
||||
|
||||
def generate_code_challenge(verifier: str, method: str = "S256") -> str:
|
||||
"""
|
||||
Generate a code challenge from the verifier.
|
||||
|
||||
Args:
|
||||
verifier: The code verifier string
|
||||
method: Challenge method ("S256" or "plain")
|
||||
|
||||
Returns:
|
||||
The code challenge string
|
||||
"""
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
||||
# URL-safe base64 encoding without padding
|
||||
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
elif method == "plain":
|
||||
return verifier
|
||||
else:
|
||||
raise ValueError(f"Unsupported code challenge method: {method}")
|
||||
|
||||
|
||||
def verify_code_challenge(
|
||||
verifier: str,
|
||||
challenge: str,
|
||||
method: str = "S256",
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a code verifier matches the stored challenge.
|
||||
|
||||
Args:
|
||||
verifier: The code verifier from the token request
|
||||
challenge: The code challenge stored during authorization
|
||||
method: The challenge method used
|
||||
|
||||
Returns:
|
||||
True if the verifier matches the challenge
|
||||
"""
|
||||
expected = generate_code_challenge(verifier, method)
|
||||
# Use constant-time comparison to prevent timing attacks
|
||||
return secrets.compare_digest(expected, challenge)
|
||||
417
autogpt_platform/backend/backend/server/oauth/router.py
Normal file
417
autogpt_platform/backend/backend/server/oauth/router.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
OAuth 2.0 Authorization Server endpoints.
|
||||
|
||||
Implements:
|
||||
- GET /oauth/authorize - Authorization endpoint
|
||||
- POST /oauth/authorize/consent - Consent form submission
|
||||
- POST /oauth/token - Token endpoint
|
||||
- GET /oauth/userinfo - OIDC UserInfo endpoint
|
||||
- POST /oauth/revoke - Token revocation endpoint
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_optional_user_id
|
||||
from fastapi import APIRouter, Form, HTTPException, Query, Request, Security
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.server.oauth.consent_templates import (
|
||||
render_consent_page,
|
||||
render_error_page,
|
||||
render_login_redirect_page,
|
||||
)
|
||||
from backend.server.oauth.errors import (
|
||||
InvalidGrantError,
|
||||
InvalidRequestError,
|
||||
LoginRequiredError,
|
||||
OAuthError,
|
||||
UnsupportedGrantTypeError,
|
||||
)
|
||||
from backend.server.oauth.models import TokenResponse, UserInfoResponse
|
||||
from backend.server.oauth.service import get_oauth_service
|
||||
from backend.server.oauth.token_service import get_token_service
|
||||
from backend.util.settings import Settings
|
||||
|
||||
oauth_router = APIRouter(prefix="/oauth", tags=["oauth"])
|
||||
|
||||
# Consent state storage (in production, use Redis)
|
||||
_consent_states: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _parse_scopes(scope_str: str) -> list[str]:
|
||||
"""Parse space-separated scope string into list."""
|
||||
if not scope_str:
|
||||
return []
|
||||
return [s.strip() for s in scope_str.split() if s.strip()]
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Get client IP address from request."""
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Authorization Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.get("/authorize", response_model=None)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response_type: str = Query(..., description="Must be 'code'"),
|
||||
client_id: str = Query(..., description="Client identifier"),
|
||||
redirect_uri: str = Query(..., description="Redirect URI"),
|
||||
state: str = Query(..., description="CSRF state parameter"),
|
||||
code_challenge: str = Query(..., description="PKCE code challenge"),
|
||||
code_challenge_method: str = Query("S256", description="PKCE method"),
|
||||
scope: str = Query("", description="Space-separated scopes"),
|
||||
nonce: Optional[str] = Query(None, description="OIDC nonce"),
|
||||
prompt: Optional[str] = Query(None, description="Prompt behavior"),
|
||||
# User authentication (via JWT token)
|
||||
user_id: Optional[str] = Security(get_optional_user_id),
|
||||
) -> HTMLResponse | RedirectResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint.
|
||||
|
||||
Validates the request, checks user authentication, and either:
|
||||
- Redirects to login if user is not authenticated
|
||||
- Shows consent page if user hasn't authorized these scopes
|
||||
- Redirects with authorization code if already authorized
|
||||
"""
|
||||
oauth_service = get_oauth_service()
|
||||
settings = Settings()
|
||||
|
||||
try:
|
||||
# Validate response_type
|
||||
if response_type != "code":
|
||||
raise InvalidRequestError(
|
||||
"Only 'code' response_type is supported", state=state
|
||||
)
|
||||
|
||||
# Validate PKCE method
|
||||
if code_challenge_method != "S256":
|
||||
raise InvalidRequestError(
|
||||
"Only 'S256' code_challenge_method is supported", state=state
|
||||
)
|
||||
|
||||
# Parse scopes
|
||||
scopes = _parse_scopes(scope)
|
||||
|
||||
# Validate client and redirect URI
|
||||
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
|
||||
|
||||
# Check if user is authenticated
|
||||
if not user_id:
|
||||
if prompt == "none":
|
||||
# Cannot prompt, return error
|
||||
raise LoginRequiredError(state=state)
|
||||
|
||||
# Redirect to login with return URL
|
||||
login_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
return_url = str(request.url)
|
||||
login_redirect = (
|
||||
f"{login_url}/login?returnUrl={urlencode({'': return_url})[1:]}"
|
||||
)
|
||||
return HTMLResponse(render_login_redirect_page(login_redirect))
|
||||
|
||||
# Check if user has already authorized these scopes
|
||||
if prompt != "consent":
|
||||
has_auth = await oauth_service.has_valid_authorization(
|
||||
user_id, client_id, scopes
|
||||
)
|
||||
if has_auth:
|
||||
# Skip consent, issue code directly
|
||||
code = await oauth_service.create_authorization_code(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scopes=scopes,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
nonce=nonce,
|
||||
)
|
||||
redirect_url = f"{redirect_uri}?code={code}&state={state}"
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
# Generate consent token and store state
|
||||
consent_token = secrets.token_urlsafe(32)
|
||||
_consent_states[consent_token] = {
|
||||
"user_id": user_id,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scopes": scopes,
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"nonce": nonce,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
}
|
||||
|
||||
# Render consent page
|
||||
return HTMLResponse(
|
||||
render_consent_page(
|
||||
client_name=client.name,
|
||||
client_logo=client.logoUrl,
|
||||
scopes=scopes,
|
||||
consent_token=consent_token,
|
||||
action_url="/oauth/authorize/consent",
|
||||
privacy_policy_url=client.privacyPolicyUrl,
|
||||
terms_url=client.termsOfServiceUrl,
|
||||
)
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
# If we have a valid redirect_uri, redirect with error
|
||||
# Otherwise show error page
|
||||
try:
|
||||
client = await oauth_service.get_client(client_id)
|
||||
if client and redirect_uri in client.redirectUris:
|
||||
return e.to_redirect(redirect_uri)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return HTMLResponse(
|
||||
render_error_page(e.error.value, e.description or "An error occurred"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
||||
@oauth_router.post("/authorize/consent", response_model=None)
|
||||
async def submit_consent(
|
||||
request: Request,
|
||||
consent_token: str = Form(...),
|
||||
authorize: str = Form(...),
|
||||
) -> HTMLResponse | RedirectResponse:
|
||||
"""
|
||||
Process consent form submission.
|
||||
|
||||
Creates authorization code and redirects to client's redirect_uri.
|
||||
"""
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
# Validate consent token
|
||||
consent_state = _consent_states.pop(consent_token, None)
|
||||
if not consent_state:
|
||||
return HTMLResponse(
|
||||
render_error_page("invalid_request", "Invalid or expired consent token"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
if consent_state["expires_at"] < datetime.now(timezone.utc):
|
||||
return HTMLResponse(
|
||||
render_error_page("invalid_request", "Consent session expired"),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
redirect_uri = consent_state["redirect_uri"]
|
||||
state = consent_state["state"]
|
||||
|
||||
# Check if user denied
|
||||
if authorize.lower() != "true":
|
||||
error_params = urlencode(
|
||||
{
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied the authorization request",
|
||||
"state": state,
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{error_params}",
|
||||
status_code=302,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create authorization code
|
||||
code = await oauth_service.create_authorization_code(
|
||||
user_id=consent_state["user_id"],
|
||||
client_id=consent_state["client_id"],
|
||||
redirect_uri=redirect_uri,
|
||||
scopes=consent_state["scopes"],
|
||||
code_challenge=consent_state["code_challenge"],
|
||||
code_challenge_method=consent_state["code_challenge_method"],
|
||||
nonce=consent_state["nonce"],
|
||||
)
|
||||
|
||||
# Redirect with code
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?code={code}&state={state}",
|
||||
status_code=302,
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
return e.to_redirect(redirect_uri)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Token Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.post("/token", response_model=TokenResponse)
|
||||
async def token(
|
||||
request: Request,
|
||||
grant_type: str = Form(...),
|
||||
code: Optional[str] = Form(None),
|
||||
redirect_uri: Optional[str] = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: Optional[str] = Form(None),
|
||||
code_verifier: Optional[str] = Form(None),
|
||||
refresh_token: Optional[str] = Form(None),
|
||||
scope: Optional[str] = Form(None),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint.
|
||||
|
||||
Supports:
|
||||
- authorization_code grant (with PKCE)
|
||||
- refresh_token grant
|
||||
"""
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
try:
|
||||
# Validate client authentication
|
||||
await oauth_service.validate_client_secret(client_id, client_secret)
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
# Validate required parameters
|
||||
if not code:
|
||||
raise InvalidRequestError("'code' is required")
|
||||
if not redirect_uri:
|
||||
raise InvalidRequestError("'redirect_uri' is required")
|
||||
if not code_verifier:
|
||||
raise InvalidRequestError("'code_verifier' is required for PKCE")
|
||||
|
||||
return await oauth_service.exchange_authorization_code(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
elif grant_type == "refresh_token":
|
||||
if not refresh_token:
|
||||
raise InvalidRequestError("'refresh_token' is required")
|
||||
|
||||
requested_scopes = _parse_scopes(scope) if scope else None
|
||||
return await oauth_service.refresh_access_token(
|
||||
refresh_token=refresh_token,
|
||||
client_id=client_id,
|
||||
requested_scopes=requested_scopes,
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnsupportedGrantTypeError(grant_type)
|
||||
|
||||
except OAuthError as e:
|
||||
raise e.to_http_exception(400 if isinstance(e, InvalidGrantError) else 401)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# UserInfo Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.get("/userinfo", response_model=UserInfoResponse)
|
||||
async def userinfo(request: Request) -> UserInfoResponse:
|
||||
"""
|
||||
OIDC UserInfo Endpoint.
|
||||
|
||||
Returns user profile information based on the granted scopes.
|
||||
"""
|
||||
token_service = get_token_service()
|
||||
|
||||
# Extract bearer token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Bearer token required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = auth_header[7:]
|
||||
|
||||
try:
|
||||
# Verify token
|
||||
claims = token_service.verify_access_token(token)
|
||||
|
||||
# Check token is not revoked
|
||||
token_hash = token_service.hash_token(token)
|
||||
stored_token = await prisma.oauthaccesstoken.find_unique(
|
||||
where={"tokenHash": token_hash}
|
||||
)
|
||||
|
||||
if not stored_token or stored_token.revokedAt:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has been revoked",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Update last used
|
||||
await prisma.oauthaccesstoken.update(
|
||||
where={"id": stored_token.id},
|
||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Get user info based on scopes
|
||||
user = await prisma.user.find_unique(where={"id": claims.sub})
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
scopes = claims.scope.split()
|
||||
|
||||
# Build response based on scopes
|
||||
email = user.email if "email" in scopes else None
|
||||
email_verified = user.emailVerified if "email" in scopes else None
|
||||
name = user.name if "profile" in scopes else None
|
||||
updated_at = int(user.updatedAt.timestamp()) if "profile" in scopes else None
|
||||
|
||||
return UserInfoResponse(
|
||||
sub=claims.sub,
|
||||
email=email,
|
||||
email_verified=email_verified,
|
||||
name=name,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid token: {str(e)}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ================================================================
|
||||
|
||||
|
||||
@oauth_router.post("/revoke")
|
||||
async def revoke(
|
||||
request: Request,
|
||||
token: str = Form(...),
|
||||
token_type_hint: Optional[str] = Form(None),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
||||
|
||||
Revokes an access token or refresh token.
|
||||
"""
|
||||
oauth_service = get_oauth_service()
|
||||
|
||||
# Note: Per RFC 7009, always return 200 even if token not found
|
||||
await oauth_service.revoke_token(token, token_type_hint)
|
||||
|
||||
return JSONResponse(content={}, status_code=200)
|
||||
625
autogpt_platform/backend/backend/server/oauth/service.py
Normal file
625
autogpt_platform/backend/backend/server/oauth/service.py
Normal file
@@ -0,0 +1,625 @@
|
||||
"""
|
||||
Core OAuth 2.0 service logic.
|
||||
|
||||
Handles:
|
||||
- Client validation and lookup
|
||||
- Authorization code generation and exchange
|
||||
- Token issuance and refresh
|
||||
- User consent management
|
||||
- Audit logging
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from prisma.enums import OAuthClientStatus
|
||||
from prisma.models import OAuthAuthorization, OAuthClient, User
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.server.oauth.errors import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
InvalidRequestError,
|
||||
InvalidScopeError,
|
||||
)
|
||||
from backend.server.oauth.models import TokenResponse
|
||||
from backend.server.oauth.pkce import verify_code_challenge
|
||||
from backend.server.oauth.token_service import OAuthTokenService, get_token_service
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""Core OAuth 2.0 service."""
|
||||
|
||||
def __init__(self, token_service: Optional[OAuthTokenService] = None):
|
||||
self.token_service = token_service or get_token_service()
|
||||
|
||||
# ================================================================
|
||||
# Client Operations
|
||||
# ================================================================
|
||||
|
||||
async def get_client(self, client_id: str) -> Optional[OAuthClient]:
|
||||
"""Get an OAuth client by client_id."""
|
||||
return await prisma.oauthclient.find_unique(where={"clientId": client_id})
|
||||
|
||||
async def validate_client(
|
||||
self,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthClient:
|
||||
"""
|
||||
Validate a client for authorization.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
redirect_uri: Requested redirect URI
|
||||
scopes: Requested scopes
|
||||
|
||||
Returns:
|
||||
Validated OAuthClient
|
||||
|
||||
Raises:
|
||||
InvalidClientError: Client not found or inactive
|
||||
InvalidRequestError: Invalid redirect URI
|
||||
InvalidScopeError: Invalid scopes requested
|
||||
"""
|
||||
client = await self.get_client(client_id)
|
||||
|
||||
if not client:
|
||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
||||
|
||||
if client.status != OAuthClientStatus.ACTIVE:
|
||||
raise InvalidClientError(f"Client '{client_id}' is not active")
|
||||
|
||||
# Validate redirect URI (exact match required)
|
||||
if redirect_uri not in client.redirectUris:
|
||||
raise InvalidRequestError(
|
||||
f"Redirect URI '{redirect_uri}' is not registered for this client"
|
||||
)
|
||||
|
||||
# Validate scopes
|
||||
invalid_scopes = set(scopes) - set(client.allowedScopes)
|
||||
if invalid_scopes:
|
||||
raise InvalidScopeError(
|
||||
f"Scopes not allowed for this client: {', '.join(invalid_scopes)}"
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
async def validate_client_secret(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: Optional[str],
|
||||
) -> OAuthClient:
|
||||
"""
|
||||
Validate client authentication for token endpoint.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
client_secret: Client secret (for confidential clients)
|
||||
|
||||
Returns:
|
||||
Validated OAuthClient
|
||||
|
||||
Raises:
|
||||
InvalidClientError: Invalid client or credentials
|
||||
"""
|
||||
client = await self.get_client(client_id)
|
||||
|
||||
if not client:
|
||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
||||
|
||||
if client.status != OAuthClientStatus.ACTIVE:
|
||||
raise InvalidClientError(f"Client '{client_id}' is not active")
|
||||
|
||||
# Confidential clients must provide secret
|
||||
if client.clientType == "confidential":
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required")
|
||||
|
||||
# Hash and compare
|
||||
secret_hash = self._hash_secret(
|
||||
client_secret, client.clientSecretSalt or ""
|
||||
)
|
||||
if not secrets.compare_digest(secret_hash, client.clientSecretHash or ""):
|
||||
raise InvalidClientError("Invalid client credentials")
|
||||
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def _hash_secret(secret: str, salt: str) -> str:
|
||||
"""Hash a client secret with salt."""
|
||||
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
|
||||
|
||||
# ================================================================
|
||||
# Authorization Code Operations
|
||||
# ================================================================
|
||||
|
||||
async def create_authorization_code(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scopes: list[str],
|
||||
code_challenge: str,
|
||||
code_challenge_method: str = "S256",
|
||||
nonce: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new authorization code.
|
||||
|
||||
Args:
|
||||
user_id: User who authorized
|
||||
client_id: Client being authorized
|
||||
redirect_uri: Redirect URI for callback
|
||||
scopes: Granted scopes
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method (S256)
|
||||
nonce: OIDC nonce (optional)
|
||||
|
||||
Returns:
|
||||
Authorization code string
|
||||
"""
|
||||
code = secrets.token_urlsafe(32)
|
||||
code_hash = self.token_service.hash_token(code)
|
||||
|
||||
# Get the OAuthClient to link
|
||||
client = await self.get_client(client_id)
|
||||
if not client:
|
||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
||||
|
||||
await prisma.oauthauthorizationcode.create(
|
||||
data={
|
||||
"codeHash": code_hash,
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
"redirectUri": redirect_uri,
|
||||
"scopes": scopes,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
"nonce": nonce,
|
||||
"expiresAt": datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
}
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Exchange an authorization code for tokens.
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
client_id: Client identifier
|
||||
redirect_uri: Must match original redirect URI
|
||||
code_verifier: PKCE code verifier
|
||||
|
||||
Returns:
|
||||
TokenResponse with access token, refresh token, etc.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: Invalid or expired code
|
||||
InvalidRequestError: PKCE verification failed
|
||||
"""
|
||||
code_hash = self.token_service.hash_token(code)
|
||||
|
||||
# Find the authorization code
|
||||
auth_code = await prisma.oauthauthorizationcode.find_unique(
|
||||
where={"codeHash": code_hash},
|
||||
include={"Client": True, "User": True},
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
raise InvalidGrantError("Authorization code not found")
|
||||
|
||||
# Ensure Client relation is loaded
|
||||
if not auth_code.Client:
|
||||
raise InvalidGrantError("Authorization code client not found")
|
||||
|
||||
# Check if already used
|
||||
if auth_code.usedAt:
|
||||
# Code reuse is a security incident - revoke all tokens for this authorization
|
||||
await self._revoke_tokens_for_client_user(
|
||||
auth_code.Client.clientId, auth_code.userId
|
||||
)
|
||||
raise InvalidGrantError("Authorization code has already been used")
|
||||
|
||||
# Check expiration
|
||||
if auth_code.expiresAt < datetime.now(timezone.utc):
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
|
||||
# Validate client
|
||||
if auth_code.Client.clientId != client_id:
|
||||
raise InvalidGrantError("Client ID mismatch")
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirectUri != redirect_uri:
|
||||
raise InvalidGrantError("Redirect URI mismatch")
|
||||
|
||||
# Verify PKCE
|
||||
if not verify_code_challenge(
|
||||
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
|
||||
):
|
||||
raise InvalidRequestError("PKCE verification failed")
|
||||
|
||||
# Mark code as used
|
||||
await prisma.oauthauthorizationcode.update(
|
||||
where={"id": auth_code.id},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Create or update authorization record
|
||||
await self._upsert_authorization(
|
||||
auth_code.userId, auth_code.Client.id, auth_code.scopes
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
return await self._create_tokens(
|
||||
user_id=auth_code.userId,
|
||||
client=auth_code.Client,
|
||||
scopes=auth_code.scopes,
|
||||
nonce=auth_code.nonce,
|
||||
user=auth_code.User,
|
||||
)
|
||||
|
||||
async def refresh_access_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
requested_scopes: Optional[list[str]] = None,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Refresh an access token using a refresh token.
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token string
|
||||
client_id: Client identifier
|
||||
requested_scopes: Optionally request fewer scopes
|
||||
|
||||
Returns:
|
||||
New TokenResponse
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: Invalid or expired refresh token
|
||||
"""
|
||||
token_hash = self.token_service.hash_token(refresh_token)
|
||||
|
||||
# Find the refresh token
|
||||
stored_token = await prisma.oauthrefreshtoken.find_unique(
|
||||
where={"tokenHash": token_hash},
|
||||
include={"Client": True, "User": True},
|
||||
)
|
||||
|
||||
if not stored_token:
|
||||
raise InvalidGrantError("Refresh token not found")
|
||||
|
||||
# Ensure Client relation is loaded
|
||||
if not stored_token.Client:
|
||||
raise InvalidGrantError("Refresh token client not found")
|
||||
|
||||
# Check if revoked
|
||||
if stored_token.revokedAt:
|
||||
raise InvalidGrantError("Refresh token has been revoked")
|
||||
|
||||
# Check expiration
|
||||
if stored_token.expiresAt < datetime.now(timezone.utc):
|
||||
raise InvalidGrantError("Refresh token has expired")
|
||||
|
||||
# Validate client
|
||||
if stored_token.Client.clientId != client_id:
|
||||
raise InvalidGrantError("Client ID mismatch")
|
||||
|
||||
# Determine scopes
|
||||
scopes = stored_token.scopes
|
||||
if requested_scopes:
|
||||
# Can only request a subset of original scopes
|
||||
invalid = set(requested_scopes) - set(stored_token.scopes)
|
||||
if invalid:
|
||||
raise InvalidScopeError(
|
||||
f"Cannot request scopes not in original grant: {', '.join(invalid)}"
|
||||
)
|
||||
scopes = requested_scopes
|
||||
|
||||
# Generate new tokens (rotates refresh token)
|
||||
return await self._create_tokens(
|
||||
user_id=stored_token.userId,
|
||||
client=stored_token.Client,
|
||||
scopes=scopes,
|
||||
user=stored_token.User,
|
||||
old_refresh_token_id=stored_token.id,
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# Token Operations
|
||||
# ================================================================
|
||||
|
||||
async def _create_tokens(
|
||||
self,
|
||||
user_id: str,
|
||||
client: OAuthClient,
|
||||
scopes: list[str],
|
||||
user: Optional[User] = None,
|
||||
nonce: Optional[str] = None,
|
||||
old_refresh_token_id: Optional[str] = None,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Create access and refresh tokens.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client: OAuth client
|
||||
scopes: Granted scopes
|
||||
user: User object (for ID token claims)
|
||||
nonce: OIDC nonce
|
||||
old_refresh_token_id: ID of refresh token being rotated
|
||||
|
||||
Returns:
|
||||
TokenResponse
|
||||
"""
|
||||
# Generate access token
|
||||
access_token, access_expires_at = self.token_service.generate_access_token(
|
||||
user_id=user_id,
|
||||
client_id=client.clientId,
|
||||
scopes=scopes,
|
||||
expires_in=client.tokenLifetimeSecs,
|
||||
)
|
||||
|
||||
# Store access token hash
|
||||
await prisma.oauthaccesstoken.create(
|
||||
data={
|
||||
"tokenHash": self.token_service.hash_token(access_token),
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
"scopes": scopes,
|
||||
"expiresAt": access_expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
# Generate refresh token
|
||||
refresh_token = self.token_service.generate_refresh_token()
|
||||
refresh_expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=client.refreshTokenLifetimeSecs
|
||||
)
|
||||
|
||||
await prisma.oauthrefreshtoken.create(
|
||||
data={
|
||||
"tokenHash": self.token_service.hash_token(refresh_token),
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
"scopes": scopes,
|
||||
"expiresAt": refresh_expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
# Revoke old refresh token if rotating
|
||||
if old_refresh_token_id:
|
||||
await prisma.oauthrefreshtoken.update(
|
||||
where={"id": old_refresh_token_id},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Generate ID token if openid scope requested
|
||||
id_token = None
|
||||
if "openid" in scopes and user:
|
||||
email = user.email if "email" in scopes else None
|
||||
name = user.name if "profile" in scopes else None
|
||||
id_token = self.token_service.generate_id_token(
|
||||
user_id=user_id,
|
||||
client_id=client.clientId,
|
||||
email=email,
|
||||
name=name,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
# Audit log
|
||||
await self._audit_log(
|
||||
event_type="token.issued",
|
||||
user_id=user_id,
|
||||
client_id=client.clientId,
|
||||
details={"scopes": scopes},
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=client.tokenLifetimeSecs,
|
||||
refresh_token=refresh_token,
|
||||
scope=" ".join(scopes),
|
||||
id_token=id_token,
|
||||
)
|
||||
|
||||
async def revoke_token(
|
||||
self,
|
||||
token: str,
|
||||
token_type_hint: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke an access or refresh token.
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
token_type_hint: Hint about token type
|
||||
|
||||
Returns:
|
||||
True if token was found and revoked
|
||||
"""
|
||||
token_hash = self.token_service.hash_token(token)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Try refresh token first if hinted or no hint
|
||||
if token_type_hint in (None, "refresh_token"):
|
||||
result = await prisma.oauthrefreshtoken.update_many(
|
||||
where={"tokenHash": token_hash, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
if result > 0:
|
||||
return True
|
||||
|
||||
# Try access token
|
||||
if token_type_hint in (None, "access_token"):
|
||||
result = await prisma.oauthaccesstoken.update_many(
|
||||
where={"tokenHash": token_hash, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
if result > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _revoke_tokens_for_client_user(
|
||||
self,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Revoke all tokens for a client-user pair (security incident response)."""
|
||||
client = await self.get_client(client_id)
|
||||
if not client:
|
||||
return
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
await prisma.oauthaccesstoken.update_many(
|
||||
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
|
||||
await prisma.oauthrefreshtoken.update_many(
|
||||
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
|
||||
await self._audit_log(
|
||||
event_type="tokens.revoked.security",
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
details={"reason": "authorization_code_reuse"},
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# Authorization (Consent) Operations
|
||||
# ================================================================
|
||||
|
||||
async def get_authorization(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
) -> Optional[OAuthAuthorization]:
|
||||
"""Get existing authorization for user-client pair."""
|
||||
client = await self.get_client(client_id)
|
||||
if not client:
|
||||
return None
|
||||
|
||||
return await prisma.oauthauthorization.find_unique(
|
||||
where={
|
||||
"userId_clientId": {
|
||||
"userId": user_id,
|
||||
"clientId": client.id,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def has_valid_authorization(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has already authorized these scopes for this client.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
client_id: Client identifier
|
||||
scopes: Requested scopes
|
||||
|
||||
Returns:
|
||||
True if user has already authorized all requested scopes
|
||||
"""
|
||||
auth = await self.get_authorization(user_id, client_id)
|
||||
if not auth or auth.revokedAt:
|
||||
return False
|
||||
|
||||
# Check if all requested scopes are already authorized
|
||||
return set(scopes).issubset(set(auth.scopes))
|
||||
|
||||
async def _upsert_authorization(
|
||||
self,
|
||||
user_id: str,
|
||||
client_db_id: str,
|
||||
scopes: list[str],
|
||||
) -> None:
|
||||
"""Create or update an authorization record."""
|
||||
existing = await prisma.oauthauthorization.find_unique(
|
||||
where={
|
||||
"userId_clientId": {
|
||||
"userId": user_id,
|
||||
"clientId": client_db_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if existing:
|
||||
# Merge scopes
|
||||
merged_scopes = list(set(existing.scopes) | set(scopes))
|
||||
await prisma.oauthauthorization.update(
|
||||
where={"id": existing.id},
|
||||
data={"scopes": merged_scopes, "revokedAt": None},
|
||||
)
|
||||
else:
|
||||
await prisma.oauthauthorization.create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"clientId": client_db_id,
|
||||
"scopes": scopes,
|
||||
}
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# Audit Logging
|
||||
# ================================================================
|
||||
|
||||
async def _audit_log(
|
||||
self,
|
||||
event_type: str,
|
||||
user_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
grant_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
details: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Create an audit log entry."""
|
||||
# Convert details to JSON for Prisma's Json field
|
||||
details_json = json.dumps(details or {})
|
||||
await prisma.oauthauditlog.create(
|
||||
data={
|
||||
"eventType": event_type,
|
||||
"userId": user_id,
|
||||
"clientId": client_id,
|
||||
"grantId": grant_id,
|
||||
"ipAddress": ip_address,
|
||||
"userAgent": user_agent,
|
||||
"details": json.loads(details_json), # type: ignore[arg-type]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_oauth_service: Optional[OAuthService] = None
|
||||
|
||||
|
||||
def get_oauth_service() -> OAuthService:
|
||||
"""Get the singleton OAuth service instance."""
|
||||
global _oauth_service
|
||||
if _oauth_service is None:
|
||||
_oauth_service = OAuthService()
|
||||
return _oauth_service
|
||||
298
autogpt_platform/backend/backend/server/oauth/token_service.py
Normal file
298
autogpt_platform/backend/backend/server/oauth/token_service.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
JWT Token Service for OAuth 2.0 Provider.
|
||||
|
||||
Handles generation and validation of:
|
||||
- Access tokens (JWT)
|
||||
- Refresh tokens (opaque)
|
||||
- ID tokens (JWT, OIDC)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
||||
RSAPrivateKey,
|
||||
RSAPublicKey,
|
||||
generate_private_key,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
class TokenClaims(BaseModel):
|
||||
"""Decoded token claims."""
|
||||
|
||||
iss: str # Issuer
|
||||
sub: str # Subject (user ID)
|
||||
aud: str # Audience (client ID)
|
||||
exp: int # Expiration timestamp
|
||||
iat: int # Issued at timestamp
|
||||
jti: str # JWT ID
|
||||
scope: str # Space-separated scopes
|
||||
client_id: str # Client ID
|
||||
|
||||
|
||||
class OAuthTokenService:
|
||||
"""
|
||||
Service for generating and validating OAuth tokens.
|
||||
|
||||
Uses RS256 (RSA with SHA-256) for JWT signing.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self._settings = settings or Settings()
|
||||
self._private_key: Optional[RSAPrivateKey] = None
|
||||
self._public_key: Optional[RSAPublicKey] = None
|
||||
self._algorithm = "RS256"
|
||||
|
||||
@property
|
||||
def issuer(self) -> str:
|
||||
"""Get the token issuer URL."""
|
||||
return self._settings.config.platform_base_url or "https://platform.agpt.co"
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
"""Get the key ID for JWKS."""
|
||||
return self._settings.secrets.oauth_jwt_key_id or "default-key-id"
|
||||
|
||||
def _get_private_key(self) -> RSAPrivateKey:
|
||||
"""Load or generate the private key."""
|
||||
if self._private_key is not None:
|
||||
return self._private_key
|
||||
|
||||
key_pem = self._settings.secrets.oauth_jwt_private_key
|
||||
if key_pem:
|
||||
loaded_key = serialization.load_pem_private_key(
|
||||
key_pem.encode(), password=None
|
||||
)
|
||||
if not isinstance(loaded_key, RSAPrivateKey):
|
||||
raise ValueError("OAuth JWT private key must be RSA")
|
||||
self._private_key = loaded_key
|
||||
else:
|
||||
# Generate a key for development (should not be used in production)
|
||||
self._private_key = generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
return self._private_key
|
||||
|
||||
def _get_public_key(self) -> RSAPublicKey:
|
||||
"""Get the public key from the private key."""
|
||||
if self._public_key is not None:
|
||||
return self._public_key
|
||||
|
||||
key_pem = self._settings.secrets.oauth_jwt_public_key
|
||||
if key_pem:
|
||||
loaded_key = serialization.load_pem_public_key(key_pem.encode())
|
||||
if not isinstance(loaded_key, RSAPublicKey):
|
||||
raise ValueError("OAuth JWT public key must be RSA")
|
||||
self._public_key = loaded_key
|
||||
else:
|
||||
self._public_key = self._get_private_key().public_key()
|
||||
return self._public_key
|
||||
|
||||
def generate_access_token(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
expires_in: int = 3600,
|
||||
) -> tuple[str, datetime]:
|
||||
"""
|
||||
Generate a JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: User ID (subject)
|
||||
client_id: Client ID (audience)
|
||||
scopes: List of granted scopes
|
||||
expires_in: Token lifetime in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (token string, expiration datetime)
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=expires_in)
|
||||
|
||||
payload = {
|
||||
"iss": self.issuer,
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int(expires_at.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": secrets.token_urlsafe(16),
|
||||
"scope": " ".join(scopes),
|
||||
"client_id": client_id,
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._get_private_key(),
|
||||
algorithm=self._algorithm,
|
||||
headers={"kid": self.key_id},
|
||||
)
|
||||
return token, expires_at
|
||||
|
||||
def generate_refresh_token(self) -> str:
|
||||
"""
|
||||
Generate an opaque refresh token.
|
||||
|
||||
Returns:
|
||||
URL-safe random token string
|
||||
"""
|
||||
return secrets.token_urlsafe(48)
|
||||
|
||||
def generate_id_token(
|
||||
self,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
email: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
nonce: Optional[str] = None,
|
||||
expires_in: int = 3600,
|
||||
) -> str:
|
||||
"""
|
||||
Generate an OIDC ID token.
|
||||
|
||||
Args:
|
||||
user_id: User ID (subject)
|
||||
client_id: Client ID (audience)
|
||||
email: User's email (optional)
|
||||
name: User's name (optional)
|
||||
nonce: OIDC nonce for replay protection (optional)
|
||||
expires_in: Token lifetime in seconds
|
||||
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=expires_in)
|
||||
|
||||
payload = {
|
||||
"iss": self.issuer,
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int(expires_at.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"auth_time": int(now.timestamp()),
|
||||
}
|
||||
|
||||
if email:
|
||||
payload["email"] = email
|
||||
payload["email_verified"] = True
|
||||
if name:
|
||||
payload["name"] = name
|
||||
if nonce:
|
||||
payload["nonce"] = nonce
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
self._get_private_key(),
|
||||
algorithm=self._algorithm,
|
||||
headers={"kid": self.key_id},
|
||||
)
|
||||
|
||||
def verify_access_token(
|
||||
self,
|
||||
token: str,
|
||||
expected_client_id: Optional[str] = None,
|
||||
) -> TokenClaims:
|
||||
"""
|
||||
Verify and decode a JWT access token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
expected_client_id: Expected client ID (audience)
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
|
||||
Raises:
|
||||
jwt.ExpiredSignatureError: Token has expired
|
||||
jwt.InvalidTokenError: Token is invalid
|
||||
"""
|
||||
options = {}
|
||||
if expected_client_id:
|
||||
options["audience"] = expected_client_id
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self._get_public_key(),
|
||||
algorithms=[self._algorithm],
|
||||
issuer=self.issuer,
|
||||
options={"verify_aud": bool(expected_client_id)},
|
||||
**options,
|
||||
)
|
||||
|
||||
return TokenClaims(
|
||||
iss=payload["iss"],
|
||||
sub=payload["sub"],
|
||||
aud=payload.get("aud", payload.get("client_id", "")),
|
||||
exp=payload["exp"],
|
||||
iat=payload["iat"],
|
||||
jti=payload["jti"],
|
||||
scope=payload.get("scope", ""),
|
||||
client_id=payload.get("client_id", payload.get("aud", "")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hash_token(token: str) -> str:
|
||||
"""
|
||||
Hash a token for secure storage.
|
||||
|
||||
Args:
|
||||
token: Token string to hash
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of the token
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
def get_jwks(self) -> dict:
|
||||
"""
|
||||
Get the JSON Web Key Set (JWKS) for public key distribution.
|
||||
|
||||
Returns:
|
||||
JWKS dictionary with public key(s)
|
||||
"""
|
||||
public_key = self._get_public_key()
|
||||
public_numbers = public_key.public_numbers()
|
||||
|
||||
# Convert to base64url encoding without padding
|
||||
def int_to_base64url(n: int, length: int) -> str:
|
||||
data = n.to_bytes(length, byteorder="big")
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
# RSA modulus and exponent
|
||||
n = int_to_base64url(public_numbers.n, (public_numbers.n.bit_length() + 7) // 8)
|
||||
e = int_to_base64url(public_numbers.e, 3)
|
||||
|
||||
return {
|
||||
"keys": [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"kid": self.key_id,
|
||||
"alg": self._algorithm,
|
||||
"n": n,
|
||||
"e": e,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_token_service: Optional[OAuthTokenService] = None
|
||||
|
||||
|
||||
def get_token_service() -> OAuthTokenService:
|
||||
"""Get the singleton token service instance."""
|
||||
global _token_service
|
||||
if _token_service is None:
|
||||
_token_service = OAuthTokenService()
|
||||
return _token_service
|
||||
@@ -21,6 +21,7 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.integrations.connect_router
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
@@ -44,6 +45,7 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.server.oauth import client_router, discovery_router, oauth_router
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
@@ -300,6 +302,18 @@ app.include_router(
|
||||
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
# OAuth Provider routes
|
||||
app.include_router(oauth_router, tags=["oauth"], prefix="")
|
||||
app.include_router(discovery_router, tags=["oidc-discovery"], prefix="")
|
||||
app.include_router(client_router, tags=["oauth-clients"], prefix="")
|
||||
|
||||
# Integration Connect popup routes (for Credential Broker)
|
||||
app.include_router(
|
||||
backend.server.integrations.connect_router.connect_router,
|
||||
tags=["integration-connect"],
|
||||
prefix="",
|
||||
)
|
||||
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
async def health():
|
||||
|
||||
227
autogpt_platform/backend/backend/util/rate_limiter.py
Normal file
227
autogpt_platform/backend/backend/util/rate_limiter.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Rate Limiting for External API.
|
||||
|
||||
Implements sliding window rate limiting using Redis for distributed systems.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
remaining: int
|
||||
reset_at: float
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Redis-based sliding window rate limiter.
|
||||
|
||||
Supports multiple limit tiers (per-minute, per-hour, per-day).
|
||||
"""
|
||||
|
||||
def __init__(self, prefix: str = "ratelimit"):
|
||||
self.prefix = prefix
|
||||
|
||||
def _make_key(self, identifier: str, window: str) -> str:
|
||||
"""Create a Redis key for the rate limit counter."""
|
||||
return f"{self.prefix}:{identifier}:{window}"
|
||||
|
||||
async def check_and_increment(
|
||||
self,
|
||||
identifier: str,
|
||||
limits: dict[str, tuple[int, int]], # window_name -> (limit, window_seconds)
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check rate limits and increment counters if allowed.
|
||||
|
||||
Uses atomic increment-first approach to prevent race conditions:
|
||||
1. Increment all counters atomically
|
||||
2. Check if any limit exceeded
|
||||
3. If exceeded, decrement and return rate limit error
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier (e.g., client_id, client_id:user_id)
|
||||
limits: Dictionary of limit configurations
|
||||
e.g., {"minute": (60, 60), "hour": (1000, 3600)}
|
||||
|
||||
Returns:
|
||||
RateLimitResult with allowed status and remaining quota
|
||||
"""
|
||||
if not limits:
|
||||
# No limits configured, allow request
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining=999999,
|
||||
reset_at=time.time() + 60,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
current_time = time.time()
|
||||
|
||||
# Increment all counters atomically first
|
||||
incremented_keys: list[tuple[str, int, int, int]] = (
|
||||
[]
|
||||
) # (key, new_count, limit, window_seconds)
|
||||
|
||||
for window_name, (limit, window_seconds) in limits.items():
|
||||
key = self._make_key(identifier, window_name)
|
||||
|
||||
# Atomic increment
|
||||
new_count = await redis.incr(key)
|
||||
|
||||
# Set expiry if this is a new key
|
||||
if new_count == 1:
|
||||
await redis.expire(key, window_seconds)
|
||||
|
||||
incremented_keys.append((key, new_count, limit, window_seconds))
|
||||
|
||||
# Check if any limit exceeded
|
||||
for key, new_count, limit, window_seconds in incremented_keys:
|
||||
if new_count > limit:
|
||||
# Rate limit exceeded - decrement all counters we just incremented
|
||||
for decr_key, _, _, _ in incremented_keys:
|
||||
await redis.decr(decr_key)
|
||||
|
||||
ttl = await redis.ttl(key)
|
||||
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_at=reset_at,
|
||||
retry_after=ttl if ttl > 0 else window_seconds,
|
||||
)
|
||||
|
||||
# All limits passed
|
||||
min_remaining = float("inf")
|
||||
earliest_reset = current_time
|
||||
|
||||
for key, new_count, limit, window_seconds in incremented_keys:
|
||||
remaining = max(0, limit - new_count)
|
||||
min_remaining = min(min_remaining, remaining)
|
||||
|
||||
ttl = await redis.ttl(key)
|
||||
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
|
||||
earliest_reset = max(earliest_reset, reset_at)
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining=int(min_remaining),
|
||||
reset_at=earliest_reset,
|
||||
)
|
||||
|
||||
async def get_remaining(
|
||||
self,
|
||||
identifier: str,
|
||||
limits: dict[str, tuple[int, int]],
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Get remaining quota for all windows without incrementing.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
limits: Dictionary of limit configurations
|
||||
|
||||
Returns:
|
||||
Dictionary of remaining quota per window
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
remaining = {}
|
||||
|
||||
for window_name, (limit, _) in limits.items():
|
||||
key = self._make_key(identifier, window_name)
|
||||
count = await redis.get(key)
|
||||
current_count = int(count) if count else 0
|
||||
remaining[window_name] = max(0, limit - current_count)
|
||||
|
||||
return remaining
|
||||
|
||||
async def reset(self, identifier: str, window: Optional[str] = None) -> None:
|
||||
"""
|
||||
Reset rate limit counters.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
window: Optional specific window to reset (resets all if None)
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
|
||||
if window:
|
||||
key = self._make_key(identifier, window)
|
||||
await redis.delete(key)
|
||||
else:
|
||||
# Delete known window keys instead of scanning
|
||||
# This avoids potentially slow scan operations with many keys
|
||||
known_windows = ["minute", "hour", "day"]
|
||||
keys_to_delete = [self._make_key(identifier, w) for w in known_windows]
|
||||
# Delete all in one call (Redis handles non-existent keys gracefully)
|
||||
if keys_to_delete:
|
||||
await redis.delete(*keys_to_delete)
|
||||
|
||||
|
||||
# Default rate limits for different endpoints
|
||||
DEFAULT_RATE_LIMITS = {
|
||||
# OAuth endpoints
|
||||
"oauth_authorize": {"minute": (30, 60)}, # 30/min per IP
|
||||
"oauth_token": {"minute": (20, 60)}, # 20/min per client
|
||||
# External API endpoints
|
||||
"api_execute": {
|
||||
"minute": (10, 60),
|
||||
"hour": (100, 3600),
|
||||
}, # 10/min, 100/hour per client+user
|
||||
"api_read": {
|
||||
"minute": (60, 60),
|
||||
"hour": (1000, 3600),
|
||||
}, # 60/min, 1000/hour per client+user
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get the singleton rate limiter instance."""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
identifier: str,
|
||||
limit_type: str,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Convenience function to check rate limits.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier for the rate limit
|
||||
limit_type: Type of limit from DEFAULT_RATE_LIMITS
|
||||
|
||||
Returns:
|
||||
RateLimitResult
|
||||
"""
|
||||
limits = DEFAULT_RATE_LIMITS.get(limit_type)
|
||||
if not limits:
|
||||
# No rate limit configured, allow
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining=999999,
|
||||
reset_at=time.time() + 60,
|
||||
)
|
||||
|
||||
rate_limiter = get_rate_limiter()
|
||||
return await rate_limiter.check_and_increment(identifier, limits)
|
||||
@@ -651,6 +651,23 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
|
||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||
|
||||
# OAuth Provider JWT keys
|
||||
oauth_jwt_private_key: str = Field(
|
||||
default="",
|
||||
description="RSA private key for signing OAuth tokens (PEM format). "
|
||||
"If not set, a development key will be auto-generated.",
|
||||
)
|
||||
oauth_jwt_public_key: str = Field(
|
||||
default="",
|
||||
description="RSA public key for verifying OAuth tokens (PEM format). "
|
||||
"If not set, derived from private key.",
|
||||
)
|
||||
oauth_jwt_key_id: str = Field(
|
||||
default="autogpt-oauth-key-1",
|
||||
description="Key ID (kid) for JWKS. Used to identify the signing key.",
|
||||
)
|
||||
|
||||
# Add more secret fields as needed
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
43
autogpt_platform/backend/backend/util/time.py
Normal file
43
autogpt_platform/backend/backend/util/time.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Time utilities for the backend.
|
||||
|
||||
Common datetime operations used across the codebase.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
|
||||
def expiration_datetime(seconds: int) -> datetime:
|
||||
"""
|
||||
Calculate an expiration datetime from now.
|
||||
|
||||
Args:
|
||||
seconds: Number of seconds until expiration
|
||||
|
||||
Returns:
|
||||
Datetime when the item will expire (UTC)
|
||||
"""
|
||||
return datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
|
||||
|
||||
def is_expired(dt: datetime) -> bool:
|
||||
"""
|
||||
Check if a datetime has passed.
|
||||
|
||||
Args:
|
||||
dt: The datetime to check (should be timezone-aware)
|
||||
|
||||
Returns:
|
||||
True if the datetime is in the past
|
||||
"""
|
||||
return dt < datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""
|
||||
Get the current UTC time.
|
||||
|
||||
Returns:
|
||||
Current datetime in UTC
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
46
autogpt_platform/backend/backend/util/url.py
Normal file
46
autogpt_platform/backend/backend/util/url.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
URL and domain validation utilities.
|
||||
|
||||
Common URL validation operations used across the codebase.
|
||||
"""
|
||||
|
||||
|
||||
def matches_domain_pattern(hostname: str, domain_pattern: str) -> bool:
|
||||
"""
|
||||
Check if a hostname matches a domain pattern.
|
||||
|
||||
Supports wildcard patterns (*.example.com) which match:
|
||||
- The base domain (example.com)
|
||||
- Any subdomain (sub.example.com, deep.sub.example.com)
|
||||
|
||||
Args:
|
||||
hostname: The hostname to check (e.g., "api.example.com")
|
||||
domain_pattern: The pattern to match against (e.g., "*.example.com" or "example.com")
|
||||
|
||||
Returns:
|
||||
True if the hostname matches the pattern
|
||||
"""
|
||||
hostname = hostname.lower()
|
||||
domain_pattern = domain_pattern.lower()
|
||||
|
||||
if domain_pattern.startswith("*."):
|
||||
# Wildcard domain - matches base and any subdomains
|
||||
base_domain = domain_pattern[2:]
|
||||
return hostname == base_domain or hostname.endswith("." + base_domain)
|
||||
|
||||
# Exact match
|
||||
return hostname == domain_pattern
|
||||
|
||||
|
||||
def hostname_matches_any_domain(hostname: str, allowed_domains: list[str]) -> bool:
|
||||
"""
|
||||
Check if a hostname matches any of the allowed domain patterns.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to check
|
||||
allowed_domains: List of allowed domain patterns (supports wildcards)
|
||||
|
||||
Returns:
|
||||
True if the hostname matches any pattern
|
||||
"""
|
||||
return any(matches_domain_pattern(hostname, domain) for domain in allowed_domains)
|
||||
@@ -60,6 +60,14 @@ model User {
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthClientsOwned OAuthClient[] @relation("OAuthClientOwner")
|
||||
OAuthAuthorizations OAuthAuthorization[]
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
OAuthAccessTokens OAuthAccessToken[]
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
CredentialGrants CredentialGrant[]
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
@@ -961,3 +969,225 @@ enum APIKeyStatus {
|
||||
REVOKED
|
||||
SUSPENDED
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OAuth Provider & Credential Broker Models
|
||||
// ============================================================
|
||||
|
||||
enum OAuthClientStatus {
|
||||
ACTIVE
|
||||
SUSPENDED
|
||||
}
|
||||
|
||||
enum CredentialGrantPermission {
|
||||
USE // Can use credential for agent execution
|
||||
DELETE // Can delete the credential
|
||||
}
|
||||
|
||||
// OAuth Client - Registered external applications
|
||||
model OAuthClient {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Client identification
|
||||
clientId String @unique // Public identifier (e.g., "app_abc123")
|
||||
clientSecretHash String? // Hashed (null for public clients)
|
||||
clientSecretSalt String?
|
||||
clientType String // "public" or "confidential"
|
||||
|
||||
// Metadata (shown on consent screen)
|
||||
name String
|
||||
description String?
|
||||
logoUrl String?
|
||||
homepageUrl String?
|
||||
privacyPolicyUrl String?
|
||||
termsOfServiceUrl String?
|
||||
|
||||
// Configuration
|
||||
redirectUris String[]
|
||||
allowedScopes String[]
|
||||
webhookDomains String[] // For webhook URL validation
|
||||
|
||||
// Security
|
||||
requirePkce Boolean @default(true)
|
||||
tokenLifetimeSecs Int @default(3600)
|
||||
refreshTokenLifetimeSecs Int @default(2592000) // 30 days
|
||||
|
||||
// Status
|
||||
status OAuthClientStatus @default(ACTIVE)
|
||||
|
||||
// Owner
|
||||
ownerId String
|
||||
Owner User @relation("OAuthClientOwner", fields: [ownerId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Relations
|
||||
Authorizations OAuthAuthorization[]
|
||||
AuthorizationCodes OAuthAuthorizationCode[]
|
||||
AccessTokens OAuthAccessToken[]
|
||||
RefreshTokens OAuthRefreshToken[]
|
||||
CredentialGrants CredentialGrant[]
|
||||
|
||||
@@index([clientId])
|
||||
@@index([ownerId])
|
||||
@@index([status])
|
||||
}
|
||||
|
||||
// OAuth Authorization - User consent record
|
||||
model OAuthAuthorization {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes String[]
|
||||
revokedAt DateTime?
|
||||
|
||||
@@unique([userId, clientId])
|
||||
@@index([userId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
// OAuth Authorization Code - Short-lived, single-use
|
||||
model OAuthAuthorizationCode {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
codeHash String @unique
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
redirectUri String
|
||||
scopes String[]
|
||||
nonce String? // OIDC nonce
|
||||
|
||||
// PKCE
|
||||
codeChallenge String
|
||||
codeChallengeMethod String @default("S256")
|
||||
|
||||
expiresAt DateTime // 10 minutes
|
||||
usedAt DateTime?
|
||||
|
||||
@@index([codeHash])
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
// OAuth Access Token
|
||||
model OAuthAccessToken {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
tokenHash String @unique // SHA256 of token
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes String[]
|
||||
expiresAt DateTime
|
||||
revokedAt DateTime?
|
||||
lastUsedAt DateTime?
|
||||
|
||||
@@index([tokenHash])
|
||||
@@index([userId, clientId])
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
// OAuth Refresh Token
|
||||
model OAuthRefreshToken {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
tokenHash String @unique
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
scopes String[]
|
||||
expiresAt DateTime
|
||||
revokedAt DateTime?
|
||||
|
||||
@@index([tokenHash])
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
// Credential Grant - Links external app to user's credential with scoped access
|
||||
model CredentialGrant {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
clientId String
|
||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
||||
|
||||
credentialId String // Reference to credential in User.integrations
|
||||
provider String
|
||||
|
||||
// Fine-grained integration scopes (e.g., "google:gmail.readonly")
|
||||
grantedScopes String[]
|
||||
|
||||
// Permissions for the credential itself
|
||||
permissions CredentialGrantPermission[]
|
||||
|
||||
expiresAt DateTime?
|
||||
revokedAt DateTime?
|
||||
lastUsedAt DateTime?
|
||||
|
||||
@@unique([userId, clientId, credentialId])
|
||||
@@index([userId, clientId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
// OAuth Audit Log
|
||||
model OAuthAuditLog {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
eventType String // e.g., "token.issued", "grant.created"
|
||||
|
||||
userId String?
|
||||
clientId String?
|
||||
grantId String?
|
||||
|
||||
ipAddress String?
|
||||
userAgent String?
|
||||
|
||||
details Json @default("{}")
|
||||
|
||||
@@index([createdAt])
|
||||
@@index([eventType])
|
||||
@@index([userId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
// Execution Webhook - Webhook registration for external API executions
|
||||
model ExecutionWebhook {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
executionId String // The graph execution ID
|
||||
webhookUrl String // URL to send notifications to
|
||||
clientId String // The OAuth client database ID
|
||||
userId String // The user who started the execution
|
||||
secret String? // Optional webhook secret for HMAC signing
|
||||
|
||||
@@index([executionId])
|
||||
@@index([clientId])
|
||||
}
|
||||
|
||||
610
docs/content/platform/external-api-integration.md
Normal file
610
docs/content/platform/external-api-integration.md
Normal file
@@ -0,0 +1,610 @@
|
||||
# External API Integration Guide
|
||||
|
||||
This guide explains how third-party applications can integrate with AutoGPT Platform to execute agents on behalf of users using the OAuth Provider and Credential Broker system.
|
||||
|
||||
## Overview
|
||||
|
||||
The AutoGPT External API allows your application to:
|
||||
|
||||
- **Execute agents** - Run user-owned or marketplace agents with user-granted credentials
|
||||
- **Access integrations** - Use third-party service credentials (Google, GitHub, etc.) that users have connected
|
||||
- **Receive webhooks** - Get notified when agent executions complete
|
||||
|
||||
The integration uses standard OAuth 2.0 with PKCE for secure authentication and a popup-based "Connect" flow for credential access.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ Your App │────▶│ AutoGPT OAuth │────▶│ External API │
|
||||
│ │ │ Provider │ │ │
|
||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────┐
|
||||
│ Credential │
|
||||
│ Broker │
|
||||
└──────────────────┘
|
||||
```
|
||||
|
||||
**Key concepts:**
|
||||
|
||||
1. **OAuth Client** - Your registered application with AutoGPT
|
||||
2. **OAuth Tokens** - Access/refresh tokens for API authentication
|
||||
3. **Credential Grants** - User permissions to use their connected integrations
|
||||
4. **Integration Scopes** - Specific permissions for each integration (e.g., `google:gmail.readonly`)
|
||||
|
||||
## Getting Started
|
||||
|
||||
### 1. Register Your OAuth Client
|
||||
|
||||
Register your application to get a `client_id` and `client_secret`:
|
||||
|
||||
```bash
|
||||
# Requires user authentication (JWT token)
|
||||
curl -X POST https://platform.agpt.co/oauth/clients/ \
|
||||
-H "Authorization: Bearer <user_jwt>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "My App",
|
||||
"description": "Description of your app",
|
||||
"client_type": "confidential",
|
||||
"redirect_uris": ["https://myapp.com/oauth/callback"],
|
||||
"webhook_domains": ["myapp.com", "*.myapp.com"]
|
||||
}'
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"client_id": "app_abc123xyz",
|
||||
"client_secret": "secret_xyz789..."
|
||||
}
|
||||
```
|
||||
|
||||
> **Important:** Store the `client_secret` securely - it's only shown once!
|
||||
|
||||
**Client types:**
|
||||
- `confidential` - Server-side apps that can securely store secrets
|
||||
- `public` - Browser/mobile apps (no client secret)
|
||||
|
||||
### 2. OAuth Authorization Flow
|
||||
|
||||
Use the standard OAuth 2.0 Authorization Code flow with PKCE to get user consent and tokens.
|
||||
|
||||
#### Generate PKCE Parameters
|
||||
|
||||
```javascript
|
||||
// Generate code verifier and challenge
|
||||
function generateCodeVerifier() {
|
||||
const array = new Uint8Array(32);
|
||||
crypto.getRandomValues(array);
|
||||
return base64UrlEncode(array);
|
||||
}
|
||||
|
||||
async function generateCodeChallenge(verifier) {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return base64UrlEncode(new Uint8Array(digest));
|
||||
}
|
||||
```
|
||||
|
||||
#### Redirect User to Authorization
|
||||
|
||||
```javascript
|
||||
const state = crypto.randomUUID(); // Store this for validation
|
||||
const codeVerifier = generateCodeVerifier(); // Store this securely
|
||||
const codeChallenge = await generateCodeChallenge(codeVerifier);
|
||||
|
||||
const authUrl = new URL('https://platform.agpt.co/oauth/authorize');
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
authUrl.searchParams.set('client_id', CLIENT_ID);
|
||||
authUrl.searchParams.set('redirect_uri', REDIRECT_URI);
|
||||
authUrl.searchParams.set('state', state);
|
||||
authUrl.searchParams.set('code_challenge', codeChallenge);
|
||||
authUrl.searchParams.set('code_challenge_method', 'S256');
|
||||
authUrl.searchParams.set('scope', 'openid profile email agents:execute integrations:connect');
|
||||
|
||||
window.location.href = authUrl.toString();
|
||||
```
|
||||
|
||||
**Available scopes:**
|
||||
|
||||
| Scope | Description |
|
||||
|-------|-------------|
|
||||
| `openid` | Required for OIDC |
|
||||
| `profile` | Access user profile (name) |
|
||||
| `email` | Access user email |
|
||||
| `agents:execute` | Execute agents and check status |
|
||||
| `integrations:list` | List user's credential grants |
|
||||
| `integrations:connect` | Request new credential grants |
|
||||
| `integrations:delete` | Delete credentials via grants |
|
||||
|
||||
#### Handle OAuth Callback
|
||||
|
||||
```javascript
|
||||
// Your callback endpoint receives: ?code=xxx&state=xxx
|
||||
app.get('/oauth/callback', async (req, res) => {
|
||||
const { code, state } = req.query;
|
||||
|
||||
// Verify state matches what you stored
|
||||
if (state !== storedState) {
|
||||
return res.status(400).send('Invalid state');
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
const response = await fetch('https://platform.agpt.co/oauth/token', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code_verifier: storedCodeVerifier,
|
||||
}),
|
||||
});
|
||||
|
||||
const tokens = await response.json();
|
||||
// { access_token, refresh_token, token_type, expires_in }
|
||||
});
|
||||
```
|
||||
|
||||
### 3. Request Credential Grants (Connect Flow)
|
||||
|
||||
Before executing agents that require integrations (like Gmail, Google Sheets, GitHub), you need credential grants from the user.
|
||||
|
||||
#### Open Connect Popup
|
||||
|
||||
```javascript
|
||||
function requestCredentialGrant(provider, scopes) {
|
||||
const nonce = crypto.randomUUID();
|
||||
|
||||
// Store nonce to validate response
|
||||
sessionStorage.setItem('connect_nonce', nonce);
|
||||
|
||||
const connectUrl = new URL(`https://platform.agpt.co/connect/${provider}`);
|
||||
connectUrl.searchParams.set('client_id', CLIENT_ID);
|
||||
connectUrl.searchParams.set('scopes', scopes.join(','));
|
||||
connectUrl.searchParams.set('nonce', nonce);
|
||||
connectUrl.searchParams.set('redirect_origin', window.location.origin);
|
||||
|
||||
// Open popup (user must be logged into AutoGPT)
|
||||
const popup = window.open(
|
||||
connectUrl.toString(),
|
||||
'AutoGPT Connect',
|
||||
'width=500,height=600,popup=true'
|
||||
);
|
||||
|
||||
// Listen for result
|
||||
window.addEventListener('message', handleConnectResult, { once: true });
|
||||
}
|
||||
|
||||
function handleConnectResult(event) {
|
||||
// Verify origin
|
||||
if (event.origin !== 'https://platform.agpt.co') return;
|
||||
|
||||
const data = event.data;
|
||||
if (data.type !== 'autogpt_connect_result') return;
|
||||
|
||||
// Verify nonce
|
||||
if (data.nonce !== sessionStorage.getItem('connect_nonce')) return;
|
||||
|
||||
if (data.success) {
|
||||
console.log('Grant created:', data.grant_id);
|
||||
console.log('Credential ID:', data.credential_id);
|
||||
console.log('Provider:', data.provider);
|
||||
} else {
|
||||
console.error('Connect failed:', data.error, data.error_description);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Integration scopes by provider:**
|
||||
|
||||
| Provider | Available Scopes |
|
||||
|----------|------------------|
|
||||
| Google | `google:gmail.readonly`, `google:gmail.send`, `google:sheets.read`, `google:sheets.write`, `google:calendar.read`, `google:calendar.write`, `google:drive.read`, `google:drive.write` |
|
||||
| GitHub | `github:repo.read`, `github:repo.write`, `github:user.read` |
|
||||
| Twitter/X | `twitter:tweet.read`, `twitter:tweet.write`, `twitter:user.read` |
|
||||
| Linear | `linear:read`, `linear:write` |
|
||||
| Notion | `notion:read`, `notion:write` |
|
||||
| Slack | `slack:read`, `slack:write` |
|
||||
|
||||
### 4. Execute Agents
|
||||
|
||||
With an OAuth token and credential grants, you can execute agents:
|
||||
|
||||
```javascript
|
||||
async function executeAgent(agentId, inputs, grantIds = null, webhookUrl = null) {
|
||||
const response = await fetch(
|
||||
`https://platform.agpt.co/api/external/v1/executions/agents/${agentId}/execute`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
inputs,
|
||||
grant_ids: grantIds, // Optional: specific grants to use
|
||||
webhook_url: webhookUrl, // Optional: receive completion webhook
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const result = await response.json();
|
||||
// { execution_id, status: "queued", message: "..." }
|
||||
return result;
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Check Execution Status
|
||||
|
||||
Poll for execution status or use webhooks:
|
||||
|
||||
```javascript
|
||||
async function getExecutionStatus(executionId) {
|
||||
const response = await fetch(
|
||||
`https://platform.agpt.co/api/external/v1/executions/${executionId}`,
|
||||
{
|
||||
headers: { 'Authorization': `Bearer ${accessToken}` },
|
||||
}
|
||||
);
|
||||
|
||||
return await response.json();
|
||||
// {
|
||||
// execution_id,
|
||||
// status: "queued" | "running" | "completed" | "failed",
|
||||
// started_at,
|
||||
// completed_at,
|
||||
// outputs, // Present when completed
|
||||
// error, // Present when failed
|
||||
// }
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Handle Webhooks
|
||||
|
||||
If you provided a `webhook_url`, you'll receive POST requests with execution events:
|
||||
|
||||
```javascript
|
||||
app.post('/webhooks/autogpt', (req, res) => {
|
||||
// Verify webhook signature (if configured)
|
||||
const signature = req.headers['x-webhook-signature'];
|
||||
const timestamp = req.headers['x-webhook-timestamp'];
|
||||
|
||||
if (signature) {
|
||||
const expectedSignature = crypto
|
||||
.createHmac('sha256', WEBHOOK_SECRET)
|
||||
.update(JSON.stringify(req.body))
|
||||
.digest('hex');
|
||||
|
||||
if (signature !== `sha256=${expectedSignature}`) {
|
||||
return res.status(401).send('Invalid signature');
|
||||
}
|
||||
}
|
||||
|
||||
const { event, timestamp, data } = req.body;
|
||||
|
||||
switch (event) {
|
||||
case 'execution.started':
|
||||
console.log(`Execution ${data.execution_id} started`);
|
||||
break;
|
||||
case 'execution.completed':
|
||||
console.log(`Execution ${data.execution_id} completed`, data.outputs);
|
||||
break;
|
||||
case 'execution.failed':
|
||||
console.error(`Execution ${data.execution_id} failed:`, data.error);
|
||||
break;
|
||||
case 'grant.revoked':
|
||||
console.log(`Grant ${data.grant_id} was revoked`);
|
||||
break;
|
||||
}
|
||||
|
||||
res.status(200).send('OK');
|
||||
});
|
||||
```
|
||||
|
||||
> **Note:** Webhook URLs must match domains registered in your OAuth client's `webhook_domains`.
|
||||
|
||||
## API Reference
|
||||
|
||||
### External API Endpoints
|
||||
|
||||
Base URL: `https://platform.agpt.co/api/external/v1`
|
||||
|
||||
| Method | Endpoint | Scope Required | Description |
|
||||
|--------|----------|----------------|-------------|
|
||||
| GET | `/executions/capabilities` | `agents:execute` | Get available grants and scopes |
|
||||
| POST | `/executions/agents/{agent_id}/execute` | `agents:execute` | Execute an agent |
|
||||
| GET | `/executions/{execution_id}` | `agents:execute` | Get execution status |
|
||||
| POST | `/executions/{execution_id}/cancel` | `agents:execute` | Cancel execution |
|
||||
| GET | `/grants/` | `integrations:list` | List credential grants |
|
||||
| GET | `/grants/{grant_id}` | `integrations:list` | Get grant details |
|
||||
| DELETE | `/grants/{grant_id}/credential` | `integrations:delete` | Delete credential via grant |
|
||||
|
||||
### OAuth Endpoints
|
||||
|
||||
Base URL: `https://platform.agpt.co`
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| GET | `/oauth/authorize` | Authorization endpoint |
|
||||
| POST | `/oauth/token` | Token endpoint |
|
||||
| GET | `/oauth/userinfo` | OIDC UserInfo |
|
||||
| POST | `/oauth/revoke` | Revoke tokens |
|
||||
| GET | `/.well-known/openid-configuration` | OIDC Discovery |
|
||||
|
||||
### Client Management Endpoints
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| POST | `/oauth/clients/` | Register new client |
|
||||
| GET | `/oauth/clients/` | List your clients |
|
||||
| GET | `/oauth/clients/{client_id}` | Get client details |
|
||||
| PATCH | `/oauth/clients/{client_id}` | Update client |
|
||||
| DELETE | `/oauth/clients/{client_id}` | Delete client |
|
||||
| POST | `/oauth/clients/{client_id}/rotate-secret` | Rotate client secret |
|
||||
|
||||
## Rate Limits
|
||||
|
||||
| Endpoint Type | Limit |
|
||||
|--------------|-------|
|
||||
| OAuth endpoints | 20-30 requests/minute |
|
||||
| Agent execution | 10 requests/minute, 100/hour |
|
||||
| Read endpoints | 60 requests/minute, 1000/hour |
|
||||
|
||||
Rate limit headers are included in responses:
|
||||
- `X-RateLimit-Remaining` - Requests remaining in current window
|
||||
- `X-RateLimit-Reset` - Unix timestamp when limit resets
|
||||
- `Retry-After` - Seconds to wait (when rate limited)
|
||||
|
||||
## Error Handling
|
||||
|
||||
### OAuth Errors
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Authorization code has expired"
|
||||
}
|
||||
```
|
||||
|
||||
Common OAuth errors:
|
||||
- `invalid_client` - Unknown or invalid client
|
||||
- `invalid_grant` - Expired/invalid authorization code
|
||||
- `access_denied` - User denied consent
|
||||
- `invalid_scope` - Requested scope not allowed
|
||||
|
||||
### API Errors
|
||||
|
||||
```json
|
||||
{
|
||||
"detail": "Grant validation failed: No valid grants found for requested integrations"
|
||||
}
|
||||
```
|
||||
|
||||
HTTP status codes:
|
||||
- `400` - Bad request (invalid parameters)
|
||||
- `401` - Unauthorized (invalid/expired token)
|
||||
- `403` - Forbidden (insufficient scopes or grants)
|
||||
- `404` - Resource not found
|
||||
- `429` - Rate limited
|
||||
- `500` - Internal server error
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Store secrets securely** - Never expose `client_secret` in client-side code
|
||||
2. **Validate state parameter** - Prevent CSRF attacks
|
||||
3. **Use PKCE** - Required for all authorization flows
|
||||
4. **Verify webhook signatures** - Prevent spoofed webhooks
|
||||
5. **Request minimal scopes** - Only request what you need
|
||||
6. **Handle token refresh** - Refresh tokens before they expire
|
||||
7. **Validate redirect origins** - Only accept messages from expected origins
|
||||
|
||||
## Complete Integration Example
|
||||
|
||||
```javascript
|
||||
class AutoGPTClient {
|
||||
constructor(clientId, clientSecret, redirectUri) {
|
||||
this.clientId = clientId;
|
||||
this.clientSecret = clientSecret;
|
||||
this.redirectUri = redirectUri;
|
||||
this.baseUrl = 'https://platform.agpt.co';
|
||||
}
|
||||
|
||||
// Step 1: Generate authorization URL
|
||||
async getAuthorizationUrl(scopes) {
|
||||
const state = crypto.randomUUID();
|
||||
const codeVerifier = this.generateCodeVerifier();
|
||||
const codeChallenge = await this.generateCodeChallenge(codeVerifier);
|
||||
|
||||
// Store for callback
|
||||
sessionStorage.setItem('oauth_state', state);
|
||||
sessionStorage.setItem('oauth_verifier', codeVerifier);
|
||||
|
||||
const url = new URL(`${this.baseUrl}/oauth/authorize`);
|
||||
url.searchParams.set('response_type', 'code');
|
||||
url.searchParams.set('client_id', this.clientId);
|
||||
url.searchParams.set('redirect_uri', this.redirectUri);
|
||||
url.searchParams.set('state', state);
|
||||
url.searchParams.set('code_challenge', codeChallenge);
|
||||
url.searchParams.set('code_challenge_method', 'S256');
|
||||
url.searchParams.set('scope', scopes.join(' '));
|
||||
|
||||
return url.toString();
|
||||
}
|
||||
|
||||
// Step 2: Exchange code for tokens
|
||||
async exchangeCode(code, state) {
|
||||
if (state !== sessionStorage.getItem('oauth_state')) {
|
||||
throw new Error('Invalid state');
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/oauth/token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code,
|
||||
redirect_uri: this.redirectUri,
|
||||
client_id: this.clientId,
|
||||
client_secret: this.clientSecret,
|
||||
code_verifier: sessionStorage.getItem('oauth_verifier'),
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
// Step 3: Request credential grant via popup
|
||||
requestGrant(provider, scopes) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const nonce = crypto.randomUUID();
|
||||
|
||||
const url = new URL(`${this.baseUrl}/connect/${provider}`);
|
||||
url.searchParams.set('client_id', this.clientId);
|
||||
url.searchParams.set('scopes', scopes.join(','));
|
||||
url.searchParams.set('nonce', nonce);
|
||||
url.searchParams.set('redirect_origin', window.location.origin);
|
||||
|
||||
const popup = window.open(url.toString(), 'connect', 'width=500,height=600');
|
||||
|
||||
const handler = (event) => {
|
||||
if (event.origin !== this.baseUrl) return;
|
||||
if (event.data?.type !== 'autogpt_connect_result') return;
|
||||
if (event.data?.nonce !== nonce) return;
|
||||
|
||||
window.removeEventListener('message', handler);
|
||||
|
||||
if (event.data.success) {
|
||||
resolve(event.data);
|
||||
} else {
|
||||
reject(new Error(event.data.error_description));
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('message', handler);
|
||||
});
|
||||
}
|
||||
|
||||
// Step 4: Execute agent
|
||||
async executeAgent(accessToken, agentId, inputs, options = {}) {
|
||||
const response = await fetch(
|
||||
`${this.baseUrl}/api/external/v1/executions/agents/${agentId}/execute`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
inputs,
|
||||
grant_ids: options.grantIds,
|
||||
webhook_url: options.webhookUrl,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
// Step 5: Poll for completion
|
||||
async waitForCompletion(accessToken, executionId, timeoutMs = 300000) {
|
||||
const startTime = Date.now();
|
||||
|
||||
while (Date.now() - startTime < timeoutMs) {
|
||||
const response = await fetch(
|
||||
`${this.baseUrl}/api/external/v1/executions/${executionId}`,
|
||||
{ headers: { 'Authorization': `Bearer ${accessToken}` } }
|
||||
);
|
||||
|
||||
const status = await response.json();
|
||||
|
||||
if (status.status === 'completed') {
|
||||
return status.outputs;
|
||||
}
|
||||
|
||||
if (status.status === 'failed') {
|
||||
throw new Error(status.error || 'Execution failed');
|
||||
}
|
||||
|
||||
// Wait before polling again
|
||||
await new Promise(resolve => setTimeout(resolve, 2000));
|
||||
}
|
||||
|
||||
throw new Error('Execution timeout');
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
generateCodeVerifier() {
|
||||
const array = new Uint8Array(32);
|
||||
crypto.getRandomValues(array);
|
||||
return this.base64UrlEncode(array);
|
||||
}
|
||||
|
||||
async generateCodeChallenge(verifier) {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return this.base64UrlEncode(new Uint8Array(digest));
|
||||
}
|
||||
|
||||
base64UrlEncode(buffer) {
|
||||
return btoa(String.fromCharCode(...buffer))
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=+$/, '');
|
||||
}
|
||||
}
|
||||
|
||||
// Usage
|
||||
const client = new AutoGPTClient(
|
||||
'app_abc123',
|
||||
'secret_xyz789',
|
||||
'https://myapp.com/oauth/callback'
|
||||
);
|
||||
|
||||
// 1. Redirect to authorization
|
||||
const authUrl = await client.getAuthorizationUrl([
|
||||
'openid', 'profile', 'agents:execute', 'integrations:connect'
|
||||
]);
|
||||
window.location.href = authUrl;
|
||||
|
||||
// 2. After callback, exchange code
|
||||
const tokens = await client.exchangeCode(code, state);
|
||||
|
||||
// 3. Request Google credentials
|
||||
const grant = await client.requestGrant('google', ['google:gmail.readonly']);
|
||||
|
||||
// 4. Execute an agent
|
||||
const execution = await client.executeAgent(
|
||||
tokens.access_token,
|
||||
'agent-uuid-here',
|
||||
{ query: 'Search my emails for invoices' },
|
||||
{ grantIds: [grant.grant_id] }
|
||||
);
|
||||
|
||||
// 5. Wait for results
|
||||
const outputs = await client.waitForCompletion(
|
||||
tokens.access_token,
|
||||
execution.execution_id
|
||||
);
|
||||
console.log('Agent outputs:', outputs);
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
- [GitHub Issues](https://github.com/Significant-Gravitas/AutoGPT/issues) - Bug reports and feature requests
|
||||
- [Discord Community](https://discord.gg/autogpt) - Community support
|
||||
@@ -23,6 +23,7 @@ nav:
|
||||
- Using AI/ML API: platform/aimlapi.md
|
||||
- Using D-ID: platform/d_id.md
|
||||
- Blocks: platform/blocks/blocks.md
|
||||
- External API Integration: platform/external-api-integration.md
|
||||
- Contributing:
|
||||
- Tests: platform/contributing/tests.md
|
||||
- OAuth Flows: platform/contributing/oauth-integration-flow.md
|
||||
|
||||
Reference in New Issue
Block a user