don't pass Prisma models through DatabaseManager

This commit is contained in:
Reinier van der Leer
2026-02-16 21:31:38 +01:00
parent 639e4d6fd3
commit 8cb5753aa3
5 changed files with 116 additions and 57 deletions

View File

@@ -11,7 +11,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import get_workspace, get_workspace_file
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
)
def _create_streaming_response(content: bytes, file) -> Response:
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mimeType,
media_type=file.mime_type,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file) -> Response:
)
async def _create_file_download_response(file) -> Response:
async def _create_file_download_response(file: WorkspaceFile) -> Response:
"""
Create a download response for a workspace file.
@@ -66,33 +66,33 @@ async def _create_file_download_response(file) -> Response:
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
if file.storage_path.startswith("local://"):
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
url = await storage.get_download_url(file.storage_path, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storagePath}): {e}",
f"(storagePath={file.storage_path}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storagePath)
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storagePath}): {fallback_error}",
f"(storagePath={file.storage_path}): {fallback_error}",
exc_info=True,
)
raise

View File

@@ -167,8 +167,8 @@ class ListWorkspaceFilesTool(BaseTool):
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
mime_type=f.mime_type,
size_bytes=f.size_bytes,
)
for f in files
]
@@ -309,8 +309,8 @@ class ReadWorkspaceFileTool(BaseTool):
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mime_type)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
@@ -321,7 +321,7 @@ class ReadWorkspaceFileTool(BaseTool):
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
mime_type=file_info.mime_type,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
@@ -350,11 +350,11 @@ class ReadWorkspaceFileTool(BaseTool):
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
mime_type=file_info.mime_type,
size_bytes=file_info.size_bytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
@@ -500,7 +500,7 @@ class WriteWorkspaceFileTool(BaseTool):
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
size_bytes=file_record.size_bytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)

View File

@@ -8,6 +8,7 @@ import logging
from datetime import datetime, timezone
from typing import Optional
import pydantic
from prisma.models import UserWorkspace, UserWorkspaceFile
from prisma.types import UserWorkspaceFileWhereInput
@@ -16,7 +17,61 @@ from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
class Workspace(pydantic.BaseModel):
"""Pydantic model for UserWorkspace, safe for RPC transport."""
id: str
user_id: str
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(workspace: "UserWorkspace") -> "Workspace":
return Workspace(
id=workspace.id,
user_id=workspace.userId,
created_at=workspace.createdAt,
updated_at=workspace.updatedAt,
)
class WorkspaceFile(pydantic.BaseModel):
"""Pydantic model for UserWorkspaceFile, safe for RPC transport."""
id: str
workspace_id: str
created_at: datetime
updated_at: datetime
name: str
path: str
storage_path: str
mime_type: str
size_bytes: int
checksum: Optional[str] = None
is_deleted: bool = False
deleted_at: Optional[datetime] = None
metadata: dict = pydantic.Field(default_factory=dict)
@staticmethod
def from_db(file: "UserWorkspaceFile") -> "WorkspaceFile":
return WorkspaceFile(
id=file.id,
workspace_id=file.workspaceId,
created_at=file.createdAt,
updated_at=file.updatedAt,
name=file.name,
path=file.path,
storage_path=file.storagePath,
mime_type=file.mimeType,
size_bytes=file.sizeBytes,
checksum=file.checksum,
is_deleted=file.isDeleted,
deleted_at=file.deletedAt,
metadata=file.metadata if isinstance(file.metadata, dict) else {},
)
async def get_or_create_workspace(user_id: str) -> Workspace:
"""
Get user's workspace, creating one if it doesn't exist.
@@ -27,7 +82,7 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace:
user_id: The user's ID
Returns:
UserWorkspace instance
Workspace instance
"""
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
@@ -37,10 +92,10 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace:
},
)
return workspace
return Workspace.from_db(workspace)
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
async def get_workspace(user_id: str) -> Optional[Workspace]:
"""
Get user's workspace if it exists.
@@ -48,9 +103,10 @@ async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
user_id: The user's ID
Returns:
UserWorkspace instance or None
Workspace instance or None
"""
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
return Workspace.from_db(workspace) if workspace else None
async def create_workspace_file(
@@ -63,7 +119,7 @@ async def create_workspace_file(
size_bytes: int,
checksum: Optional[str] = None,
metadata: Optional[dict] = None,
) -> UserWorkspaceFile:
) -> WorkspaceFile:
"""
Create a new workspace file record.
@@ -79,7 +135,7 @@ async def create_workspace_file(
metadata: Optional additional metadata
Returns:
Created UserWorkspaceFile instance
Created WorkspaceFile instance
"""
# Normalize path to start with /
if not path.startswith("/"):
@@ -103,13 +159,13 @@ async def create_workspace_file(
f"Created workspace file {file.id} at path {path} "
f"in workspace {workspace_id}"
)
return file
return WorkspaceFile.from_db(file)
async def get_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
) -> Optional[WorkspaceFile]:
"""
Get a workspace file by ID.
@@ -118,19 +174,20 @@ async def get_workspace_file(
workspace_id: Optional workspace ID for validation
Returns:
UserWorkspaceFile instance or None
WorkspaceFile instance or None
"""
where_clause: dict = {"id": file_id, "isDeleted": False}
if workspace_id:
where_clause["workspaceId"] = workspace_id
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
file = await UserWorkspaceFile.prisma().find_first(where=where_clause)
return WorkspaceFile.from_db(file) if file else None
async def get_workspace_file_by_path(
workspace_id: str,
path: str,
) -> Optional[UserWorkspaceFile]:
) -> Optional[WorkspaceFile]:
"""
Get a workspace file by its virtual path.
@@ -139,19 +196,20 @@ async def get_workspace_file_by_path(
path: Virtual path
Returns:
UserWorkspaceFile instance or None
WorkspaceFile instance or None
"""
# Normalize path
if not path.startswith("/"):
path = f"/{path}"
return await UserWorkspaceFile.prisma().find_first(
file = await UserWorkspaceFile.prisma().find_first(
where={
"workspaceId": workspace_id,
"path": path,
"isDeleted": False,
}
)
return WorkspaceFile.from_db(file) if file else None
async def list_workspace_files(
@@ -160,7 +218,7 @@ async def list_workspace_files(
include_deleted: bool = False,
limit: Optional[int] = None,
offset: int = 0,
) -> list[UserWorkspaceFile]:
) -> list[WorkspaceFile]:
"""
List files in a workspace.
@@ -172,7 +230,7 @@ async def list_workspace_files(
offset: Number of files to skip
Returns:
List of UserWorkspaceFile instances
List of WorkspaceFile instances
"""
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
@@ -185,12 +243,13 @@ async def list_workspace_files(
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().find_many(
files = await UserWorkspaceFile.prisma().find_many(
where=where_clause,
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
return [WorkspaceFile.from_db(f) for f in files]
async def count_workspace_files(
@@ -225,7 +284,7 @@ async def count_workspace_files(
async def soft_delete_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
) -> Optional[WorkspaceFile]:
"""
Soft-delete a workspace file.
@@ -237,7 +296,7 @@ async def soft_delete_workspace_file(
workspace_id: Optional workspace ID for validation
Returns:
Updated UserWorkspaceFile instance or None if not found
Updated WorkspaceFile instance or None if not found
"""
# First verify the file exists and belongs to workspace
file = await get_workspace_file(file_id, workspace_id)
@@ -259,7 +318,7 @@ async def soft_delete_workspace_file(
)
logger.info(f"Soft-deleted workspace file {file_id}")
return updated
return WorkspaceFile.from_db(updated) if updated else None
async def get_workspace_total_size(workspace_id: str) -> int:
@@ -273,4 +332,4 @@ async def get_workspace_total_size(workspace_id: str) -> int:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.sizeBytes for file in files)
return sum(file.size_bytes for file in files)

View File

@@ -383,7 +383,7 @@ async def store_media_file(
else:
info = await workspace_manager.get_file_info(ws.file_ref)
if info:
return MediaFileType(f"{file}#{info.mimeType}")
return MediaFileType(f"{file}#{info.mime_type}")
except Exception:
pass
return MediaFileType(file)
@@ -397,7 +397,7 @@ async def store_media_file(
filename=filename,
overwrite=True,
)
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
return MediaFileType(f"workspace://{file_record.id}#{file_record.mime_type}")
else:
raise ValueError(f"Invalid return_format: {return_format}")

View File

@@ -11,9 +11,9 @@ import uuid
from typing import Optional
from prisma.errors import UniqueViolationError
from prisma.models import UserWorkspaceFile
from backend.data.workspace import (
WorkspaceFile,
count_workspace_files,
create_workspace_file,
get_workspace_file,
@@ -131,7 +131,7 @@ class WorkspaceManager:
raise FileNotFoundError(f"File not found at path: {resolved_path}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
return await storage.retrieve(file.storage_path)
async def read_file_by_id(self, file_id: str) -> bytes:
"""
@@ -151,7 +151,7 @@ class WorkspaceManager:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
return await storage.retrieve(file.storage_path)
async def write_file(
self,
@@ -160,7 +160,7 @@ class WorkspaceManager:
path: Optional[str] = None,
mime_type: Optional[str] = None,
overwrite: bool = False,
) -> UserWorkspaceFile:
) -> WorkspaceFile:
"""
Write file to workspace.
@@ -175,7 +175,7 @@ class WorkspaceManager:
overwrite: Whether to overwrite existing file at path
Returns:
Created UserWorkspaceFile instance
Created WorkspaceFile instance
Raises:
ValueError: If file exceeds size limit or path already exists
@@ -296,7 +296,7 @@ class WorkspaceManager:
limit: Optional[int] = None,
offset: int = 0,
include_all_sessions: bool = False,
) -> list[UserWorkspaceFile]:
) -> list[WorkspaceFile]:
"""
List files in workspace.
@@ -311,7 +311,7 @@ class WorkspaceManager:
If False (default), only list current session's files.
Returns:
List of UserWorkspaceFile instances
List of WorkspaceFile instances
"""
effective_path = self._get_effective_path(path, include_all_sessions)
@@ -339,7 +339,7 @@ class WorkspaceManager:
# Delete from storage
storage = await get_workspace_storage()
try:
await storage.delete(file.storagePath)
await storage.delete(file.storage_path)
except Exception as e:
logger.warning(f"Failed to delete file from storage: {e}")
# Continue with database soft-delete even if storage delete fails
@@ -367,9 +367,9 @@ class WorkspaceManager:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.get_download_url(file.storagePath, expires_in)
return await storage.get_download_url(file.storage_path, expires_in)
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
async def get_file_info(self, file_id: str) -> Optional[WorkspaceFile]:
"""
Get file metadata.
@@ -377,11 +377,11 @@ class WorkspaceManager:
file_id: The file's ID
Returns:
UserWorkspaceFile instance or None
WorkspaceFile instance or None
"""
return await get_workspace_file(file_id, self.workspace_id)
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
async def get_file_info_by_path(self, path: str) -> Optional[WorkspaceFile]:
"""
Get file metadata by path.
@@ -392,7 +392,7 @@ class WorkspaceManager:
path: Virtual path
Returns:
UserWorkspaceFile instance or None
WorkspaceFile instance or None
"""
resolved_path = self._resolve_path(path)
return await get_workspace_file_by_path(self.workspace_id, resolved_path)