feat(frontend,backend): fix google auth and add gmail, sheets (#8236)

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
Nicholas Tindle
2024-10-03 11:36:30 -05:00
committed by GitHub
parent 4989e3c282
commit 6dbc0f7270
14 changed files with 990 additions and 46 deletions

View File

@@ -81,12 +81,17 @@ class SupabaseIntegrationCredentialsStore:
]
self._set_user_integration_creds(user_id, filtered_credentials)
async def store_state_token(self, user_id: str, provider: str) -> str:
async def store_state_token(
self, user_id: str, provider: str, scopes: list[str]
) -> str:
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
state = OAuthState(
token=token, provider=provider, expires_at=int(expires_at.timestamp())
token=token,
provider=provider,
expires_at=int(expires_at.timestamp()),
scopes=scopes,
)
user_metadata = self._get_user_metadata(user_id)
@@ -100,6 +105,36 @@ class SupabaseIntegrationCredentialsStore:
return token
async def get_any_valid_scopes_from_state_token(
self, user_id: str, token: str, provider: str
) -> list[str]:
"""
Get the valid scopes from the OAuth state token. This will return any valid scopes
from any OAuth state token for the given provider. If no valid scopes are found,
an empty list is returned. DO NOT RELY ON THIS TOKEN TO AUTHENTICATE A USER, AS IT
IS TO CHECK IF THE USER HAS GIVEN PERMISSIONS TO THE APPLICATION BEFORE EXCHANGING
THE CODE FOR TOKENS.
"""
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)
if valid_state:
return valid_state.get("scopes", [])
return []
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])

View File

@@ -36,6 +36,15 @@ SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
GITHUB_CLIENT_ID=
GITHUB_CLIENT_SECRET=
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
# You'll need to add/enable the following scopes (minimum):
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
## ===== OPTIONAL API KEYS ===== ##
# LLM

View File

@@ -24,7 +24,7 @@ class ReadCsvBlock(Block):
output_schema=ReadCsvBlock.Output,
description="Reads a CSV file and outputs the data as a list of dictionaries and individual rows via rows.",
contributors=[ContributorDetails(name="Nicholas Tindle")],
categories={BlockCategory.TEXT},
categories={BlockCategory.TEXT, BlockCategory.DATA},
test_input={
"contents": "a, b, c\n1,2,3\n4,5,6",
},

View File

