feat(chat): Enhance chat session management and error handling

- Added chat routes to the REST API for improved session management.
- Introduced RedisError exception for better error handling when interacting with Redis.
- Updated ChatSession model to include credentials and improved session validation.
- Enhanced error logging and handling in chat streaming functions to ensure robustness.
- Removed unused LOGIN_NEEDED response type from models for cleaner code.
This commit is contained in:
Swifty
2025-10-30 15:09:22 +01:00
parent 8d9bcb620d
commit c136e08321
6 changed files with 133 additions and 49 deletions

View File

@@ -27,6 +27,7 @@ import backend.server.v2.admin.credit_admin_routes
import backend.server.v2.admin.store_admin_routes
import backend.server.v2.builder
import backend.server.v2.builder.routes
import backend.server.v2.chat.routes as chat_routes
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
@@ -284,6 +285,7 @@ app.include_router(
tags=["v1", "email"],
prefix="/api/email",
)
app.include_router(chat_routes.router, tags=["v2", "chat"], prefix="/api/chat")
app.mount("/external-api", external_app)

View File

@@ -20,6 +20,7 @@ from pydantic import BaseModel
from backend.server.v2.chat.config import ChatConfig
from backend.util.cache import async_redis
from backend.util.exceptions import RedisError
logger = logging.getLogger(__name__)
config = ChatConfig()
@@ -46,6 +47,7 @@ class ChatSession(BaseModel):
user_id: str | None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
started_at: datetime
updated_at: datetime
@@ -56,6 +58,7 @@ class ChatSession(BaseModel):
user_id=user_id,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@@ -102,13 +105,27 @@ class ChatSession(BaseModel):
if message.tool_calls:
t: list[ChatCompletionMessageToolCallParam] = []
for tool_call in message.tool_calls:
# Tool calls are stored with nested structure: {id, type, function: {name, arguments}}
function_data = tool_call.get("function", {})
# Skip tool calls that are missing required fields
if "id" not in tool_call or "name" not in function_data:
logger.warning(
f"Skipping invalid tool call: missing required fields. "
f"Got: {tool_call.keys()}, function keys: {function_data.keys()}"
)
continue
# Arguments are stored as a JSON string
arguments_str = function_data.get("arguments", "{}")
t.append(
ChatCompletionMessageToolCallParam(
id=tool_call["id"],
type="function",
function=Function(
arguments=tool_call["arguments"],
name=tool_call["name"],
arguments=arguments_str,
name=function_data["name"],
),
)
)
@@ -142,15 +159,19 @@ async def get_chat_session(
"""Get a chat session by ID."""
redis_key = f"chat:session:{session_id}"
raw_session: bytes = await async_redis.get(redis_key)
raw_session: bytes | None = await async_redis.get(redis_key)
if not raw_session:
logger.warning(f"Session {session_id} not found")
if raw_session is None:
logger.warning(f"Session {session_id} not found in Redis")
return None
session = ChatSession.model_validate_json(raw_session)
try:
session = ChatSession.model_validate_json(raw_session)
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
if session.user_id != user_id:
if session.user_id is not None and session.user_id != user_id:
logger.warning(
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
)
@@ -171,6 +192,8 @@ async def upsert_chat_session(
)
if not resp:
raise Exception(f"Failed to update chat session: {resp}")
raise RedisError(
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
)
return session

View File

@@ -10,7 +10,6 @@ class ResponseType(str, Enum):
TEXT_CHUNK = "text_chunk"
TOOL_CALL = "tool_call"
TOOL_RESPONSE = "tool_response"
LOGIN_NEEDED = "login_needed"
ERROR = "error"
USAGE = "usage"
STREAM_END = "stream_end"
@@ -57,22 +56,8 @@ class StreamToolExecutionResult(StreamBaseResponse):
)
class StreamLoginNeeded(StreamBaseResponse):
"""Authentication required notification."""
type: ResponseType = ResponseType.LOGIN_NEEDED
message: str = Field(..., description="Message explaining why login is needed")
session_id: str = Field(..., description="Current session ID to preserve")
agent_info: dict[str, Any] | None = Field(
default=None, description="Agent context if applicable"
)
required_action: str = Field(
default="login", description="Required action (login/signup)"
)
class StreamUsage(StreamBaseResponse):
"""Error response."""
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
prompt_tokens: int

View File

