refactor(copilot): extract shared file-ID sanitization helper

Extract `_resolve_workspace_files(user_id, file_ids)` helper from the
duplicated UUID-filter + workspace-DB-lookup logic in both
`stream_chat_post` and `queue_pending_message`.

Both endpoints now call the single helper; callers map the returned
`list[UserWorkspaceFile]` to IDs or file-description strings as before.

Also removes the redundant `if user_id:` guard from `stream_chat_post`'s
file-ID block — `Security(auth.get_user_id)` guarantees a non-empty string.

Addresses autogpt-pr-reviewer "Should Fix: Duplicated file-ID sanitization"
and coderabbitai nit on the if user_id guard.
This commit is contained in:
majdyz
2026-04-11 00:31:03 +07:00
parent 3ef24b3234
commit 9da0dd111f

View File

@@ -102,6 +102,29 @@ async def _validate_and_get_session(
return session
async def _resolve_workspace_files(
user_id: str,
file_ids: list[str],
) -> list[UserWorkspaceFile]:
"""Filter *file_ids* to UUID-valid entries that exist in the caller's workspace.
Returns the matching ``UserWorkspaceFile`` records (empty list if none pass).
Used by both the stream and pending-message endpoints to prevent callers from
referencing other users' files.
"""
valid_ids = [fid for fid in file_ids if _UUID_RE.match(fid)]
if not valid_ids:
return []
workspace = await get_or_create_workspace(user_id)
return await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
router = APIRouter(
tags=["chat"],
)
@@ -850,33 +873,21 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
if request.file_ids:
files = await _resolve_workspace_files(user_id, request.file_ids)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
@@ -1128,28 +1139,21 @@ async def queue_pending_message(
message_length=len(request.message),
)
# Sanitise file IDs to the user's own workspace (same logic as
# stream_chat_post) so injection doesn't surface other users' files.
# Sanitise file IDs to the user's own workspace so injection doesn't
# surface other users' files. _resolve_workspace_files handles UUID
# filtering and the workspace-scoped DB lookup.
sanitized_file_ids: list[str] = []
if request.file_ids:
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.match(fid))
files = await _resolve_workspace_files(user_id, request.file_ids)
sanitized_file_ids = [wf.id for wf in files]
if len(sanitized_file_ids) != valid_id_count:
logger.warning(
"queue_pending_message: dropped %d file id(s) not in "
"caller's workspace (session=%s)",
valid_id_count - len(sanitized_file_ids),
session_id,
)
sanitized_file_ids = [wf.id for wf in files]
if len(sanitized_file_ids) != len(valid_ids):
logger.warning(
"queue_pending_message: dropped %d file id(s) not in "
"caller's workspace (session=%s)",
len(valid_ids) - len(sanitized_file_ids),
session_id,
)
# Redis is the single source of truth for pending messages. We do
# NOT persist to ``session.messages`` here — the drain-at-start