@@ -0,0 +1,53 @@
from typing import Literal
from autogpt_libs.supabase_integration_credentials_store.types import OAuth2Credentials
from pydantic import SecretStr
from backend.data.model import CredentialsField, CredentialsMetaInput
from backend.util.settings import Secrets
secrets = Secrets()
GOOGLE_OAUTH_IS_CONFIGURED = bool(
secrets.google_client_id and secrets.google_client_secret
)
GoogleCredentials = OAuth2Credentials
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
"""
Creates a Google credentials input on a block.
Params:
scopes: The authorization scopes needed for the block to work.
"""
return CredentialsField(
provider="google",
supported_credential_types={"oauth2"},
required_scopes=set(scopes),
description="The Google integration requires OAuth2 authentication.",
)
TEST_CREDENTIALS = OAuth2Credentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="google",
access_token=SecretStr("mock-google-access-token"),
refresh_token=SecretStr("mock-google-refresh-token"),
access_token_expires_at=1234567890,
scopes=[
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
],
title="Mock Google OAuth2 Credentials",
username="mock-google-username",
refresh_token_expires_at=1234567890,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@@ -0,0 +1,522 @@
import base64
from email.utils import parseaddr
from typing import List
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GoogleCredentials,
GoogleCredentialsField,
GoogleCredentialsInput,
)
class Attachment(BaseModel):
filename: str
content_type: str
size: int
attachment_id: str
class Email(BaseModel):
id: str
subject: str
snippet: str
from_: str
to: str
date: str
body: str = "" # Default to an empty string
sizeEstimate: int
attachments: List[Attachment]
class GmailReadBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/gmail.readonly"]
)
query: str = SchemaField(
description="Search query for reading emails",
default="is:unread",
)
max_results: int = SchemaField(
description="Maximum number of emails to retrieve",
default=10,
)
class Output(BlockSchema):
email: Email = SchemaField(
description="Email data",
)
emails: list[Email] = SchemaField(
description="List of email data",
)
error: str = SchemaField(
description="Error message if any",
)
def __init__(self):
super().__init__(
id="25310c70-b89b-43ba-b25c-4dfa7e2a481c",
description="This block reads emails from Gmail.",
categories={BlockCategory.COMMUNICATION},
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
input_schema=GmailReadBlock.Input,
output_schema=GmailReadBlock.Output,
test_input={
"query": "is:unread",
"max_results": 5,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"result",
[
{
"id": "1",
"subject": "Test Email",
"snippet": "This is a test email",
}
],
),
],
test_mock={
"_read_emails": lambda *args, **kwargs: [
{
"id": "1",
"subject": "Test Email",
"snippet": "This is a test email",
}
],
"_send_email": lambda *args, **kwargs: {"id": "1", "status": "sent"},
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = self._build_service(credentials, **kwargs)
messages = self._read_emails(
service, input_data.query, input_data.max_results
)
for email in messages:
yield "email", email
yield "emails", messages
except Exception as e:
yield "error", str(e)
@staticmethod
def _build_service(credentials: GoogleCredentials, **kwargs):
creds = Credentials(
token=(
credentials.access_token.get_secret_value()
if credentials.access_token
else None
),
refresh_token=(
credentials.refresh_token.get_secret_value()
if credentials.refresh_token
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("gmail", "v1", credentials=creds)
def _read_emails(
self, service, query: str | None, max_results: int | None
) -> list[Email]:
results = (
service.users()
.messages()
.list(userId="me", q=query or "", maxResults=max_results or 10)
.execute()
)
messages = results.get("messages", [])
email_data = []
for message in messages:
msg = (
service.users()
.messages()
.get(userId="me", id=message["id"], format="full")
.execute()
)
headers = {
header["name"].lower(): header["value"]
for header in msg["payload"]["headers"]
}
attachments = self._get_attachments(service, msg)
email = Email(
id=msg["id"],
subject=headers.get("subject", "No Subject"),
snippet=msg["snippet"],
from_=parseaddr(headers.get("from", ""))[1],
to=parseaddr(headers.get("to", ""))[1],
date=headers.get("date", ""),
body=self._get_email_body(msg),
sizeEstimate=msg["sizeEstimate"],
attachments=attachments,
)
email_data.append(email)
return email_data
def _get_email_body(self, msg):
if "parts" in msg["payload"]:
for part in msg["payload"]["parts"]:
if part["mimeType"] == "text/plain":
return base64.urlsafe_b64decode(part["body"]["data"]).decode(
"utf-8"
)
elif msg["payload"]["mimeType"] == "text/plain":
return base64.urlsafe_b64decode(msg["payload"]["body"]["data"]).decode(
"utf-8"
)
return "This email does not contain a text body."
def _get_attachments(self, service, message):
attachments = []
if "parts" in message["payload"]:
for part in message["payload"]["parts"]:
if part["filename"]:
attachment = Attachment(
filename=part["filename"],
content_type=part["mimeType"],
size=int(part["body"].get("size", 0)),
attachment_id=part["body"]["attachmentId"],
)
attachments.append(attachment)
return attachments
# Add a new method to download attachment content
def download_attachment(self, service, message_id: str, attachment_id: str):
attachment = (
service.users()
.messages()
.attachments()
.get(userId="me", messageId=message_id, id=attachment_id)
.execute()
)
file_data = base64.urlsafe_b64decode(attachment["data"].encode("UTF-8"))
return file_data
class GmailSendBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/gmail.send"]
)
to: str = SchemaField(
description="Recipient email address",
)
subject: str = SchemaField(
description="Email subject",
)
body: str = SchemaField(
description="Email body",
)
class Output(BlockSchema):
result: dict = SchemaField(
description="Send confirmation",
)
error: str = SchemaField(
description="Error message if any",
)
def __init__(self):
super().__init__(
id="6c27abc2-e51d-499e-a85f-5a0041ba94f0",
description="This block sends an email using Gmail.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailSendBlock.Input,
output_schema=GmailSendBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"to": "recipient@example.com",
"subject": "Test Email",
"body": "This is a test email sent from GmailSendBlock.",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("result", {"id": "1", "status": "sent"}),
],
test_mock={
"_send_email": lambda *args, **kwargs: {"id": "1", "status": "sent"},
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
send_result = self._send_email(
service, input_data.to, input_data.subject, input_data.body
)
yield "result", send_result
except Exception as e:
yield "error", str(e)
def _send_email(self, service, to: str, subject: str, body: str) -> dict:
if not to or not subject or not body:
raise ValueError("To, subject, and body are required for sending an email")
message = self._create_message(to, subject, body)
sent_message = (
service.users().messages().send(userId="me", body=message).execute()
)
return {"id": sent_message["id"], "status": "sent"}
def _create_message(self, to: str, subject: str, body: str) -> dict:
import base64
from email.mime.text import MIMEText
message = MIMEText(body)
message["to"] = to
message["subject"] = subject
raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode("utf-8")
return {"raw": raw_message}
class GmailListLabelsBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/gmail.labels"]
)
class Output(BlockSchema):
result: list[dict] = SchemaField(
description="List of labels",
)
error: str = SchemaField(
description="Error message if any",
)
def __init__(self):
super().__init__(
id="3e1c2c1c-c689-4520-b956-1f3bf4e02bb7",
description="This block lists all labels in Gmail.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailListLabelsBlock.Input,
output_schema=GmailListLabelsBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"result",
[
{"id": "Label_1", "name": "Important"},
{"id": "Label_2", "name": "Work"},
],
),
],
test_mock={
"_list_labels": lambda *args, **kwargs: [
{"id": "Label_1", "name": "Important"},
{"id": "Label_2", "name": "Work"},
],
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
labels = self._list_labels(service)
yield "result", labels
except Exception as e:
yield "error", str(e)
def _list_labels(self, service) -> list[dict]:
results = service.users().labels().list(userId="me").execute()
labels = results.get("labels", [])
return [{"id": label["id"], "name": label["name"]} for label in labels]
class GmailAddLabelBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/gmail.modify"]
)
message_id: str = SchemaField(
description="Message ID to add label to",
)
label_name: str = SchemaField(
description="Label name to add",
)
class Output(BlockSchema):
result: dict = SchemaField(
description="Label addition result",
)
error: str = SchemaField(
description="Error message if any",
)
def __init__(self):
super().__init__(
id="f884b2fb-04f4-4265-9658-14f433926ac9",
description="This block adds a label to a Gmail message.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailAddLabelBlock.Input,
output_schema=GmailAddLabelBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"message_id": "12345",
"label_name": "Important",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"result",
{"status": "Label added successfully", "label_id": "Label_1"},
),
],
test_mock={
"_add_label": lambda *args, **kwargs: {
"status": "Label added successfully",
"label_id": "Label_1",
},
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
result = self._add_label(
service, input_data.message_id, input_data.label_name
)
yield "result", result
except Exception as e:
yield "error", str(e)
def _add_label(self, service, message_id: str, label_name: str) -> dict:
label_id = self._get_or_create_label(service, label_name)
service.users().messages().modify(
userId="me", id=message_id, body={"addLabelIds": [label_id]}
).execute()
return {"status": "Label added successfully", "label_id": label_id}
def _get_or_create_label(self, service, label_name: str) -> str:
label_id = self._get_label_id(service, label_name)
if not label_id:
label = (
service.users()
.labels()
.create(userId="me", body={"name": label_name})
.execute()
)
label_id = label["id"]
return label_id
def _get_label_id(self, service, label_name: str) -> str | None:
results = service.users().labels().list(userId="me").execute()
labels = results.get("labels", [])
for label in labels:
if label["name"] == label_name:
return label["id"]
return None
class GmailRemoveLabelBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/gmail.modify"]
)
message_id: str = SchemaField(
description="Message ID to remove label from",
)
label_name: str = SchemaField(
description="Label name to remove",
)
class Output(BlockSchema):
result: dict = SchemaField(
description="Label removal result",
)
error: str = SchemaField(
description="Error message if any",
)
def __init__(self):
super().__init__(
id="0afc0526-aba1-4b2b-888e-a22b7c3f359d",
description="This block removes a label from a Gmail message.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailRemoveLabelBlock.Input,
output_schema=GmailRemoveLabelBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"message_id": "12345",
"label_name": "Important",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"result",
{"status": "Label removed successfully", "label_id": "Label_1"},
),
],
test_mock={
"_remove_label": lambda *args, **kwargs: {
"status": "Label removed successfully",
"label_id": "Label_1",
},
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
result = self._remove_label(
service, input_data.message_id, input_data.label_name
)
yield "result", result
except Exception as e:
yield "error", str(e)
def _remove_label(self, service, message_id: str, label_name: str) -> dict:
label_id = self._get_label_id(service, label_name)
if label_id:
service.users().messages().modify(
userId="me", id=message_id, body={"removeLabelIds": [label_id]}
).execute()
return {"status": "Label removed successfully", "label_id": label_id}
else:
return {"status": "Label not found", "label_name": label_name}
def _get_label_id(self, service, label_name: str) -> str | None:
results = service.users().labels().list(userId="me").execute()
labels = results.get("labels", [])
for label in labels:
if label["name"] == label_name:
return label["id"]
return None

View File

@@ -0,0 +1,192 @@
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GoogleCredentials,
GoogleCredentialsField,
GoogleCredentialsInput,
)
class GoogleSheetsReadBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/spreadsheets.readonly"]
)
spreadsheet_id: str = SchemaField(
description="The ID of the spreadsheet to read from",
)
range: str = SchemaField(
description="The A1 notation of the range to read",
)
class Output(BlockSchema):
result: list[list[str]] = SchemaField(
description="The data read from the spreadsheet",
)
error: str = SchemaField(
description="Error message if any",
)
def __init__(self):
super().__init__(
id="5724e902-3635-47e9-a108-aaa0263a4988",
description="This block reads data from a Google Sheets spreadsheet.",
categories={BlockCategory.DATA},
input_schema=GoogleSheetsReadBlock.Input,
output_schema=GoogleSheetsReadBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"range": "Sheet1!A1:B2",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"result",
[
["Name", "Score"],
["Alice", "85"],
],
),
],
test_mock={
"_read_sheet": lambda *args, **kwargs: [
["Name", "Score"],
["Alice", "85"],
],
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = self._build_service(credentials, **kwargs)
data = self._read_sheet(
service, input_data.spreadsheet_id, input_data.range
)
yield "result", data
except Exception as e:
yield "error", str(e)
@staticmethod
def _build_service(credentials: GoogleCredentials, **kwargs):
creds = Credentials(
token=(
credentials.access_token.get_secret_value()
if credentials.access_token
else None
),
refresh_token=(
credentials.refresh_token.get_secret_value()
if credentials.refresh_token
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("sheets", "v4", credentials=creds)
def _read_sheet(self, service, spreadsheet_id: str, range: str) -> list[list[str]]:
sheet = service.spreadsheets()
result = sheet.values().get(spreadsheetId=spreadsheet_id, range=range).execute()
return result.get("values", [])
class GoogleSheetsWriteBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/spreadsheets"]
)
spreadsheet_id: str = SchemaField(
description="The ID of the spreadsheet to write to",
)
range: str = SchemaField(
description="The A1 notation of the range to write",
)
values: list[list[str]] = SchemaField(
description="The data to write to the spreadsheet",
)
class Output(BlockSchema):
result: dict = SchemaField(
description="The result of the write operation",
)
error: str = SchemaField(
description="Error message if any",
)
def __init__(self):
super().__init__(
id="d9291e87-301d-47a8-91fe-907fb55460e5",
description="This block writes data to a Google Sheets spreadsheet.",
categories={BlockCategory.DATA},
input_schema=GoogleSheetsWriteBlock.Input,
output_schema=GoogleSheetsWriteBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"range": "Sheet1!A1:B2",
"values": [
["Name", "Score"],
["Bob", "90"],
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"result",
{"updatedCells": 4, "updatedColumns": 2, "updatedRows": 2},
),
],
test_mock={
"_write_sheet": lambda *args, **kwargs: {
"updatedCells": 4,
"updatedColumns": 2,
"updatedRows": 2,
},
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GoogleSheetsReadBlock._build_service(credentials, **kwargs)
result = self._write_sheet(
service,
input_data.spreadsheet_id,
input_data.range,
input_data.values,
)
yield "result", result
except Exception as e:
yield "error", str(e)
def _write_sheet(
self, service, spreadsheet_id: str, range: str, values: list[list[str]]
) -> dict:
body = {"values": values}
result = (
service.spreadsheets()
.values()
.update(
spreadsheetId=spreadsheet_id,
range=range,
valueInputOption="USER_ENTERED",
body=body,
)
.execute()
)
return result

View File

@@ -45,7 +45,9 @@ class BlockCategory(Enum):
INPUT = "Block that interacts with input of the graph."
OUTPUT = "Block that interacts with output of the graph."
LOGIC = "Programming logic to control the flow of your agent"
COMMUNICATION = "Block that interacts with communication platforms."
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
DATA = "Block that interacts with structured data."
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}

View File

@@ -1,12 +1,16 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import ClassVar
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
logger = logging.getLogger(__name__)
class BaseOAuthHandler(ABC):
PROVIDER_NAME: ClassVar[str]
DEFAULT_SCOPES: ClassVar[list[str]] = []
@abstractmethod
def __init__(self, client_id: str, client_secret: str, redirect_uri: str): ...
@@ -17,7 +21,9 @@ class BaseOAuthHandler(ABC):
...
@abstractmethod
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
) -> OAuth2Credentials:
"""Exchanges the acquired authorization code from login for a set of tokens"""
...
@@ -46,3 +52,11 @@ class BaseOAuthHandler(ABC):
credentials.access_token_expires_at is not None
and credentials.access_token_expires_at < int(time.time()) + 300
)
def handle_default_scopes(self, scopes: list[str]) -> list[str]:
"""Handles the default scopes for the provider"""
# If scopes are empty, use the default scopes for the provider
if not scopes:
logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}")
scopes = self.DEFAULT_SCOPES
return scopes

