diff --git a/autogpt_platform/backend/backend/api/features/chat/db.py b/autogpt_platform/backend/backend/api/features/chat/db.py index 7d87d9ebd3..8384392246 100644 --- a/autogpt_platform/backend/backend/api/features/chat/db.py +++ b/autogpt_platform/backend/backend/api/features/chat/db.py @@ -1,5 +1,6 @@ """Database operations for chat sessions.""" +import asyncio import logging from datetime import UTC, datetime from typing import Any, cast @@ -10,6 +11,7 @@ from prisma.types import ( ChatMessageCreateInput, ChatSessionCreateInput, ChatSessionUpdateInput, + ChatSessionWhereInput, ) from backend.data.db import transaction @@ -25,7 +27,8 @@ async def get_chat_session(session_id: str) -> PrismaChatSession | None: include={"Messages": True}, ) if session and session.Messages: - # Sort messages by sequence in Python since Prisma doesn't support order_by in include + # Sort messages by sequence in Python - Prisma Python client doesn't support + # order_by in include clauses (unlike Prisma JS), so we sort after fetching session.Messages.sort(key=lambda m: m.sequence) return session @@ -79,6 +82,7 @@ async def update_chat_session( include={"Messages": True}, ) if session and session.Messages: + # Sort in Python - Prisma Python doesn't support order_by in include clauses session.Messages.sort(key=lambda m: m.sequence) return session @@ -95,9 +99,9 @@ async def add_chat_message( function_call: dict[str, Any] | None = None, ) -> PrismaChatMessage: """Add a message to a chat session.""" - # Build the input dict dynamically - only include optional fields when they - # have values, as Prisma TypedDict validation fails when optional fields - # are explicitly set to None + # Build input dict dynamically rather than using ChatMessageCreateInput directly + # because Prisma's TypedDict validation rejects optional fields set to None. + # We only include fields that have values, then cast at the end. data: dict[str, Any] = { "Session": {"connect": {"id": session_id}}, "role": role, @@ -120,15 +124,15 @@ async def add_chat_message( if function_call is not None: data["functionCall"] = SafeJson(function_call) - # Update session's updatedAt timestamp - await PrismaChatSession.prisma().update( - where={"id": session_id}, - data={"updatedAt": datetime.now(UTC)}, - ) - - return await PrismaChatMessage.prisma().create( - data=cast(ChatMessageCreateInput, data) + # Run message create and session timestamp update in parallel for lower latency + _, message = await asyncio.gather( + PrismaChatSession.prisma().update( + where={"id": session_id}, + data={"updatedAt": datetime.now(UTC)}, + ), + PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)), ) + return message async def add_chat_messages_batch( @@ -148,9 +152,9 @@ async def add_chat_messages_batch( async with transaction() as tx: for i, msg in enumerate(messages): - # Build the input dict dynamically - only include optional JSON fields - # when they have values, as Prisma TypedDict validation fails when - # optional fields are explicitly set to None + # Build input dict dynamically rather than using ChatMessageCreateInput + # directly because Prisma's TypedDict validation rejects optional fields + # set to None. We only include fields that have values, then cast. data: dict[str, Any] = { "Session": {"connect": {"id": session_id}}, "role": msg["role"], @@ -178,7 +182,9 @@ async def add_chat_messages_batch( ) created_messages.append(created) - # Update session's updatedAt timestamp within the same transaction + # Update session's updatedAt timestamp within the same transaction. + # Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated + # separately via update_chat_session() after streaming completes. await PrismaChatSession.prisma(tx).update( where={"id": session_id}, data={"updatedAt": datetime.now(UTC)}, @@ -219,8 +225,8 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo True if deleted successfully, False otherwise. """ try: - # Build where clause with optional user_id validation - where_clause: dict[str, Any] = {"id": session_id} + # Build typed where clause with optional user_id validation + where_clause: ChatSessionWhereInput = {"id": session_id} if user_id is not None: where_clause["userId"] = user_id diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/api/features/chat/model.py index 2c001542be..7fd10723c0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/api/features/chat/model.py @@ -32,9 +32,20 @@ from .config import ChatConfig logger = logging.getLogger(__name__) config = ChatConfig() +# Redis cache key prefix for chat sessions +CHAT_SESSION_CACHE_PREFIX = "chat:session:" + + +def _get_session_cache_key(session_id: str) -> str: + """Get the Redis cache key for a chat session.""" + return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}" + + # Session-level locks to prevent race conditions during concurrent upserts. # Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced, # preventing unbounded memory growth while maintaining lock semantics for active sessions. +# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after +# async with lock: completes). Explicit cleanup also occurs in delete_chat_session(). _session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary() _session_locks_mutex = asyncio.Lock() @@ -265,7 +276,7 @@ class ChatSession(BaseModel): async def _get_session_from_cache(session_id: str) -> ChatSession | None: """Get a chat session from Redis cache.""" - redis_key = f"chat:session:{session_id}" + redis_key = _get_session_cache_key(session_id) async_redis = await get_redis_async() raw_session: bytes | None = await async_redis.get(redis_key) @@ -287,7 +298,7 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None: async def _cache_session(session: ChatSession) -> None: """Cache a chat session in Redis.""" - redis_key = f"chat:session:{session.session_id}" + redis_key = _get_session_cache_key(session.session_id) async_redis = await get_redis_async() await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json()) @@ -547,7 +558,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo # Only invalidate cache and clean up lock after DB confirms deletion try: - redis_key = f"chat:session:{session_id}" + redis_key = _get_session_cache_key(session_id) async_redis = await get_redis_async() await async_redis.delete(redis_key) except Exception as e: @@ -582,7 +593,7 @@ async def update_session_title(session_id: str, title: str) -> bool: # Invalidate cache so next fetch gets updated title try: - redis_key = f"chat:session:{session_id}" + redis_key = _get_session_cache_key(session_id) async_redis = await get_redis_async() await async_redis.delete(redis_key) except Exception as e: diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 40db037dfd..b185b3e129 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -1,10 +1,17 @@ +import asyncio import logging from collections.abc import AsyncGenerator from typing import Any import orjson from langfuse import Langfuse -from openai import AsyncOpenAI +from openai import ( + APIConnectionError, + APIError, + APIStatusError, + AsyncOpenAI, + RateLimitError, +) from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam from backend.data.understanding import ( @@ -655,6 +662,31 @@ async def stream_chat_completion( logger.warning(f"Failed to end Langfuse trace: {e}") +# Retry configuration for OpenAI API calls +MAX_RETRIES = 3 +BASE_DELAY_SECONDS = 1.0 +MAX_DELAY_SECONDS = 30.0 + + +def _is_retryable_error(error: Exception) -> bool: + """Determine if an error is retryable.""" + if isinstance(error, RateLimitError): + return True + if isinstance(error, APIConnectionError): + return True + if isinstance(error, APIStatusError): + # APIStatusError has a response with status_code + # Retry on 5xx status codes (server errors) + if error.response.status_code >= 500: + return True + if isinstance(error, APIError): + # Retry on overloaded errors or 500 errors (may not have status code) + error_message = str(error).lower() + if "overloaded" in error_message or "internal server error" in error_message: + return True + return False + + async def _stream_chat_chunks( session: ChatSession, tools: list[ChatCompletionToolParam], @@ -665,6 +697,7 @@ async def _stream_chat_chunks( Pure streaming function for OpenAI chat completions with tool calling. This function is database-agnostic and focuses only on streaming logic. + Implements exponential backoff retry for transient API errors. Args: session: Chat session with conversation history @@ -692,136 +725,172 @@ async def _stream_chat_chunks( # Loop to handle tool calls and continue conversation while True: - try: - logger.info("Creating OpenAI chat completion stream...") + retry_count = 0 + last_error: Exception | None = None - # Create the stream with proper types - stream = await client.chat.completions.create( - model=model, - messages=messages, - tools=tools, - tool_choice="auto", - stream=True, - stream_options={"include_usage": True}, - ) + while retry_count <= MAX_RETRIES: + try: + logger.info( + f"Creating OpenAI chat completion stream..." + f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}" + ) - # Variables to accumulate tool calls - tool_calls: list[dict[str, Any]] = [] - active_tool_call_idx: int | None = None - finish_reason: str | None = None - # Track which tool call indices have had their start event emitted - emitted_start_for_idx: set[int] = set() + # Create the stream with proper types + stream = await client.chat.completions.create( + model=model, + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + stream_options={"include_usage": True}, + ) - # Track if we've started the text block - text_started = False + # Variables to accumulate tool calls + tool_calls: list[dict[str, Any]] = [] + active_tool_call_idx: int | None = None + finish_reason: str | None = None + # Track which tool call indices have had their start event emitted + emitted_start_for_idx: set[int] = set() - # Process the stream - chunk: ChatCompletionChunk - async for chunk in stream: - if chunk.usage: - yield StreamUsage( - promptTokens=chunk.usage.prompt_tokens, - completionTokens=chunk.usage.completion_tokens, - totalTokens=chunk.usage.total_tokens, - ) + # Track if we've started the text block + text_started = False - if chunk.choices: - choice = chunk.choices[0] - delta = choice.delta - - # Capture finish reason - if choice.finish_reason: - finish_reason = choice.finish_reason - logger.info(f"Finish reason: {finish_reason}") - - # Handle content streaming - if delta.content: - # Emit text-start on first text content - if not text_started and text_block_id: - yield StreamTextStart(id=text_block_id) - text_started = True - # Stream the text delta - text_response = StreamTextDelta( - id=text_block_id or "", - delta=delta.content, + # Process the stream + chunk: ChatCompletionChunk + async for chunk in stream: + if chunk.usage: + yield StreamUsage( + promptTokens=chunk.usage.prompt_tokens, + completionTokens=chunk.usage.completion_tokens, + totalTokens=chunk.usage.total_tokens, ) - yield text_response - # Handle tool calls - if delta.tool_calls: - for tc_chunk in delta.tool_calls: - idx = tc_chunk.index + if chunk.choices: + choice = chunk.choices[0] + delta = choice.delta - # Update active tool call index if needed - if ( - active_tool_call_idx is None - or active_tool_call_idx != idx - ): - active_tool_call_idx = idx + # Capture finish reason + if choice.finish_reason: + finish_reason = choice.finish_reason + logger.info(f"Finish reason: {finish_reason}") - # Ensure we have a tool call object at this index - while len(tool_calls) <= idx: - tool_calls.append( - { - "id": "", - "type": "function", - "function": { - "name": "", - "arguments": "", + # Handle content streaming + if delta.content: + # Emit text-start on first text content + if not text_started and text_block_id: + yield StreamTextStart(id=text_block_id) + text_started = True + # Stream the text delta + text_response = StreamTextDelta( + id=text_block_id or "", + delta=delta.content, + ) + yield text_response + + # Handle tool calls + if delta.tool_calls: + for tc_chunk in delta.tool_calls: + idx = tc_chunk.index + + # Update active tool call index if needed + if ( + active_tool_call_idx is None + or active_tool_call_idx != idx + ): + active_tool_call_idx = idx + + # Ensure we have a tool call object at this index + while len(tool_calls) <= idx: + tool_calls.append( + { + "id": "", + "type": "function", + "function": { + "name": "", + "arguments": "", + }, }, - }, - ) + ) - # Accumulate the tool call data - if tc_chunk.id: - tool_calls[idx]["id"] = tc_chunk.id - if tc_chunk.function: - if tc_chunk.function.name: - tool_calls[idx]["function"][ - "name" - ] = tc_chunk.function.name - if tc_chunk.function.arguments: - tool_calls[idx]["function"][ - "arguments" - ] += tc_chunk.function.arguments + # Accumulate the tool call data + if tc_chunk.id: + tool_calls[idx]["id"] = tc_chunk.id + if tc_chunk.function: + if tc_chunk.function.name: + tool_calls[idx]["function"][ + "name" + ] = tc_chunk.function.name + if tc_chunk.function.arguments: + tool_calls[idx]["function"][ + "arguments" + ] += tc_chunk.function.arguments - # Emit StreamToolInputStart only after we have the tool call ID - if ( - idx not in emitted_start_for_idx - and tool_calls[idx]["id"] - and tool_calls[idx]["function"]["name"] - ): - yield StreamToolInputStart( - toolCallId=tool_calls[idx]["id"], - toolName=tool_calls[idx]["function"]["name"], - ) - emitted_start_for_idx.add(idx) - logger.info(f"Stream complete. Finish reason: {finish_reason}") + # Emit StreamToolInputStart only after we have the tool call ID + if ( + idx not in emitted_start_for_idx + and tool_calls[idx]["id"] + and tool_calls[idx]["function"]["name"] + ): + yield StreamToolInputStart( + toolCallId=tool_calls[idx]["id"], + toolName=tool_calls[idx]["function"]["name"], + ) + emitted_start_for_idx.add(idx) + logger.info(f"Stream complete. Finish reason: {finish_reason}") - # Yield all accumulated tool calls after the stream is complete - # This ensures all tool call arguments have been fully received - for idx, tool_call in enumerate(tool_calls): - try: - async for tc in _yield_tool_call(tool_calls, idx, session): - yield tc - except (orjson.JSONDecodeError, KeyError, TypeError) as e: + # Yield all accumulated tool calls after the stream is complete + # This ensures all tool call arguments have been fully received + for idx, tool_call in enumerate(tool_calls): + try: + async for tc in _yield_tool_call(tool_calls, idx, session): + yield tc + except (orjson.JSONDecodeError, KeyError, TypeError) as e: + logger.error( + f"Failed to parse tool call {idx}: {e}", + exc_info=True, + extra={"tool_call": tool_call}, + ) + yield StreamError( + errorText=f"Invalid tool call arguments for tool {tool_call.get('function', {}).get('name', 'unknown')}: {e}", + ) + # Re-raise to trigger retry logic in the parent function + raise + + yield StreamFinish() + return + except Exception as e: + last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: + retry_count += 1 + # Calculate delay with exponential backoff + delay = min( + BASE_DELAY_SECONDS * (2 ** (retry_count - 1)), + MAX_DELAY_SECONDS, + ) + logger.warning( + f"Retryable error in stream: {e!s}. " + f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})" + ) + await asyncio.sleep(delay) + continue # Retry the stream + else: + # Non-retryable error or max retries exceeded logger.error( - f"Failed to parse tool call {idx}: {e}", + f"Error in stream (not retrying): {e!s}", exc_info=True, - extra={"tool_call": tool_call}, ) - yield StreamError( - errorText=f"Invalid tool call arguments for tool {tool_call.get('function', {}).get('name', 'unknown')}: {e}", - ) - # Re-raise to trigger retry logic in the parent function - raise + error_response = StreamError(errorText=str(e)) + yield error_response + yield StreamFinish() + return - yield StreamFinish() - return - except Exception as e: - logger.error(f"Error in stream: {e!s}", exc_info=True) - error_response = StreamError(errorText=str(e)) - yield error_response + # If we exit the retry loop without returning, it means we exhausted retries + if last_error: + logger.error( + f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}", + exc_info=True, + ) + yield StreamError(errorText=f"Max retries exceeded: {last_error!s}") yield StreamFinish() return