Compare commits

...

6 Commits

Author SHA1 Message Date
Reinier van der Leer
5c62abf694 add google-auth-oauthlib for backend 2024-08-19 16:23:44 +02:00
Reinier van der Leer
4ae914b999 draft LinkGoogleDriveButton 2024-08-19 16:22:53 +02:00
Reinier van der Leer
43ab7035ba implement POST /api/integrations/auth/google 2024-08-19 16:20:51 +02:00
Reinier van der Leer
12e4dd3b84 move get_user_id from server.py to utils.py 2024-08-19 16:19:57 +02:00
Reinier van der Leer
b67d1886fa feat(server): Add Read File block with support for Google Drive and GCS 2024-08-13 18:22:09 +02:00
Reinier van der Leer
3d40815018 feat(forge/file_storage): Add Google Drive support 2024-08-13 18:20:42 +02:00
10 changed files with 480 additions and 20 deletions

View File

@@ -8,6 +8,7 @@ class FileStorageBackendName(str, enum.Enum):
LOCAL = "local"
GCS = "gcs"
S3 = "s3"
GOOGLE_DRIVE = "google_drive"
def get_storage(
@@ -35,3 +36,12 @@ def get_storage(
config = GCSFileStorageConfiguration.from_env()
config.root = root_path
return GCSFileStorage(config)
case FileStorageBackendName.GOOGLE_DRIVE:
from .google_drive import (
GoogleDriveFileStorage,
GoogleDriveFileStorageConfiguration,
)
config = GoogleDriveFileStorageConfiguration.from_env()
config.root = root_path
return GoogleDriveFileStorage(config)

View File

@@ -0,0 +1,253 @@
"""
The GoogleDriveFileStorage class provides an interface for interacting with a
file workspace, and stores the files in Google Drive.
"""
from __future__ import annotations
import inspect
import io
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Literal, overload
from google.oauth2.credentials import Credentials as GoogleCredentials
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
from forge.models.config import SystemConfiguration, UserConfigurable
from .base import FileStorage, FileStorageConfiguration
if TYPE_CHECKING:
from googleapiclient._apis.drive.v3 import File as GoogleDriveFile
logger = logging.getLogger(__name__)
class GoogleDriveFileStorageConfiguration(FileStorageConfiguration):
class Credentials(SystemConfiguration):
token: str = UserConfigurable(from_env="GOOGLE_DRIVE_TOKEN")
refresh_token: str = UserConfigurable(from_env="GOOGLE_DRIVE_REFRESH_TOKEN")
client_id: str = UserConfigurable(from_env="GOOGLE_DRIVE_CLIENT_ID")
client_secret: str = UserConfigurable(from_env="GOOGLE_DRIVE_CLIENT_SECRET")
token_uri: str = UserConfigurable(
from_env="GOOGLE_DRIVE_TOKEN_URI",
default="https://oauth2.googleapis.com/token",
)
credentials: Credentials
root_folder_id: str = UserConfigurable(from_env="GOOGLE_DRIVE_ROOT_FOLDER_ID")
class GoogleDriveFileStorage(FileStorage):
"""A class that represents Google Drive storage."""
def __init__(self, config: GoogleDriveFileStorageConfiguration):
self._root = config.root
self._credentials = config.credentials
self._root_folder_id = config.root_folder_id
self._drive = build(
"drive",
"v3",
credentials=GoogleCredentials(**self._credentials.model_dump()),
)
super().__init__()
@property
def root(self) -> Path:
"""The root directory of the file storage."""
return self._root
@property
def restrict_to_root(self) -> bool:
"""Whether to restrict generated paths to the root."""
return True
@property
def is_local(self) -> bool:
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
return False
def initialize(self) -> None:
logger.debug(f"Initializing {repr(self)}...")
# Check if root folder exists, create if it doesn't
if not self._root_folder_id:
folder_info: GoogleDriveFile = {
"name": "AutoGPT Root",
"mimeType": "application/vnd.google-apps.folder",
}
folder = self._drive.files().create(body=folder_info, fields="id").execute()
self._root_folder_id = folder.get("id")
def get_path(self, relative_path: str | Path) -> Path:
return super().get_path(relative_path)
def _get_file_id(self, path: str | Path) -> str:
path = self.get_path(path)
query = (
f"name='{path.name}' "
f"and '{self._root_folder_id}' in parents "
f"and trashed=false"
)
results = self._drive.files().list(q=query, fields="files(id)").execute()
files = results.get("files", [])
if not files:
raise ValueError(f"No file or folder '{path.name}' in workspace")
return files[0]["id"]
@overload
def open_file(
self,
path: str | Path,
mode: Literal["r", "w"] = "r",
binary: Literal[False] = False,
) -> io.TextIOWrapper:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
) -> io.BytesIO:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> io.TextIOWrapper | io.BytesIO:
...
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> io.TextIOWrapper | io.BytesIO:
"""Open a file in the storage."""
file_id = self._get_file_id(path)
if mode == "r":
request = self._drive.files().get_media(fileId=file_id)
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while done is False:
_, done = downloader.next_chunk()
fh.seek(0)
return fh if binary else io.TextIOWrapper(fh)
elif mode == "w":
return io.BytesIO() if binary else io.StringIO()
@overload
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
with self.open_file(path, "r", binary) as f:
return f.read()
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the storage."""
path = self.get_path(path)
media = MediaIoBaseUpload(
io.BytesIO(content.encode() if isinstance(content, str) else content),
mimetype="application/octet-stream",
)
file_metadata: GoogleDriveFile = {
"name": path.name,
"parents": [self._root_folder_id],
}
self._drive.files().create(
body=file_metadata,
media_body=media,
fields="id",
).execute()
if self.on_write_file:
path = Path(path)
if path.is_absolute():
path = path.relative_to(self.root)
res = self.on_write_file(path)
if inspect.isawaitable(res):
await res
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the storage."""
query = f"'{self._root_folder_id}' in parents and trashed=false"
results = self._drive.files().list(q=query, fields="files(name)").execute()
return [Path(item["name"]) for item in results.get("files", [])]
def list_folders(
self, path: str | Path = ".", recursive: bool = False
) -> list[Path]:
"""List 'directories' directly in a given path or recursively in the storage."""
query = (
f"'{self._root_folder_id}' in parents "
f"and mimeType='application/vnd.google-apps.folder' "
f"and trashed=false"
)
results = self._drive.files().list(q=query, fields="files(name)").execute()
return [Path(item["name"]) for item in results.get("files", [])]
def delete_file(self, path: str | Path) -> None:
"""Delete a file in the storage."""
file_id = self._get_file_id(path)
if file_id:
self._drive.files().delete(fileId=file_id).execute()
def delete_dir(self, path: str | Path) -> None:
"""Delete an empty folder in the storage."""
folder_id = self._get_file_id(path)
if folder_id:
self._drive.files().delete(fileId=folder_id).execute()
def exists(self, path: str | Path) -> bool:
"""Check if a file or folder exists in Google Drive storage."""
return bool(self._get_file_id(path))
def make_dir(self, path: str | Path) -> None:
"""Create a directory in the storage if doesn't exist."""
path = self.get_path(path)
folder_metadata: GoogleDriveFile = {
"name": path.name,
"mimeType": "application/vnd.google-apps.folder",
"parents": [self._root_folder_id],
}
self._drive.files().create(body=folder_metadata, fields="id").execute()
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
"""Rename a file or folder in the storage."""
file_id = self._get_file_id(old_path)
new_path = self.get_path(new_path)
file_metadata: GoogleDriveFile = {"name": new_path.name}
self._drive.files().update(fileId=file_id, body=file_metadata).execute()
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
file_id = self._get_file_id(source)
destination = self.get_path(destination)
file_metadata: GoogleDriveFile = {
"name": destination.name,
"parents": [self._root_folder_id],
}
self._drive.files().copy(fileId=file_id, body=file_metadata).execute()
def clone_with_subroot(self, subroot: str | Path) -> GoogleDriveFileStorage:
"""Create a new GoogleDriveFileStorage with a subroot of the current storage."""
subroot_path = self.get_path(subroot)
subroot_id = self._get_file_id(subroot_path)
if not subroot_id:
raise ValueError(f"Subroot {subroot} does not exist")
config = GoogleDriveFileStorageConfiguration(
root=subroot_path,
root_folder_id=subroot_id,
credentials=self._credentials,
)
return GoogleDriveFileStorage(config)
def __repr__(self) -> str:
return f"{__class__.__name__}(root={self._root})"

View File

@@ -0,0 +1,41 @@
import { FC } from "react";
import { Button } from "@/components/ui/button";
import { useSupabase } from "@/components/SupabaseProvider";
export const LinkGoogleDriveButton: FC = () => {
const { supabase, isLoading } = useSupabase();
const linkGoogleDrive = async () => {
if (isLoading || !supabase) {
return;
}
const popup = window.open(
'/api/auth/google',
'googleOAuth',
'width=500,height=600'
);
if (!popup || popup.closed || typeof popup.closed === 'undefined') {
console.error('Popup blocked or not created');
return;
}
// Polling to check when the popup is closed
const pollTimer = window.setInterval(async () => {
if (popup.closed) {
window.clearInterval(pollTimer);
// Optionally, you can refresh tokens or check the state after OAuth is done
const { data, error } = await supabase.auth.getSession();
if (data.session) {
// Tokens should now be stored in the Supabase database
console.log('Google Drive linked successfully!');
} else if (error) {
console.error('Error fetching session:', error);
}
}
}, 500);
};
return <Button onClick={linkGoogleDrive} disabled={isLoading}>Link Google Drive</Button>;
}

View File

@@ -0,0 +1,72 @@
from typing import Literal
from forge.file_storage import FileStorage
from forge.file_storage.gcs import GCSFileStorage, GCSFileStorageConfiguration
from forge.file_storage.google_drive import (
GoogleDriveFileStorage,
GoogleDriveFileStorageConfiguration,
)
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
from autogpt_server.data.model import SchemaField
class _GCSFileStorageConfig(GCSFileStorageConfiguration):
provider: Literal["gcs"]
class _GoogleDriveFileStorageConfig(GoogleDriveFileStorageConfiguration):
provider: Literal["google_drive"]
_FileStorageConfig = _GCSFileStorageConfig | _GoogleDriveFileStorageConfig
def _get_storage(config: _FileStorageConfig) -> FileStorage:
if config.provider == "google_drive":
return GoogleDriveFileStorage(config)
if config.provider == "gcs":
return GCSFileStorage(config)
raise TypeError(f"Invalid storage configuration: {config}")
class ReadFileBlock(Block):
class Input(BlockSchema):
file_storage: _FileStorageConfig = SchemaField(
description="Configuration for the file storage to use",
json_schema_extra={"resource_type": "file_storage"},
)
path: str = SchemaField(
description="The path of the file to read",
placeholder="example.txt",
)
type: Literal["text", "bytes"] = SchemaField(
description="The type of the file content",
default="text",
)
class Output(BlockSchema):
content: str | bytes = SchemaField(description="The content of the read file")
length: int = SchemaField(
description="The length/size of the file content (bytes)"
)
error: str = SchemaField(
description="Any error message if the file can't be read"
)
def __init__(self):
super().__init__(
id="e58cdb7c-f2d2-42ea-8c79-d6eaabd7df3b",
input_schema=ReadFileBlock.Input,
output_schema=ReadFileBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
try:
storage = _get_storage(input_data.file_storage)
content = storage.read_file(input_data.path, input_data.type == "bytes")
yield "content", content
yield "length", len(content)
except Exception as e:
yield "error", str(e)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from typing import Any, Callable, ClassVar, Optional, TypeVar
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic.config import JsonDict
from pydantic_core import (
CoreSchema,
PydanticUndefined,
@@ -98,13 +99,15 @@ def SchemaField(
placeholder: Optional[str] = None,
secret: bool = False,
exclude: bool = False,
json_schema_extra: Optional[JsonDict] = None,
**kwargs,
) -> T:
json_extra: dict[str, Any] = {}
if json_schema_extra is None:
json_schema_extra = {}
if placeholder:
json_extra["placeholder"] = placeholder
json_schema_extra["placeholder"] = placeholder
if secret:
json_extra["secret"] = True
json_schema_extra["secret"] = True
return Field(
default,
@@ -113,7 +116,7 @@ def SchemaField(
title=title,
description=description,
exclude=exclude,
json_schema_extra=json_extra,
json_schema_extra=json_schema_extra,
**kwargs,
)

View File

@@ -0,0 +1,57 @@
from pathlib import Path
from typing import Annotated
from fastapi import APIRouter, Depends
from google.oauth2.credentials import Credentials
from google.auth.transport.requests import Request
from google_auth_oauthlib.flow import Flow
from pydantic import BaseModel
from .utils import get_user_id
integrations_api = APIRouter()
class GoogleAuthExchangeRequestBody(BaseModel):
code: str
class GoogleAuthExchangeResponse(BaseModel):
access_token: str
refresh_token: str
token_uri: str
client_id: str
client_secret: str
scopes: str
@integrations_api.post("/auth/google")
def exchange_google_auth_code(
body: GoogleAuthExchangeRequestBody,
user_id: Annotated[str, Depends(get_user_id)],
) -> GoogleAuthExchangeResponse:
# Set up the OAuth 2.0 flow
flow = Flow.from_client_secrets_file(
Path(__file__).parent.parent.parent / "google_client_secret.json",
scopes=[], # irrelevant since requesting scopes is done by the front end
)
# Exchange the authorization code for credentials
flow.fetch_token(code=body.code)
# Get the credentials
credentials = flow.credentials
# Refresh the token if it's expired
if credentials.expired and credentials.refresh_token:
credentials.refresh(Request())
# Return the tokens
return GoogleAuthExchangeResponse(
access_token=credentials.token,
refresh_token=credentials.refresh_token,
token_uri=credentials.token_uri,
client_id=credentials.client_id,
client_secret=credentials.client_secret,
scopes=credentials.scopes,
)

View File

@@ -30,7 +30,7 @@ from autogpt_server.data.execution import (
get_execution_results,
list_executions,
)
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
from autogpt_server.data.user import get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
from autogpt_server.server.model import (
@@ -43,20 +43,11 @@ from autogpt_server.util.lock import KeyedMutex
from autogpt_server.util.service import AppService, expose, get_service_client
from autogpt_server.util.settings import Settings
from .utils import get_user_id
settings = Settings()
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not payload:
# This handles the case when authentication is disabled
return DEFAULT_USER_ID
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
return user_id
class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()

View File

@@ -0,0 +1,15 @@
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import Depends, HTTPException
from autogpt_server.data.user import DEFAULT_USER_ID
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not payload:
# This handles the case when authentication is disabled
return DEFAULT_USER_ID
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
return user_id

View File

@@ -1550,6 +1550,24 @@ files = [
google-auth = "*"
httplib2 = ">=0.19.0"
[[package]]
name = "google-auth-oauthlib"
version = "1.2.1"
description = "Google Authentication Library"
optional = false
python-versions = ">=3.6"
files = [
{file = "google_auth_oauthlib-1.2.1-py2.py3-none-any.whl", hash = "sha256:2d58a27262d55aa1b87678c3ba7142a080098cbc2024f903c62355deb235d91f"},
{file = "google_auth_oauthlib-1.2.1.tar.gz", hash = "sha256:afd0cad092a2eaa53cd8e8298557d6de1034c6cb4a740500b5357b648af97263"},
]
[package.dependencies]
google-auth = ">=2.15.0"
requests-oauthlib = ">=0.7.0"
[package.extras]
tool = ["click (>=6.0.0)"]
[[package]]
name = "google-cloud-appengine-logging"
version = "1.4.4"
@@ -2420,9 +2438,6 @@ files = [
{file = "lief-0.14.1-cp312-cp312-manylinux_2_28_x86_64.manylinux_2_27_x86_64.whl", hash = "sha256:497b88f9c9aaae999766ba188744ee35c5f38b4b64016f7dbb7037e9bf325382"},
{file = "lief-0.14.1-cp312-cp312-win32.whl", hash = "sha256:08bad88083f696915f8dcda4042a3bfc514e17462924ec8984085838b2261921"},
{file = "lief-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:e131d6158a085f8a72124136816fefc29405c725cd3695ce22a904e471f0f815"},
{file = "lief-0.14.1-cp313-cp313-manylinux_2_28_x86_64.manylinux_2_27_x86_64.whl", hash = "sha256:f9ff9a6959fb6d0e553cca41cd1027b609d27c5073e98d9fad8b774fbb5746c2"},
{file = "lief-0.14.1-cp313-cp313-win32.whl", hash = "sha256:95f295a7cc68f4e14ce7ea4ff8082a04f5313c2e5e63cc2bbe9d059190b7e4d5"},
{file = "lief-0.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:cdc1123c2e27970f8c8353505fd578e634ab33193c8d1dff36dc159e25599a40"},
{file = "lief-0.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:df650fa05ca131e4dfeb42c77985e1eb239730af9944bc0aadb1dfac8576e0e8"},
{file = "lief-0.14.1-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:b4e76eeb48ca2925c6ca6034d408582615f2faa855f9bb11482e7acbdecc4803"},
{file = "lief-0.14.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:016e4fac91303466024154dd3c4b599e8b7c52882f72038b62a2be386d98c8f9"},
@@ -3479,6 +3494,8 @@ files = [
{file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"},
{file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"},
{file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"},
{file = "orjson-3.10.6-cp313-none-win32.whl", hash = "sha256:efdf2c5cde290ae6b83095f03119bdc00303d7a03b42b16c54517baa3c4ca3d0"},
{file = "orjson-3.10.6-cp313-none-win_amd64.whl", hash = "sha256:8e190fe7888e2e4392f52cafb9626113ba135ef53aacc65cd13109eb9746c43e"},
{file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"},
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"},
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"},
@@ -6419,4 +6436,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "003a4c89682abbf72c67631367f57e56d91d72b44f95e972b2326440199045e7"
content-hash = "309ff5939c9f558b8fd2e834743a852e4d7c26f745368fb22dd08dfbeaa67057"

View File

@@ -19,6 +19,7 @@ pytest = "^8.2.1"
uvicorn = { extras = ["standard"], version = "^0.30.1" }
fastapi = "^0.109.0"
flake8 = "^7.0.0"
google-auth-oauthlib = "^1.2.1"
jsonschema = "^4.22.0"
psutil = "^5.9.8"
pyro5 = "^5.15"