View File

@@ -41,7 +41,9 @@ class GitHubOAuthHandler(BaseOAuthHandler):
}
return f"{self.auth_base_url}?{urlencode(params)}"
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
) -> OAuth2Credentials:
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:

View File

@@ -1,3 +1,5 @@
import logging
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
from google.auth.external_account_authorized_user import (
Credentials as ExternalAccountCredentials,
@@ -9,6 +11,8 @@ from pydantic import SecretStr
from .base import BaseOAuthHandler
logger = logging.getLogger(__name__)
class GoogleOAuthHandler(BaseOAuthHandler):
"""
@@ -17,6 +21,11 @@ class GoogleOAuthHandler(BaseOAuthHandler):
PROVIDER_NAME = "google"
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
DEFAULT_SCOPES = [
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"openid",
]
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
@@ -25,7 +34,9 @@ class GoogleOAuthHandler(BaseOAuthHandler):
self.token_uri = "https://oauth2.googleapis.com/token"
def get_login_url(self, scopes: list[str], state: str) -> str:
flow = self._setup_oauth_flow(scopes)
all_scopes = list(set(scopes + self.DEFAULT_SCOPES))
logger.debug(f"Setting up OAuth flow with scopes: {all_scopes}")
flow = self._setup_oauth_flow(all_scopes)
flow.redirect_uri = self.redirect_uri
authorization_url, _ = flow.authorization_url(
access_type="offline",
@@ -35,29 +46,57 @@ class GoogleOAuthHandler(BaseOAuthHandler):
)
return authorization_url
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
flow = self._setup_oauth_flow(None)
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
) -> OAuth2Credentials:
logger.debug(f"Exchanging code for tokens with scopes: {scopes}")
# Use the scopes from the initial request
flow = self._setup_oauth_flow(scopes)
flow.redirect_uri = self.redirect_uri
flow.fetch_token(code=code)
logger.debug("Fetching token from Google")
# Disable scope check in fetch_token
flow.oauth2session.scope = None
token = flow.fetch_token(code=code)
logger.debug("Token fetched successfully")
# Get the actual scopes granted by Google
granted_scopes: list[str] = token.get("scope", [])
logger.debug(f"Scopes granted by Google: {granted_scopes}")
google_creds = flow.credentials
username = self._request_email(google_creds)
logger.debug(f"Received credentials: {google_creds}")
logger.debug("Requesting user email")
username = self._request_email(google_creds)
logger.debug(f"User email retrieved: {username}")
# Google's OAuth library is poorly typed so we need some of these:
assert google_creds.token
assert google_creds.refresh_token
assert google_creds.expiry
assert google_creds.scopes
return OAuth2Credentials(
assert granted_scopes
# Create OAuth2Credentials with the granted scopes
credentials = OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=SecretStr(google_creds.token),
refresh_token=SecretStr(google_creds.refresh_token),
access_token_expires_at=int(google_creds.expiry.timestamp()),
refresh_token=(SecretStr(google_creds.refresh_token)),
access_token_expires_at=(
int(google_creds.expiry.timestamp()) if google_creds.expiry else None
),
refresh_token_expires_at=None,
scopes=google_creds.scopes,
scopes=granted_scopes,
)
logger.debug(
f"OAuth2Credentials object created successfully with scopes: {credentials.scopes}"
)
return credentials
def _request_email(
self, creds: Credentials | ExternalAccountCredentials
@@ -65,6 +104,9 @@ class GoogleOAuthHandler(BaseOAuthHandler):
session = AuthorizedSession(creds)
response = session.get(self.EMAIL_ENDPOINT)
if not response.ok:
logger.error(
f"Failed to get user email. Status code: {response.status_code}"
)
return None
return response.json()["email"]
@@ -99,7 +141,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
scopes=google_creds.scopes,
)
def _setup_oauth_flow(self, scopes: list[str] | None) -> Flow:
def _setup_oauth_flow(self, scopes: list[str]) -> Flow:
return Flow.from_client_config(
{
"web": {

View File

@@ -35,7 +35,9 @@ class NotionOAuthHandler(BaseOAuthHandler):
}
return f"{self.auth_base_url}?{urlencode(params)}"
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
) -> OAuth2Credentials:
request_body = {
"grant_type": "authorization_code",
"code": code,

View File

@@ -54,10 +54,11 @@ async def login(
) -> LoginResponse:
handler = _get_provider_oauth_handler(request, provider)
# Generate and store a secure random state token
state_token = await store.store_state_token(user_id, provider)
requested_scopes = scopes.split(",") if scopes else []
# Generate and store a secure random state token along with the scopes
state_token = await store.store_state_token(user_id, provider, requested_scopes)
login_url = handler.get_login_url(requested_scopes, state_token)
return LoginResponse(login_url=login_url, state_token=state_token)
@@ -80,20 +81,44 @@ async def callback(
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
) -> CredentialsMetaResponse:
logger.debug(f"Received OAuth callback for provider: {provider}")
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not await store.verify_state_token(user_id, state_token, provider):
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
credentials = handler.exchange_code_for_tokens(code)
scopes = await store.get_any_valid_scopes_from_state_token(
user_id, state_token, provider
)
logger.debug(f"Retrieved scopes from state token: {scopes}")
scopes = handler.handle_default_scopes(scopes)
credentials = handler.exchange_code_for_tokens(code, scopes)
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
# Check if the granted scopes are sufficient for the requested scopes
if not set(scopes).issubset(set(credentials.scopes)):
# For now, we'll just log the warning and continue
logger.warning(
f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}"
)
except Exception as e:
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
raise HTTPException(status_code=400, detail=str(e))
logger.error(f"Code->Token exchange failed for provider {provider}: {e}")
raise HTTPException(
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
)
# TODO: Allow specifying `title` to set on `credentials`
store.add_creds(user_id, credentials)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
)
return CredentialsMetaResponse(
id=credentials.id,
type=credentials.type,

View File

@@ -9,7 +9,8 @@ export async function GET(request: Request) {
const code = searchParams.get("code");
const state = searchParams.get("state");
// Send message from popup window to host window
console.debug("OAuth callback received:", { code, state });
const message: OAuthPopupResultMessage =
code && state
? { message_type: "oauth_popup_result", success: true, code, state }
@@ -19,13 +20,15 @@ export async function GET(request: Request) {
message: `Incomplete query: ${searchParams.toString()}`,
};
console.debug("Sending message to opener:", message);
// Return a response with the message as JSON and a script to close the window
return new NextResponse(
`
<html>
<body>
<script>
window.postMessage(${JSON.stringify(message)});
window.opener.postMessage(${JSON.stringify(message)});
window.close();
</script>
</body>

View File

@@ -74,6 +74,7 @@ export const CredentialsInput: FC<{
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
const [oAuthPopupController, setOAuthPopupController] =
useState<AbortController | null>(null);
const [oAuthError, setOAuthError] = useState<string | null>(null);
if (!credentials) {
return null;
@@ -95,6 +96,7 @@ export const CredentialsInput: FC<{
} = credentials;
async function handleOAuthLogin() {
setOAuthError(null);
const { login_url, state_token } = await api.oAuthLogin(
provider,
schema.credentials_scopes,
@@ -102,46 +104,81 @@ export const CredentialsInput: FC<{
setOAuth2FlowInProgress(true);
const popup = window.open(login_url, "_blank", "popup=true");
if (!popup) {
throw new Error(
"Failed to open popup window. Please allow popups for this site.",
);
}
const controller = new AbortController();
setOAuthPopupController(controller);
controller.signal.onabort = () => {
console.debug("OAuth flow aborted");
setOAuth2FlowInProgress(false);
popup?.close();
popup.close();
};
popup?.addEventListener(
"message",
async (e: MessageEvent<OAuthPopupResultMessage>) => {
if (
typeof e.data != "object" ||
!(
"message_type" in e.data &&
e.data.message_type == "oauth_popup_result"
)
)
return;
if (!e.data.success) {
console.error("OAuth flow failed:", e.data.message);
return;
}
const handleMessage = async (e: MessageEvent<OAuthPopupResultMessage>) => {
console.debug("Message received:", e.data);
if (
typeof e.data != "object" ||
!("message_type" in e.data) ||
e.data.message_type !== "oauth_popup_result"
) {
console.debug("Ignoring irrelevant message");
return;
}
if (e.data.state !== state_token) return;
if (!e.data.success) {
console.error("OAuth flow failed:", e.data.message);
setOAuthError(`OAuth flow failed: ${e.data.message}`);
setOAuth2FlowInProgress(false);
return;
}
if (e.data.state !== state_token) {
console.error("Invalid state token received");
setOAuthError("Invalid state token received");
setOAuth2FlowInProgress(false);
return;
}
try {
console.debug("Processing OAuth callback");
const credentials = await oAuthCallback(e.data.code, e.data.state);
console.debug("OAuth callback processed successfully");
onSelectCredentials({
id: credentials.id,
type: "oauth2",
title: credentials.title,
provider,
});
} catch (error) {
console.error("Error in OAuth callback:", error);
setOAuthError(
// type of error is unkown so we need to use String(error)
`Error in OAuth callback: ${
error instanceof Error ? error.message : String(error)
}`,
);
} finally {
console.debug("Finalizing OAuth flow");
setOAuth2FlowInProgress(false);
controller.abort("success");
},
{ signal: controller.signal },
);
}
};
console.debug("Adding message event listener");
window.addEventListener("message", handleMessage, {
signal: controller.signal,
});
setTimeout(
() => {
console.debug("OAuth flow timed out");
controller.abort("timeout");
setOAuth2FlowInProgress(false);
setOAuthError("OAuth flow timed out");
},
5 * 60 * 1000,
);
@@ -189,6 +226,9 @@ export const CredentialsInput: FC<{
)}
</div>
{modals}
{oAuthError && (
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
)}
</>
);
}
@@ -251,6 +291,9 @@ export const CredentialsInput: FC<{
</SelectContent>
</Select>
{modals}
{oAuthError && (
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
)}
</>
);
};