mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
refactor(copilot): extract shared file-ID sanitization helper
Extract `_resolve_workspace_files(user_id, file_ids)` helper from the duplicated UUID-filter + workspace-DB-lookup logic in both `stream_chat_post` and `queue_pending_message`. Both endpoints now call the single helper; callers map the returned `list[UserWorkspaceFile]` to IDs or file-description strings as before. Also removes the redundant `if user_id:` guard from `stream_chat_post`'s file-ID block — `Security(auth.get_user_id)` guarantees a non-empty string. Addresses autogpt-pr-reviewer "Should Fix: Duplicated file-ID sanitization" and coderabbitai nit on the if user_id guard.
This commit is contained in:
@@ -102,6 +102,29 @@ async def _validate_and_get_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _resolve_workspace_files(
|
||||
user_id: str,
|
||||
file_ids: list[str],
|
||||
) -> list[UserWorkspaceFile]:
|
||||
"""Filter *file_ids* to UUID-valid entries that exist in the caller's workspace.
|
||||
|
||||
Returns the matching ``UserWorkspaceFile`` records (empty list if none pass).
|
||||
Used by both the stream and pending-message endpoints to prevent callers from
|
||||
referencing other users' files.
|
||||
"""
|
||||
valid_ids = [fid for fid in file_ids if _UUID_RE.match(fid)]
|
||||
if not valid_ids:
|
||||
return []
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
return await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -850,33 +873,21 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if request.file_ids:
|
||||
files = await _resolve_workspace_files(user_id, request.file_ids)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
@@ -1128,28 +1139,21 @@ async def queue_pending_message(
|
||||
message_length=len(request.message),
|
||||
)
|
||||
|
||||
# Sanitise file IDs to the user's own workspace (same logic as
|
||||
# stream_chat_post) so injection doesn't surface other users' files.
|
||||
# Sanitise file IDs to the user's own workspace so injection doesn't
|
||||
# surface other users' files. _resolve_workspace_files handles UUID
|
||||
# filtering and the workspace-scoped DB lookup.
|
||||
sanitized_file_ids: list[str] = []
|
||||
if request.file_ids:
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.match(fid))
|
||||
files = await _resolve_workspace_files(user_id, request.file_ids)
|
||||
sanitized_file_ids = [wf.id for wf in files]
|
||||
if len(sanitized_file_ids) != valid_id_count:
|
||||
logger.warning(
|
||||
"queue_pending_message: dropped %d file id(s) not in "
|
||||
"caller's workspace (session=%s)",
|
||||
valid_id_count - len(sanitized_file_ids),
|
||||
session_id,
|
||||
)
|
||||
sanitized_file_ids = [wf.id for wf in files]
|
||||
if len(sanitized_file_ids) != len(valid_ids):
|
||||
logger.warning(
|
||||
"queue_pending_message: dropped %d file id(s) not in "
|
||||
"caller's workspace (session=%s)",
|
||||
len(valid_ids) - len(sanitized_file_ids),
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Redis is the single source of truth for pending messages. We do
|
||||
# NOT persist to ``session.messages`` here — the drain-at-start
|
||||
|
||||
Reference in New Issue
Block a user