@@ -1,10 +1,12 @@
"""Chat API routes for chat session management and streaming via SSE."""
import logging
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
@@ -149,15 +151,36 @@ async def stream_chat(
message: The user's new message to process.
user_id: Optional authenticated user ID.
Yields:
SSE-formatted response chunks.
Returns:
StreamingResponse: SSE-formatted response chunks.
"""
# Validate session exists before starting the stream
# This prevents errors after the response has already started
session = await chat_service.get_session(session_id, user_id)
async for chunk in chat_service.stream_chat_completion(
session_id, message, user_id
):
yield chunk.to_sse()
if not session:
raise NotFoundError(
f"Session {session_id} not found. "
)
if session.user_id is None and user_id is not None:
session = await chat_service.assign_user_to_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
async for chunk in chat_service.stream_chat_completion(
session_id, message, user_id
):
yield chunk.to_sse()
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
@router.patch(

View File

@@ -37,9 +37,11 @@ async def create_chat_session(
user_id: str | None = None,
) -> ChatSession:
"""
Create a new chat session.
Create a new chat session and persist it to the database.
"""
return ChatSession.new(user_id)
session = ChatSession.new(user_id)
# Persist the session immediately so it can be used for streaming
return await upsert_chat_session(session)
async def get_session(
@@ -74,17 +76,19 @@ async def stream_chat_completion(
"""Main entry point for streaming chat completions with database handling.
This function handles all database operations and delegates streaming
to the pure stream_chat_response function.
to the internal _stream_chat_chunks function.
Args:
session_id: Chat session ID
user_message: User's input message
user_id: User ID for authentication
model: OpenAI model to use
max_messages: Maximum context messages to include
user_id: User ID for authentication (None for anonymous)
Yields:
SSE formatted JSON strings with response data
StreamBaseResponse objects formatted as SSE
Raises:
NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
"""
logger.info(
@@ -109,7 +113,7 @@ async def stream_chat_completion(
session = await upsert_chat_session(session)
assert session, "Session not found"
assistant_repsonse = ChatMessage(
assistant_response = ChatMessage(
role="assistant",
content="",
)
@@ -123,9 +127,15 @@ async def stream_chat_completion(
):
if isinstance(chunk, StreamTextChunk):
assistant_repsonse.content += chunk.content
assistant_response.content += chunk.content
yield chunk
elif isinstance(chunk, StreamToolCall):
# Convert arguments dict to JSON string for consistent storage
arguments_str = (
orjson.dumps(chunk.arguments).decode("utf-8")
if chunk.arguments
else "{}"
)
tool_call_response = ChatMessage(
role="assistant",
content="",
@@ -135,7 +145,7 @@ async def stream_chat_completion(
"type": "function",
"function": {
"name": chunk.tool_name,
"arguments": chunk.arguments,
"arguments": arguments_str,
},
}
],
@@ -167,13 +177,16 @@ async def stream_chat_completion(
else:
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
except Exception as e:
logger.error(f"Error in stream: {e!s}", exc_info=True)
logger.error(f"Error during stream: {e!s}", exc_info=True)
# Always yield error response if we haven't already
if not has_yielded_error:
error_response = StreamError(
message=str(e),
timestamp=datetime.now(UTC).isoformat(),
)
yield error_response
has_yielded_error = True
# Always yield end marker after error
if not has_yielded_end:
yield StreamEnd(
timestamp=datetime.now(UTC).isoformat(),
@@ -185,7 +198,10 @@ async def stream_chat_completion(
logger.info(
f"Upserting session: {session.session_id} with user id {session.user_id}"
)
session.messages.append(assistant_repsonse)
# Only append assistant response if it has content or tool calls
# to avoid saving empty messages on errors
if assistant_response.content or assistant_response.tool_calls:
session.messages.append(assistant_response)
await upsert_chat_session(session)
@@ -225,10 +241,9 @@ async def _stream_chat_chunks(
stream=True,
)
# Variables to accumulate the response
assistant_message: str = ""
# Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = []
active_tool_call_idx = None
active_tool_call_idx: int | None = None
finish_reason: str | None = None
# Process the stream
@@ -253,7 +268,6 @@ async def _stream_chat_chunks(
# Handle content streaming
if delta.content:
assistant_message += delta.content
# Stream the text chunk
text_response = StreamTextChunk(
content=delta.content,
@@ -268,12 +282,17 @@ async def _stream_chat_chunks(
if active_tool_call_idx is None:
active_tool_call_idx = idx
# When we start receiving a new tool call (higher index),
# yield the previous one since it's now complete
# (OpenAI streams tool calls with incrementing indices)
if active_tool_call_idx != idx:
yield_idx = idx - 1
async for tc in _yield_tool_call(
tool_calls, yield_idx, session
):
yield tc
# Update to track the new active tool call
active_tool_call_idx = idx
# Ensure we have a tool call object at this index
while len(tool_calls) <= idx:
@@ -302,11 +321,19 @@ async def _stream_chat_chunks(
] += tc_chunk.function.arguments
logger.info(f"Stream complete. Finish reason: {finish_reason}")
if active_tool_call_idx is not None:
# Yield the final tool call if any were accumulated
if active_tool_call_idx is not None and active_tool_call_idx < len(
tool_calls
):
async for tc in _yield_tool_call(
tool_calls, active_tool_call_idx, session
):
yield tc
elif active_tool_call_idx is not None:
logger.warning(
f"Active tool call index {active_tool_call_idx} out of bounds "
f"(tool_calls length: {len(tool_calls)})"
)
yield StreamEnd(
timestamp=datetime.now(UTC).isoformat(),
@@ -334,16 +361,34 @@ async def _yield_tool_call(
Yield a tool call.
"""
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
# Parse tool call arguments with error handling
try:
arguments = orjson.loads(tool_calls[yield_idx]["function"]["arguments"])
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
logger.error(
f"Failed to parse tool call arguments: {e}",
exc_info=True,
extra={
"tool_call": tool_calls[yield_idx],
},
)
yield StreamError(
message=f"Invalid tool call arguments: {e}",
timestamp=datetime.now(UTC).isoformat(),
)
return
yield StreamToolCall(
tool_id=tool_calls[yield_idx]["id"],
tool_name=tool_calls[yield_idx]["function"]["name"],
arguments=orjson.loads(tool_calls[yield_idx]["function"]["arguments"]),
arguments=arguments,
timestamp=datetime.now(UTC).isoformat(),
)
tool_execution_response: StreamToolExecutionResult = await execute_tool(
tool_name=tool_calls[yield_idx]["function"]["name"],
parameters=orjson.loads(tool_calls[yield_idx]["function"]["arguments"]),
parameters=arguments,
tool_call_id=tool_calls[yield_idx]["id"],
user_id=session.user_id,
session_id=session.session_id,

View File

@@ -98,3 +98,9 @@ class DatabaseError(Exception):
"""Raised when there is an error interacting with the database"""
pass
class RedisError(Exception):
"""Raised when there is an error interacting with Redis"""
pass