mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-16 17:55:55 -05:00
don't pass Prisma models through DatabaseManager
This commit is contained in:
@@ -11,7 +11,7 @@ import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi.responses import Response
|
||||
|
||||
from backend.data.workspace import get_workspace, get_workspace_file
|
||||
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(content: bytes, file) -> Response:
|
||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mimeType,
|
||||
media_type=file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Length": str(len(content)),
|
||||
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file) -> Response:
|
||||
)
|
||||
|
||||
|
||||
async def _create_file_download_response(file) -> Response:
|
||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
@@ -66,33 +66,33 @@ async def _create_file_download_response(file) -> Response:
|
||||
storage = await get_workspace_storage()
|
||||
|
||||
# For local storage, stream the file directly
|
||||
if file.storagePath.startswith("local://"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
if file.storage_path.startswith("local://"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||
url = await storage.get_download_url(file.storage_path, expires_in=300)
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
logger.error(
|
||||
f"Failed to get signed URL for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {e}",
|
||||
f"(storagePath={file.storage_path}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||
f"(storagePath={file.storage_path}): {fallback_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -167,8 +167,8 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
file_id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mimeType,
|
||||
size_bytes=f.sizeBytes,
|
||||
mime_type=f.mime_type,
|
||||
size_bytes=f.size_bytes,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
@@ -309,8 +309,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
target_file_id = file_info.id
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
@@ -321,7 +321,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
mime_type=file_info.mime_type,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
session_id=session_id,
|
||||
@@ -350,11 +350,11 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
size_bytes=file_info.sizeBytes,
|
||||
mime_type=file_info.mime_type,
|
||||
size_bytes=file_info.size_bytes,
|
||||
download_url=download_url,
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -500,7 +500,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.sizeBytes,
|
||||
size_bytes=file_record.size_bytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||
from prisma.types import UserWorkspaceFileWhereInput
|
||||
|
||||
@@ -16,7 +17,61 @@ from backend.util.json import SafeJson
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||
class Workspace(pydantic.BaseModel):
|
||||
"""Pydantic model for UserWorkspace, safe for RPC transport."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(workspace: "UserWorkspace") -> "Workspace":
|
||||
return Workspace(
|
||||
id=workspace.id,
|
||||
user_id=workspace.userId,
|
||||
created_at=workspace.createdAt,
|
||||
updated_at=workspace.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceFile(pydantic.BaseModel):
|
||||
"""Pydantic model for UserWorkspaceFile, safe for RPC transport."""
|
||||
|
||||
id: str
|
||||
workspace_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
name: str
|
||||
path: str
|
||||
storage_path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
checksum: Optional[str] = None
|
||||
is_deleted: bool = False
|
||||
deleted_at: Optional[datetime] = None
|
||||
metadata: dict = pydantic.Field(default_factory=dict)
|
||||
|
||||
@staticmethod
|
||||
def from_db(file: "UserWorkspaceFile") -> "WorkspaceFile":
|
||||
return WorkspaceFile(
|
||||
id=file.id,
|
||||
workspace_id=file.workspaceId,
|
||||
created_at=file.createdAt,
|
||||
updated_at=file.updatedAt,
|
||||
name=file.name,
|
||||
path=file.path,
|
||||
storage_path=file.storagePath,
|
||||
mime_type=file.mimeType,
|
||||
size_bytes=file.sizeBytes,
|
||||
checksum=file.checksum,
|
||||
is_deleted=file.isDeleted,
|
||||
deleted_at=file.deletedAt,
|
||||
metadata=file.metadata if isinstance(file.metadata, dict) else {},
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_workspace(user_id: str) -> Workspace:
|
||||
"""
|
||||
Get user's workspace, creating one if it doesn't exist.
|
||||
|
||||
@@ -27,7 +82,7 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspace instance
|
||||
Workspace instance
|
||||
"""
|
||||
workspace = await UserWorkspace.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
@@ -37,10 +92,10 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||
},
|
||||
)
|
||||
|
||||
return workspace
|
||||
return Workspace.from_db(workspace)
|
||||
|
||||
|
||||
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||
async def get_workspace(user_id: str) -> Optional[Workspace]:
|
||||
"""
|
||||
Get user's workspace if it exists.
|
||||
|
||||
@@ -48,9 +103,10 @@ async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspace instance or None
|
||||
Workspace instance or None
|
||||
"""
|
||||
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
return Workspace.from_db(workspace) if workspace else None
|
||||
|
||||
|
||||
async def create_workspace_file(
|
||||
@@ -63,7 +119,7 @@ async def create_workspace_file(
|
||||
size_bytes: int,
|
||||
checksum: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> UserWorkspaceFile:
|
||||
) -> WorkspaceFile:
|
||||
"""
|
||||
Create a new workspace file record.
|
||||
|
||||
@@ -79,7 +135,7 @@ async def create_workspace_file(
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
Created UserWorkspaceFile instance
|
||||
Created WorkspaceFile instance
|
||||
"""
|
||||
# Normalize path to start with /
|
||||
if not path.startswith("/"):
|
||||
@@ -103,13 +159,13 @@ async def create_workspace_file(
|
||||
f"Created workspace file {file.id} at path {path} "
|
||||
f"in workspace {workspace_id}"
|
||||
)
|
||||
return file
|
||||
return WorkspaceFile.from_db(file)
|
||||
|
||||
|
||||
async def get_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by ID.
|
||||
|
||||
@@ -118,19 +174,20 @@ async def get_workspace_file(
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||
if workspace_id:
|
||||
where_clause["workspaceId"] = workspace_id
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||
file = await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||
return WorkspaceFile.from_db(file) if file else None
|
||||
|
||||
|
||||
async def get_workspace_file_by_path(
|
||||
workspace_id: str,
|
||||
path: str,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by its virtual path.
|
||||
|
||||
@@ -139,19 +196,20 @@ async def get_workspace_file_by_path(
|
||||
path: Virtual path
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
# Normalize path
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_first(
|
||||
file = await UserWorkspaceFile.prisma().find_first(
|
||||
where={
|
||||
"workspaceId": workspace_id,
|
||||
"path": path,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
return WorkspaceFile.from_db(file) if file else None
|
||||
|
||||
|
||||
async def list_workspace_files(
|
||||
@@ -160,7 +218,7 @@ async def list_workspace_files(
|
||||
include_deleted: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
) -> list[UserWorkspaceFile]:
|
||||
) -> list[WorkspaceFile]:
|
||||
"""
|
||||
List files in a workspace.
|
||||
|
||||
@@ -172,7 +230,7 @@ async def list_workspace_files(
|
||||
offset: Number of files to skip
|
||||
|
||||
Returns:
|
||||
List of UserWorkspaceFile instances
|
||||
List of WorkspaceFile instances
|
||||
"""
|
||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||
|
||||
@@ -185,12 +243,13 @@ async def list_workspace_files(
|
||||
path_prefix = f"/{path_prefix}"
|
||||
where_clause["path"] = {"startswith": path_prefix}
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_many(
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where=where_clause,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
return [WorkspaceFile.from_db(f) for f in files]
|
||||
|
||||
|
||||
async def count_workspace_files(
|
||||
@@ -225,7 +284,7 @@ async def count_workspace_files(
|
||||
async def soft_delete_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Soft-delete a workspace file.
|
||||
|
||||
@@ -237,7 +296,7 @@ async def soft_delete_workspace_file(
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
Updated UserWorkspaceFile instance or None if not found
|
||||
Updated WorkspaceFile instance or None if not found
|
||||
"""
|
||||
# First verify the file exists and belongs to workspace
|
||||
file = await get_workspace_file(file_id, workspace_id)
|
||||
@@ -259,7 +318,7 @@ async def soft_delete_workspace_file(
|
||||
)
|
||||
|
||||
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||
return updated
|
||||
return WorkspaceFile.from_db(updated) if updated else None
|
||||
|
||||
|
||||
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||
@@ -273,4 +332,4 @@ async def get_workspace_total_size(workspace_id: str) -> int:
|
||||
Total size in bytes
|
||||
"""
|
||||
files = await list_workspace_files(workspace_id)
|
||||
return sum(file.sizeBytes for file in files)
|
||||
return sum(file.size_bytes for file in files)
|
||||
|
||||
@@ -383,7 +383,7 @@ async def store_media_file(
|
||||
else:
|
||||
info = await workspace_manager.get_file_info(ws.file_ref)
|
||||
if info:
|
||||
return MediaFileType(f"{file}#{info.mimeType}")
|
||||
return MediaFileType(f"{file}#{info.mime_type}")
|
||||
except Exception:
|
||||
pass
|
||||
return MediaFileType(file)
|
||||
@@ -397,7 +397,7 @@ async def store_media_file(
|
||||
filename=filename,
|
||||
overwrite=True,
|
||||
)
|
||||
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
||||
return MediaFileType(f"workspace://{file_record.id}#{file_record.mime_type}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid return_format: {return_format}")
|
||||
|
||||
@@ -11,9 +11,9 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import UserWorkspaceFile
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
create_workspace_file,
|
||||
get_workspace_file,
|
||||
@@ -131,7 +131,7 @@ class WorkspaceManager:
|
||||
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.retrieve(file.storagePath)
|
||||
return await storage.retrieve(file.storage_path)
|
||||
|
||||
async def read_file_by_id(self, file_id: str) -> bytes:
|
||||
"""
|
||||
@@ -151,7 +151,7 @@ class WorkspaceManager:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.retrieve(file.storagePath)
|
||||
return await storage.retrieve(file.storage_path)
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
@@ -160,7 +160,7 @@ class WorkspaceManager:
|
||||
path: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> UserWorkspaceFile:
|
||||
) -> WorkspaceFile:
|
||||
"""
|
||||
Write file to workspace.
|
||||
|
||||
@@ -175,7 +175,7 @@ class WorkspaceManager:
|
||||
overwrite: Whether to overwrite existing file at path
|
||||
|
||||
Returns:
|
||||
Created UserWorkspaceFile instance
|
||||
Created WorkspaceFile instance
|
||||
|
||||
Raises:
|
||||
ValueError: If file exceeds size limit or path already exists
|
||||
@@ -296,7 +296,7 @@ class WorkspaceManager:
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
include_all_sessions: bool = False,
|
||||
) -> list[UserWorkspaceFile]:
|
||||
) -> list[WorkspaceFile]:
|
||||
"""
|
||||
List files in workspace.
|
||||
|
||||
@@ -311,7 +311,7 @@ class WorkspaceManager:
|
||||
If False (default), only list current session's files.
|
||||
|
||||
Returns:
|
||||
List of UserWorkspaceFile instances
|
||||
List of WorkspaceFile instances
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
|
||||
@@ -339,7 +339,7 @@ class WorkspaceManager:
|
||||
# Delete from storage
|
||||
storage = await get_workspace_storage()
|
||||
try:
|
||||
await storage.delete(file.storagePath)
|
||||
await storage.delete(file.storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete file from storage: {e}")
|
||||
# Continue with database soft-delete even if storage delete fails
|
||||
@@ -367,9 +367,9 @@ class WorkspaceManager:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.get_download_url(file.storagePath, expires_in)
|
||||
return await storage.get_download_url(file.storage_path, expires_in)
|
||||
|
||||
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
||||
async def get_file_info(self, file_id: str) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Get file metadata.
|
||||
|
||||
@@ -377,11 +377,11 @@ class WorkspaceManager:
|
||||
file_id: The file's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
return await get_workspace_file(file_id, self.workspace_id)
|
||||
|
||||
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
||||
async def get_file_info_by_path(self, path: str) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Get file metadata by path.
|
||||
|
||||
@@ -392,7 +392,7 @@ class WorkspaceManager:
|
||||
path: Virtual path
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
WorkspaceFile instance or None
|
||||
"""
|
||||
resolved_path = self._resolve_path(path)
|
||||
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
|
||||
Reference in New Issue
Block a user