diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index d7ebe04507..11d9ebf90f 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -102,6 +102,29 @@ async def _validate_and_get_session( return session +async def _resolve_workspace_files( + user_id: str, + file_ids: list[str], +) -> list[UserWorkspaceFile]: + """Filter *file_ids* to UUID-valid entries that exist in the caller's workspace. + + Returns the matching ``UserWorkspaceFile`` records (empty list if none pass). + Used by both the stream and pending-message endpoints to prevent callers from + referencing other users' files. + """ + valid_ids = [fid for fid in file_ids if _UUID_RE.match(fid)] + if not valid_ids: + return [] + workspace = await get_or_create_workspace(user_id) + return await UserWorkspaceFile.prisma().find_many( + where={ + "id": {"in": valid_ids}, + "workspaceId": workspace.id, + "isDeleted": False, + } + ) + + router = APIRouter( tags=["chat"], ) @@ -850,33 +873,21 @@ async def stream_chat_post( # Also sanitise file_ids so only validated, workspace-scoped IDs are # forwarded downstream (e.g. to the executor via enqueue_copilot_turn). sanitized_file_ids: list[str] | None = None - if request.file_ids and user_id: - # Filter to valid UUIDs only to prevent DB abuse - valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] - - if valid_ids: - workspace = await get_or_create_workspace(user_id) - # Batch query instead of N+1 - files = await UserWorkspaceFile.prisma().find_many( - where={ - "id": {"in": valid_ids}, - "workspaceId": workspace.id, - "isDeleted": False, - } + if request.file_ids: + files = await _resolve_workspace_files(user_id, request.file_ids) + # Only keep IDs that actually exist in the user's workspace + sanitized_file_ids = [wf.id for wf in files] or None + file_lines: list[str] = [ + f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}" + for wf in files + ] + if file_lines: + files_block = ( + "\n\n[Attached files]\n" + + "\n".join(file_lines) + + "\nUse read_workspace_file with the file_id to access file contents." ) - # Only keep IDs that actually exist in the user's workspace - sanitized_file_ids = [wf.id for wf in files] or None - file_lines: list[str] = [ - f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}" - for wf in files - ] - if file_lines: - files_block = ( - "\n\n[Attached files]\n" - + "\n".join(file_lines) - + "\nUse read_workspace_file with the file_id to access file contents." - ) - request.message += files_block + request.message += files_block # Atomically append user message to session BEFORE creating task to avoid # race condition where GET_SESSION sees task as "running" but message isn't @@ -1128,28 +1139,21 @@ async def queue_pending_message( message_length=len(request.message), ) - # Sanitise file IDs to the user's own workspace (same logic as - # stream_chat_post) so injection doesn't surface other users' files. + # Sanitise file IDs to the user's own workspace so injection doesn't + # surface other users' files. _resolve_workspace_files handles UUID + # filtering and the workspace-scoped DB lookup. sanitized_file_ids: list[str] = [] if request.file_ids: - valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] - if valid_ids: - workspace = await get_or_create_workspace(user_id) - files = await UserWorkspaceFile.prisma().find_many( - where={ - "id": {"in": valid_ids}, - "workspaceId": workspace.id, - "isDeleted": False, - } + valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.match(fid)) + files = await _resolve_workspace_files(user_id, request.file_ids) + sanitized_file_ids = [wf.id for wf in files] + if len(sanitized_file_ids) != valid_id_count: + logger.warning( + "queue_pending_message: dropped %d file id(s) not in " + "caller's workspace (session=%s)", + valid_id_count - len(sanitized_file_ids), + session_id, ) - sanitized_file_ids = [wf.id for wf in files] - if len(sanitized_file_ids) != len(valid_ids): - logger.warning( - "queue_pending_message: dropped %d file id(s) not in " - "caller's workspace (session=%s)", - len(valid_ids) - len(sanitized_file_ids), - session_id, - ) # Redis is the single source of truth for pending messages. We do # NOT persist to ``session.messages`` here — the drain-at-start