mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
feat(chat): Enhance chat session management and error handling
- Added chat routes to the REST API for improved session management. - Introduced RedisError exception for better error handling when interacting with Redis. - Updated ChatSession model to include credentials and improved session validation. - Enhanced error logging and handling in chat streaming functions to ensure robustness. - Removed unused LOGIN_NEEDED response type from models for cleaner code.
This commit is contained in:
@@ -27,6 +27,7 @@ import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.builder
|
||||
import backend.server.v2.builder.routes
|
||||
import backend.server.v2.chat.routes as chat_routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
@@ -284,6 +285,7 @@ app.include_router(
|
||||
tags=["v1", "email"],
|
||||
prefix="/api/email",
|
||||
)
|
||||
app.include_router(chat_routes.router, tags=["v2", "chat"], prefix="/api/chat")
|
||||
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.cache import async_redis
|
||||
from backend.util.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -46,6 +47,7 @@ class ChatSession(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@@ -56,6 +58,7 @@ class ChatSession(BaseModel):
|
||||
user_id=user_id,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
@@ -102,13 +105,27 @@ class ChatSession(BaseModel):
|
||||
if message.tool_calls:
|
||||
t: list[ChatCompletionMessageToolCallParam] = []
|
||||
for tool_call in message.tool_calls:
|
||||
# Tool calls are stored with nested structure: {id, type, function: {name, arguments}}
|
||||
function_data = tool_call.get("function", {})
|
||||
|
||||
# Skip tool calls that are missing required fields
|
||||
if "id" not in tool_call or "name" not in function_data:
|
||||
logger.warning(
|
||||
f"Skipping invalid tool call: missing required fields. "
|
||||
f"Got: {tool_call.keys()}, function keys: {function_data.keys()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Arguments are stored as a JSON string
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
|
||||
t.append(
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=tool_call["arguments"],
|
||||
name=tool_call["name"],
|
||||
arguments=arguments_str,
|
||||
name=function_data["name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -142,15 +159,19 @@ async def get_chat_session(
|
||||
"""Get a chat session by ID."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
|
||||
raw_session: bytes = await async_redis.get(redis_key)
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if not raw_session:
|
||||
logger.warning(f"Session {session_id} not found")
|
||||
if raw_session is None:
|
||||
logger.warning(f"Session {session_id} not found in Redis")
|
||||
return None
|
||||
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
if session.user_id != user_id:
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
@@ -171,6 +192,8 @@ async def upsert_chat_session(
|
||||
)
|
||||
|
||||
if not resp:
|
||||
raise Exception(f"Failed to update chat session: {resp}")
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
@@ -10,7 +10,6 @@ class ResponseType(str, Enum):
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_RESPONSE = "tool_response"
|
||||
LOGIN_NEEDED = "login_needed"
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
STREAM_END = "stream_end"
|
||||
@@ -57,22 +56,8 @@ class StreamToolExecutionResult(StreamBaseResponse):
|
||||
)
|
||||
|
||||
|
||||
class StreamLoginNeeded(StreamBaseResponse):
|
||||
"""Authentication required notification."""
|
||||
|
||||
type: ResponseType = ResponseType.LOGIN_NEEDED
|
||||
message: str = Field(..., description="Message explaining why login is needed")
|
||||
session_id: str = Field(..., description="Current session ID to preserve")
|
||||
agent_info: dict[str, Any] | None = Field(
|
||||
default=None, description="Agent context if applicable"
|
||||
)
|
||||
required_action: str = Field(
|
||||
default="login", description="Required action (login/signup)"
|
||||
)
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Error response."""
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
prompt_tokens: int
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -149,15 +151,36 @@ async def stream_chat(
|
||||
message: The user's new message to process.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Yields:
|
||||
SSE-formatted response chunks.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id, message, user_id
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. "
|
||||
)
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id, message, user_id
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
|
||||
@@ -37,9 +37,11 @@ async def create_chat_session(
|
||||
user_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Create a new chat session.
|
||||
Create a new chat session and persist it to the database.
|
||||
"""
|
||||
return ChatSession.new(user_id)
|
||||
session = ChatSession.new(user_id)
|
||||
# Persist the session immediately so it can be used for streaming
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
|
||||
async def get_session(
|
||||
@@ -74,17 +76,19 @@ async def stream_chat_completion(
|
||||
"""Main entry point for streaming chat completions with database handling.
|
||||
|
||||
This function handles all database operations and delegates streaming
|
||||
to the pure stream_chat_response function.
|
||||
to the internal _stream_chat_chunks function.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID
|
||||
user_message: User's input message
|
||||
user_id: User ID for authentication
|
||||
model: OpenAI model to use
|
||||
max_messages: Maximum context messages to include
|
||||
user_id: User ID for authentication (None for anonymous)
|
||||
|
||||
Yields:
|
||||
SSE formatted JSON strings with response data
|
||||
StreamBaseResponse objects formatted as SSE
|
||||
|
||||
Raises:
|
||||
NotFoundError: If session_id is invalid
|
||||
ValueError: If max_context_messages is exceeded
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
@@ -109,7 +113,7 @@ async def stream_chat_completion(
|
||||
session = await upsert_chat_session(session)
|
||||
assert session, "Session not found"
|
||||
|
||||
assistant_repsonse = ChatMessage(
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
@@ -123,9 +127,15 @@ async def stream_chat_completion(
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_repsonse.content += chunk.content
|
||||
assistant_response.content += chunk.content
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolCall):
|
||||
# Convert arguments dict to JSON string for consistent storage
|
||||
arguments_str = (
|
||||
orjson.dumps(chunk.arguments).decode("utf-8")
|
||||
if chunk.arguments
|
||||
else "{}"
|
||||
)
|
||||
tool_call_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
@@ -135,7 +145,7 @@ async def stream_chat_completion(
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.tool_name,
|
||||
"arguments": chunk.arguments,
|
||||
"arguments": arguments_str,
|
||||
},
|
||||
}
|
||||
],
|
||||
@@ -167,13 +177,16 @@ async def stream_chat_completion(
|
||||
else:
|
||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {e!s}", exc_info=True)
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
# Always yield error response if we haven't already
|
||||
if not has_yielded_error:
|
||||
error_response = StreamError(
|
||||
message=str(e),
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
yield error_response
|
||||
has_yielded_error = True
|
||||
# Always yield end marker after error
|
||||
if not has_yielded_end:
|
||||
yield StreamEnd(
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
@@ -185,7 +198,10 @@ async def stream_chat_completion(
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}"
|
||||
)
|
||||
session.messages.append(assistant_repsonse)
|
||||
# Only append assistant response if it has content or tool calls
|
||||
# to avoid saving empty messages on errors
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
session.messages.append(assistant_response)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
|
||||
@@ -225,10 +241,9 @@ async def _stream_chat_chunks(
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Variables to accumulate the response
|
||||
assistant_message: str = ""
|
||||
# Variables to accumulate tool calls
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
active_tool_call_idx = None
|
||||
active_tool_call_idx: int | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process the stream
|
||||
@@ -253,7 +268,6 @@ async def _stream_chat_chunks(
|
||||
|
||||
# Handle content streaming
|
||||
if delta.content:
|
||||
assistant_message += delta.content
|
||||
# Stream the text chunk
|
||||
text_response = StreamTextChunk(
|
||||
content=delta.content,
|
||||
@@ -268,12 +282,17 @@ async def _stream_chat_chunks(
|
||||
if active_tool_call_idx is None:
|
||||
active_tool_call_idx = idx
|
||||
|
||||
# When we start receiving a new tool call (higher index),
|
||||
# yield the previous one since it's now complete
|
||||
# (OpenAI streams tool calls with incrementing indices)
|
||||
if active_tool_call_idx != idx:
|
||||
yield_idx = idx - 1
|
||||
async for tc in _yield_tool_call(
|
||||
tool_calls, yield_idx, session
|
||||
):
|
||||
yield tc
|
||||
# Update to track the new active tool call
|
||||
active_tool_call_idx = idx
|
||||
|
||||
# Ensure we have a tool call object at this index
|
||||
while len(tool_calls) <= idx:
|
||||
@@ -302,11 +321,19 @@ async def _stream_chat_chunks(
|
||||
] += tc_chunk.function.arguments
|
||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||
|
||||
if active_tool_call_idx is not None:
|
||||
# Yield the final tool call if any were accumulated
|
||||
if active_tool_call_idx is not None and active_tool_call_idx < len(
|
||||
tool_calls
|
||||
):
|
||||
async for tc in _yield_tool_call(
|
||||
tool_calls, active_tool_call_idx, session
|
||||
):
|
||||
yield tc
|
||||
elif active_tool_call_idx is not None:
|
||||
logger.warning(
|
||||
f"Active tool call index {active_tool_call_idx} out of bounds "
|
||||
f"(tool_calls length: {len(tool_calls)})"
|
||||
)
|
||||
|
||||
yield StreamEnd(
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
@@ -334,16 +361,34 @@ async def _yield_tool_call(
|
||||
Yield a tool call.
|
||||
"""
|
||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
||||
|
||||
# Parse tool call arguments with error handling
|
||||
try:
|
||||
arguments = orjson.loads(tool_calls[yield_idx]["function"]["arguments"])
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"Failed to parse tool call arguments: {e}",
|
||||
exc_info=True,
|
||||
extra={
|
||||
"tool_call": tool_calls[yield_idx],
|
||||
},
|
||||
)
|
||||
yield StreamError(
|
||||
message=f"Invalid tool call arguments: {e}",
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
return
|
||||
|
||||
yield StreamToolCall(
|
||||
tool_id=tool_calls[yield_idx]["id"],
|
||||
tool_name=tool_calls[yield_idx]["function"]["name"],
|
||||
arguments=orjson.loads(tool_calls[yield_idx]["function"]["arguments"]),
|
||||
arguments=arguments,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
tool_execution_response: StreamToolExecutionResult = await execute_tool(
|
||||
tool_name=tool_calls[yield_idx]["function"]["name"],
|
||||
parameters=orjson.loads(tool_calls[yield_idx]["function"]["arguments"]),
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_calls[yield_idx]["id"],
|
||||
user_id=session.user_id,
|
||||
session_id=session.session_id,
|
||||
|
||||
@@ -98,3 +98,9 @@ class DatabaseError(Exception):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RedisError(Exception):
|
||||
"""Raised when there is an error interacting with Redis"""
|
||||
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user