mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 14:25:25 -05:00
Compare commits
4 Commits
fix/vector
...
pwuts/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0432297e97 | ||
|
|
aa989f077f | ||
|
|
1e2f6ea3c3 | ||
|
|
e23226f153 |
59
autogpt_platform/backend/backend/data/_fileio.py
Normal file
59
autogpt_platform/backend/backend/data/_fileio.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.config import ConfigDict, JsonDict
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
|
||||
class MIMEType(Enum):
|
||||
# Feel free to add missing MIME types as needed.
|
||||
# Just make sure not to make duplicates, and stick to the existing naming pattern.
|
||||
TEXT = "text/plain"
|
||||
|
||||
|
||||
MT = TypeVar("MT", bound=MIMEType)
|
||||
|
||||
|
||||
class FileMetaIO(BaseModel, Generic[MT]):
|
||||
id: str
|
||||
name: str = ""
|
||||
content_type: MT
|
||||
|
||||
@classmethod
|
||||
def allowed_content_types(cls) -> tuple[MIMEType, ...]:
|
||||
return get_args(cls.model_fields["content_type"].annotation)
|
||||
|
||||
@classmethod
|
||||
def validate_file_field_schema(cls, model: type["BlockSchema"]):
|
||||
"""Validates the schema of a file I/O field"""
|
||||
field_name = next(
|
||||
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
_FileIOFieldSchemaExtra[MT].model_validate(field_schema)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
|
||||
raise TypeError(
|
||||
"Field 'credentials' JSON schema lacks required extra items: "
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema: JsonDict, cls: "FileMetaIO"):
|
||||
schema["content_types"] = [ct.value for ct in cls.allowed_content_types()]
|
||||
# TODO: add file extensions?
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class _FileIOFieldSchemaExtra(BaseModel, Generic[MT]):
|
||||
content_types: list[MT]
|
||||
@@ -38,6 +38,7 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
from ._fileio import FileMetaIO
|
||||
from .model import (
|
||||
ContributorDetails,
|
||||
Credentials,
|
||||
@@ -252,6 +253,11 @@ class BlockSchema(BaseModel):
|
||||
"has invalid name: must be 'credentials' or *_credentials"
|
||||
)
|
||||
|
||||
elif FileMetaIO is get_origin(
|
||||
field_type := cls.model_fields[field_name].annotation
|
||||
):
|
||||
cast(type[FileMetaIO], field_type).validate_file_field_schema(cls)
|
||||
|
||||
@classmethod
|
||||
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
|
||||
return {
|
||||
|
||||
128
autogpt_platform/backend/backend/data/files.py
Normal file
128
autogpt_platform/backend/backend/data/files.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
import fastapi
|
||||
import prisma.models
|
||||
from google.cloud import storage
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.db import BaseDbModel
|
||||
from backend.util.exceptions import MissingConfigError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
# ---------------- MODEL ---------------- #
|
||||
|
||||
|
||||
class File(BaseDbModel):
|
||||
user_id: str
|
||||
|
||||
name: str
|
||||
size: int = Field(..., description="file size in bytes")
|
||||
content_type: str = Field(..., description="MIME content type of the file")
|
||||
created_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _file: prisma.models.File) -> "File":
|
||||
return cls(
|
||||
id=_file.id,
|
||||
user_id=_file.userID,
|
||||
name=_file.name,
|
||||
size=_file.size,
|
||||
content_type=_file.contentType,
|
||||
created_at=_file.createdAt,
|
||||
)
|
||||
|
||||
|
||||
# ---------------- CRUD functions ---------------- #
|
||||
|
||||
|
||||
async def list_files(user_id: str) -> Sequence[File]:
|
||||
return [
|
||||
File.from_db(f)
|
||||
for f in await prisma.models.File.prisma().find_many(where={"userID": user_id})
|
||||
]
|
||||
|
||||
|
||||
async def get_file(file_id: str, user_id: str) -> File:
|
||||
file = await prisma.models.File.prisma().find_first(
|
||||
where={"id": file_id, "userID": user_id}
|
||||
)
|
||||
if not file:
|
||||
raise NotFoundError(f"File #{file_id} does not exist")
|
||||
return File.from_db(file)
|
||||
|
||||
|
||||
async def get_file_content(file_id: str, user_id: str) -> tuple[File, storage.Blob]:
|
||||
file = await get_file(file_id=file_id, user_id=user_id)
|
||||
|
||||
blob = _user_file_bucket().get_blob(file.id)
|
||||
if not (blob and blob.exists()):
|
||||
logger.error(f"File #{file_id} of user #{user_id} not found in bucket")
|
||||
raise NotFoundError(f"File #{file_id} not found in storage")
|
||||
return file, blob
|
||||
|
||||
|
||||
async def create_file(
|
||||
user_id: str, content: bytes, content_type: str, name: str = ""
|
||||
) -> File:
|
||||
file = await prisma.models.File.prisma().create(
|
||||
data={
|
||||
"userID": user_id,
|
||||
"name": name,
|
||||
"size": len(content),
|
||||
"contentType": content_type,
|
||||
}
|
||||
)
|
||||
_user_file_bucket().blob(file.id).upload_from_string(
|
||||
content, content_type=content_type
|
||||
)
|
||||
return File.from_db(file)
|
||||
|
||||
|
||||
async def create_file_from_upload(
|
||||
user_id: str, uploaded_file: fastapi.UploadFile
|
||||
) -> File:
|
||||
# Validate file type
|
||||
content_type = uploaded_file.content_type
|
||||
if content_type is None:
|
||||
raise ValueError(
|
||||
"File has no type"
|
||||
) # FIXME: graceful fallback to type detection
|
||||
|
||||
# Validate file size
|
||||
if uploaded_file.size is None:
|
||||
raise ValueError("File has no size")
|
||||
if uploaded_file.size > MAX_FILE_SIZE:
|
||||
raise ValueError("File is too large: maximum size is 50MiB")
|
||||
|
||||
file = await prisma.models.File.prisma().create(
|
||||
data={
|
||||
"userID": user_id,
|
||||
"name": uploaded_file.filename or "",
|
||||
"size": uploaded_file.size,
|
||||
"contentType": content_type,
|
||||
}
|
||||
)
|
||||
_user_file_bucket().blob(file.id).upload_from_file(
|
||||
uploaded_file, content_type=content_type
|
||||
)
|
||||
return File.from_db(file)
|
||||
|
||||
|
||||
# ---------------- UTILITIES ---------------- #
|
||||
|
||||
|
||||
def _user_file_bucket() -> storage.Bucket:
|
||||
if not settings.secrets.user_file_gcs_bucket_name:
|
||||
raise MissingConfigError("Missing storage bucket configuration")
|
||||
|
||||
# TODO: use S3 API instead to allow use of other cloud storage providers
|
||||
storage_client = storage.Client()
|
||||
return storage_client.bucket(settings.secrets.user_file_gcs_bucket_name)
|
||||
44
autogpt_platform/backend/backend/server/routers/files.py
Normal file
44
autogpt_platform/backend/backend/server/routers/files.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import fastapi
|
||||
|
||||
import backend.data.files as files_db
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
files_api = fastapi.APIRouter()
|
||||
|
||||
|
||||
@files_api.get(path="/", tags=["files"])
|
||||
async def list_files(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
) -> Sequence[files_db.File]:
|
||||
return await files_db.list_files(user_id=user_id)
|
||||
|
||||
|
||||
@files_api.post(path="/", tags=["files"])
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
file: fastapi.UploadFile,
|
||||
) -> files_db.File:
|
||||
return await files_db.create_file_from_upload(user_id=user_id, uploaded_file=file)
|
||||
|
||||
|
||||
@files_api.get(path="/{file_id}", tags=["files"])
|
||||
async def get_file_meta(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
file_id: Annotated[str, fastapi.Path()],
|
||||
) -> files_db.File:
|
||||
return await files_db.get_file(user_id=user_id, file_id=file_id)
|
||||
|
||||
|
||||
@files_api.get(path="/{file_id}/download", tags=["files"])
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
file_id: Annotated[str, fastapi.Path()],
|
||||
):
|
||||
file, blob = await files_db.get_file_content(user_id=user_id, file_id=file_id)
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=blob.open(),
|
||||
media_type=file.content_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{file.name}"'},
|
||||
)
|
||||
@@ -41,7 +41,7 @@ class MissingConfigError(Exception):
|
||||
|
||||
|
||||
class NotFoundError(ValueError):
|
||||
"""The requested record was not found, resulting in an error condition"""
|
||||
"""The requested resource was not found in the system"""
|
||||
|
||||
|
||||
class GraphNotFoundError(ValueError):
|
||||
|
||||
@@ -549,6 +549,11 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
description="The secret key to use for the unsubscribe user by token",
|
||||
)
|
||||
|
||||
user_file_gcs_bucket_name: str = Field(
|
||||
default="",
|
||||
description="The name of the Google Cloud Storage bucket for users' files",
|
||||
)
|
||||
|
||||
# OAuth server credentials for integrations
|
||||
# --8<-- [start:OAuthServerCredentialsExample]
|
||||
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
|
||||
|
||||
6101
autogpt_platform/backend/poetry.lock
generated
6101
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -87,6 +87,7 @@ stagehand = "^0.5.1"
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
black = "^24.10.0"
|
||||
faker = "^38.2.0"
|
||||
google-api-python-client-stubs = "^1.28.0"
|
||||
httpx = "^0.28.1"
|
||||
isort = "^5.13.2"
|
||||
poethepoet = "^0.37.0"
|
||||
|
||||
@@ -50,6 +50,7 @@ model User {
|
||||
|
||||
AgentPresets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
Files File[]
|
||||
|
||||
Profile Profile[]
|
||||
UserOnboarding UserOnboarding?
|
||||
@@ -458,6 +459,9 @@ model AgentNodeExecutionInputOutput {
|
||||
data Json?
|
||||
time DateTime @default(now())
|
||||
|
||||
fileID String?
|
||||
File File? @relation("AgentNodeExecutionIOFile", fields: [fileID], references: [id])
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
referencedByInputExecId String?
|
||||
ReferencedByInputExec AgentNodeExecution? @relation("AgentNodeExecutionInput", fields: [referencedByInputExecId], references: [id], onDelete: Cascade)
|
||||
@@ -519,6 +523,22 @@ model PendingHumanReview {
|
||||
@@index([graphExecId, status])
|
||||
}
|
||||
|
||||
model File {
|
||||
id String @id @default(uuid())
|
||||
userID String
|
||||
|
||||
OwnedByUser User @relation(fields: [userID], references: [id])
|
||||
|
||||
name String
|
||||
size Int // file size in bytes
|
||||
contentType String
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
UsedInAgentNodeExecutionIO AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionIOFile")
|
||||
|
||||
@@index([id, userID])
|
||||
}
|
||||
|
||||
// Webhook that is registered with a provider and propagates to one or more nodes
|
||||
model IntegrationWebhook {
|
||||
id String @id @default(uuid())
|
||||
|
||||
@@ -6,7 +6,7 @@ export function useRunAgentInputs() {
|
||||
const [uploadProgress, setUploadProgress] = useState(0);
|
||||
|
||||
async function handleUploadFile(file: File) {
|
||||
const result = await api.uploadFile(file, "gcs", 24, (progress) =>
|
||||
const result = await api.uploadSignedFile(file, "gcs", 24, (progress) =>
|
||||
setUploadProgress(progress),
|
||||
);
|
||||
return result;
|
||||
|
||||
@@ -18,6 +18,7 @@ import type {
|
||||
CreateAPIKeyResponse,
|
||||
CreatorDetails,
|
||||
CreatorsResponse,
|
||||
FileMeta,
|
||||
Credentials,
|
||||
CredentialsDeleteNeedConfirmationResponse,
|
||||
CredentialsDeleteResponse,
|
||||
@@ -408,6 +409,26 @@ export default class BackendAPI {
|
||||
});
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
///////////////// FILES ////////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
listFiles(): Promise<FileMeta[]> {
|
||||
return this._get("/files");
|
||||
}
|
||||
|
||||
getFileMeta(fileID: string): Promise<FileMeta> {
|
||||
return this._get(`/files/${fileID}`);
|
||||
}
|
||||
|
||||
downloadFile(fileID: string): Promise<unknown> {
|
||||
return this._get(`/files/${fileID}/download`);
|
||||
}
|
||||
|
||||
uploadFile(file: File): Promise<FileMeta> {
|
||||
return this._uploadFile("/files", file);
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns `true` if a ping event was received, `false` if provider doesn't support pinging but the webhook exists.
|
||||
* @throws `Error` if the webhook does not exist.
|
||||
@@ -501,7 +522,7 @@ export default class BackendAPI {
|
||||
return this._uploadFile("/store/submissions/media", file);
|
||||
}
|
||||
|
||||
uploadFile(
|
||||
uploadSignedFile(
|
||||
file: File,
|
||||
provider: string = "gcs",
|
||||
expiration_hours: number = 24,
|
||||
@@ -809,7 +830,7 @@ export default class BackendAPI {
|
||||
return session?.access_token || "no-token-found";
|
||||
}
|
||||
|
||||
private async _uploadFile(path: string, file: File): Promise<string> {
|
||||
private async _uploadFile(path: string, file: File): Promise<any> {
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
|
||||
@@ -844,7 +865,7 @@ export default class BackendAPI {
|
||||
private async _makeClientFileUpload(
|
||||
path: string,
|
||||
formData: FormData,
|
||||
): Promise<string> {
|
||||
): Promise<any> {
|
||||
// Dynamic import is required even for client-only functions because helpers.ts
|
||||
// has server-only imports (like getServerSupabase) at the top level. Static imports
|
||||
// would bundle server-only code into the client bundle, causing runtime errors.
|
||||
@@ -868,7 +889,7 @@ export default class BackendAPI {
|
||||
private async _makeServerFileUpload(
|
||||
path: string,
|
||||
formData: FormData,
|
||||
): Promise<string> {
|
||||
): Promise<any> {
|
||||
const { makeAuthenticatedFileUpload, buildServerUrl } = await import(
|
||||
"./helpers"
|
||||
);
|
||||
|
||||
@@ -659,6 +659,23 @@ export type HostScopedCredentials = BaseCredentials & {
|
||||
headers: Record<string, string>;
|
||||
};
|
||||
|
||||
/* Mirror of backend/data/files.py:File */
|
||||
export type FileMeta = {
|
||||
id: string;
|
||||
user_id: string;
|
||||
name: string;
|
||||
size: number;
|
||||
content_type: string;
|
||||
created_at: Date;
|
||||
};
|
||||
|
||||
/* Mirror of backend/backend/data/_fileio.py:FileMetaIO */
|
||||
export type FileMetaIO = {
|
||||
id: string;
|
||||
name: string;
|
||||
content_type: string;
|
||||
};
|
||||
|
||||
// Mirror of backend/backend/data/notifications.py:NotificationType
|
||||
export type NotificationType =
|
||||
| "AGENT_RUN"
|
||||
|
||||
Reference in New Issue
Block a user