mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend): address review findings for pending-message endpoint
- Fix off-by-one in rate limit: use >= instead of > for call count check - Move track_user_message() after push_pending_message() so analytics only fires on successful push - Add logger.warning in rate-limiter except-Exception catch instead of silent pass - Use fullmatch instead of match for UUID regex validation - Add extra="forbid" to PendingMessageContext to reject unexpected fields
This commit is contained in:
@@ -133,7 +133,7 @@ async def _resolve_workspace_files(
|
||||
Used by both the stream and pending-message endpoints to prevent callers from
|
||||
referencing other users' files.
|
||||
"""
|
||||
valid_ids = [fid for fid in file_ids if _UUID_RE.match(fid)]
|
||||
valid_ids = [fid for fid in file_ids if _UUID_RE.fullmatch(fid)]
|
||||
if not valid_ids:
|
||||
return []
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
@@ -1172,7 +1172,7 @@ async def queue_pending_message(
|
||||
),
|
||||
)
|
||||
)
|
||||
if _call_count > _PENDING_CALL_LIMIT:
|
||||
if _call_count >= _PENDING_CALL_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Too many pending messages: limit is {_PENDING_CALL_LIMIT} per {_PENDING_CALL_WINDOW_SECONDS}s",
|
||||
@@ -1180,20 +1180,14 @@ async def queue_pending_message(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass # Redis failure is non-fatal; fail open
|
||||
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.warning("queue_pending_message: rate-limit check failed, failing open") # non-fatal
|
||||
|
||||
# Sanitise file IDs to the user's own workspace so injection doesn't
|
||||
# surface other users' files. _resolve_workspace_files handles UUID
|
||||
# filtering and the workspace-scoped DB lookup.
|
||||
sanitized_file_ids: list[str] = []
|
||||
if request.file_ids:
|
||||
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.match(fid))
|
||||
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.fullmatch(fid))
|
||||
files = await _resolve_workspace_files(user_id, request.file_ids)
|
||||
sanitized_file_ids = [wf.id for wf in files]
|
||||
if len(sanitized_file_ids) != valid_id_count:
|
||||
@@ -1220,6 +1214,12 @@ async def queue_pending_message(
|
||||
)
|
||||
buffer_length = await push_pending_message(session_id, pending)
|
||||
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
|
||||
# Check whether a turn is currently running for UX feedback.
|
||||
active_session = await stream_registry.get_session(session_id)
|
||||
turn_in_flight = bool(active_session and active_session.status == "running")
|
||||
|
||||
@@ -49,7 +49,7 @@ _PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
|
||||
_NOTIFY_PAYLOAD = "1"
|
||||
|
||||
|
||||
class PendingMessageContext(BaseModel):
|
||||
class PendingMessageContext(BaseModel, extra="forbid"):
|
||||
"""Structured page context attached to a pending message."""
|
||||
|
||||
url: str | None = None
|
||||
|
||||
Reference in New Issue
Block a user