From 9da0dd111f3540d62a9da3c60bf0518969da2cac Mon Sep 17 00:00:00 2001 From: majdyz Date: Sat, 11 Apr 2026 00:31:03 +0700 Subject: [PATCH] refactor(copilot): extract shared file-ID sanitization helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract `_resolve_workspace_files(user_id, file_ids)` helper from the duplicated UUID-filter + workspace-DB-lookup logic in both `stream_chat_post` and `queue_pending_message`. Both endpoints now call the single helper; callers map the returned `list[UserWorkspaceFile]` to IDs or file-description strings as before. Also removes the redundant `if user_id:` guard from `stream_chat_post`'s file-ID block — `Security(auth.get_user_id)` guarantees a non-empty string. Addresses autogpt-pr-reviewer "Should Fix: Duplicated file-ID sanitization" and coderabbitai nit on the if user_id guard. --- .../backend/api/features/chat/routes.py | 94 ++++++++++--------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index d7ebe04507..11d9ebf90f 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -102,6 +102,29 @@ async def _validate_and_get_session( return session +async def _resolve_workspace_files( + user_id: str, + file_ids: list[str], +) -> list[UserWorkspaceFile]: + """Filter *file_ids* to UUID-valid entries that exist in the caller's workspace. + + Returns the matching ``UserWorkspaceFile`` records (empty list if none pass). + Used by both the stream and pending-message endpoints to prevent callers from + referencing other users' files. + """ + valid_ids = [fid for fid in file_ids if _UUID_RE.match(fid)] + if not valid_ids: + return [] + workspace = await get_or_create_workspace(user_id) + return await UserWorkspaceFile.prisma().find_many( + where={ + "id": {"in": valid_ids}, + "workspaceId": workspace.id, + "isDeleted": False, + } + ) + + router = APIRouter( tags=["chat"], ) @@ -850,33 +873,21 @@ async def stream_chat_post( # Also sanitise file_ids so only validated, workspace-scoped IDs are # forwarded downstream (e.g. to the executor via enqueue_copilot_turn). sanitized_file_ids: list[str] | None = None - if request.file_ids and user_id: - # Filter to valid UUIDs only to prevent DB abuse - valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] - - if valid_ids: - workspace = await get_or_create_workspace(user_id) - # Batch query instead of N+1 - files = await UserWorkspaceFile.prisma().find_many( - where={ - "id": {"in": valid_ids}, - "workspaceId": workspace.id, - "isDeleted": False, - } + if request.file_ids: + files = await _resolve_workspace_files(user_id, request.file_ids) + # Only keep IDs that actually exist in the user's workspace + sanitized_file_ids = [wf.id for wf in files] or None + file_lines: list[str] = [ + f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}" + for wf in files + ] + if file_lines: + files_block = ( + "\n\n[Attached files]\n" + + "\n".join(file_lines) + + "\nUse read_workspace_file with the file_id to access file contents." ) - # Only keep IDs that actually exist in the user's workspace - sanitized_file_ids = [wf.id for wf in files] or None - file_lines: list[str] = [ - f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}" - for wf in files - ] - if file_lines: - files_block = ( - "\n\n[Attached files]\n" - + "\n".join(file_lines) - + "\nUse read_workspace_file with the file_id to access file contents." - ) - request.message += files_block + request.message += files_block # Atomically append user message to session BEFORE creating task to avoid # race condition where GET_SESSION sees task as "running" but message isn't @@ -1128,28 +1139,21 @@ async def queue_pending_message( message_length=len(request.message), ) - # Sanitise file IDs to the user's own workspace (same logic as - # stream_chat_post) so injection doesn't surface other users' files. + # Sanitise file IDs to the user's own workspace so injection doesn't + # surface other users' files. _resolve_workspace_files handles UUID + # filtering and the workspace-scoped DB lookup. sanitized_file_ids: list[str] = [] if request.file_ids: - valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] - if valid_ids: - workspace = await get_or_create_workspace(user_id) - files = await UserWorkspaceFile.prisma().find_many( - where={ - "id": {"in": valid_ids}, - "workspaceId": workspace.id, - "isDeleted": False, - } + valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.match(fid)) + files = await _resolve_workspace_files(user_id, request.file_ids) + sanitized_file_ids = [wf.id for wf in files] + if len(sanitized_file_ids) != valid_id_count: + logger.warning( + "queue_pending_message: dropped %d file id(s) not in " + "caller's workspace (session=%s)", + valid_id_count - len(sanitized_file_ids), + session_id, ) - sanitized_file_ids = [wf.id for wf in files] - if len(sanitized_file_ids) != len(valid_ids): - logger.warning( - "queue_pending_message: dropped %d file id(s) not in " - "caller's workspace (session=%s)", - len(valid_ids) - len(sanitized_file_ids), - session_id, - ) # Redis is the single source of truth for pending messages. We do # NOT persist to ``session.messages`` here — the drain-at-start