mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
update service to process chunks
This commit is contained in:
@@ -22,7 +22,7 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
|
||||
# System Prompt Configuration
|
||||
system_prompt_path: str = Field(
|
||||
default="prompts/chat_system.md",
|
||||
@@ -33,7 +33,7 @@ class ChatConfig(BaseSettings):
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
)
|
||||
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@@ -110,4 +110,4 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime, UTC
|
||||
|
||||
|
||||
from backend.util.cache import async_redis
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.exceptions import NotFoundError
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionDeveloperMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionDeveloperMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function, ChatCompletionMessageToolCallParam
|
||||
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
|
||||
import uuid
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.util.cache import async_redis
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
@@ -38,9 +39,9 @@ class ChatSession(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> 'ChatSession':
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -83,22 +84,24 @@ class ChatSession(BaseModel):
|
||||
)
|
||||
if message.function_call:
|
||||
m["function_call"] = FunctionCall(
|
||||
arguments=message.function_call["arguments"],
|
||||
name=message.function_call["name"],
|
||||
)
|
||||
arguments=message.function_call["arguments"],
|
||||
name=message.function_call["name"],
|
||||
)
|
||||
if message.refusal:
|
||||
m["refusal"] = message.refusal
|
||||
if message.tool_calls:
|
||||
t: list[ChatCompletionMessageToolCallParam] = []
|
||||
for tool_call in message.tool_calls:
|
||||
t.append(ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=tool_call["arguments"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
))
|
||||
t.append(
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=tool_call["arguments"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
m["tool_calls"] = t
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
@@ -112,11 +115,13 @@ class ChatSession(BaseModel):
|
||||
)
|
||||
)
|
||||
elif message.role == "function":
|
||||
messages.append(ChatCompletionFunctionMessageParam(
|
||||
role="function",
|
||||
content=message.content,
|
||||
name=message.name or "",
|
||||
))
|
||||
messages.append(
|
||||
ChatCompletionFunctionMessageParam(
|
||||
role="function",
|
||||
content=message.content,
|
||||
name=message.name or "",
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
@@ -151,7 +156,7 @@ async def upsert_chat_session(
|
||||
redis_key, config.session_ttl, session.model_dump_json()
|
||||
)
|
||||
|
||||
if resp != True:
|
||||
if not resp:
|
||||
raise Exception(f"Failed to update chat session: {resp}")
|
||||
|
||||
return session
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
from backend.server.v2.chat.data import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
import pytest
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from backend.server.v2.chat.data import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
content="Hello, how are you?",
|
||||
role="user"
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[{
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\": \"New York\"}"
|
||||
}
|
||||
}]
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123"
|
||||
),
|
||||
]
|
||||
|
||||
"arguments": '{"city": "New York"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -38,26 +42,27 @@ async def test_chatsession_serialization_deserialization():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatsession_redis_storage():
|
||||
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages
|
||||
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
|
||||
assert s2 == s
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
|
||||
assert s2 is None
|
||||
|
||||
assert s2 is None
|
||||
|
||||
@@ -21,11 +21,11 @@ class StreamBaseResponse(BaseModel):
|
||||
type: ResponseType
|
||||
timestamp: str | None = None
|
||||
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class StreamTextChunk(StreamBaseResponse):
|
||||
"""Streaming text content from the assistant."""
|
||||
|
||||
@@ -44,7 +44,7 @@ class StreamToolCall(StreamBaseResponse):
|
||||
)
|
||||
|
||||
|
||||
class StreamToolResponse(StreamBaseResponse):
|
||||
class StreamToolExecutionResult(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_RESPONSE
|
||||
@@ -80,10 +80,11 @@ class StreamError(StreamBaseResponse):
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
|
||||
|
||||
class StreamEnd(StreamBaseResponse):
|
||||
"""End of stream marker."""
|
||||
|
||||
type: ResponseType = ResponseType.STREAM_END
|
||||
summary: dict[str, Any] | None = Field(
|
||||
default=None, description="Stream summary statistics"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,25 +1,28 @@
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from datetime import datetime, UTC
|
||||
|
||||
import orjson
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
import backend.server.v2.chat.config
|
||||
from backend.server.v2.chat.data import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.server.v2.chat.models import (
|
||||
StreamTextChunk,
|
||||
StreamToolCall,
|
||||
StreamToolResponse,
|
||||
StreamError,
|
||||
StreamEnd,
|
||||
ResponseType,
|
||||
StreamBaseResponse,
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
StreamToolCall,
|
||||
StreamToolExecutionResult,
|
||||
)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,11 +35,11 @@ async def execute_tool(
|
||||
parameters: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
) -> StreamToolResponse:
|
||||
) -> StreamToolExecutionResult:
|
||||
"""
|
||||
TODO: Implement tool execution.
|
||||
"""
|
||||
return StreamToolResponse(
|
||||
return StreamToolExecutionResult(
|
||||
type=ResponseType.TOOL_RESPONSE,
|
||||
tool_id=tool_name,
|
||||
tool_name=tool_name,
|
||||
@@ -51,7 +54,7 @@ async def stream_chat_completion(
|
||||
user_message: str,
|
||||
user_id: str | None,
|
||||
max_messages: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Main entry point for streaming chat completions with database handling.
|
||||
|
||||
This function handles all database operations and delegates streaming
|
||||
@@ -68,21 +71,84 @@ async def stream_chat_completion(
|
||||
SSE formatted JSON strings with response data
|
||||
|
||||
"""
|
||||
# TODO: Implement this function once db operations are implemented
|
||||
async for chunk in stream_chat_response(
|
||||
messages=[],
|
||||
tools=[],
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
assistant_repsonse = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
try:
|
||||
async for chunk in stream_chat_response(
|
||||
session=session,
|
||||
tools=[],
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_repsonse.content += chunk.content
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolCall):
|
||||
tool_call_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": chunk.tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.tool_name,
|
||||
"arguments": chunk.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
session.messages.append(tool_call_response)
|
||||
elif isinstance(chunk, StreamToolExecutionResult):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=orjson.dumps(chunk.result).decode("utf-8"),
|
||||
tool_call_id=chunk.tool_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(chunk, StreamEnd):
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
yield chunk
|
||||
else:
|
||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {e!s}", exc_info=True)
|
||||
if not has_yielded_error:
|
||||
error_response = StreamError(
|
||||
message=str(e),
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
yield error_response
|
||||
if not has_yielded_end:
|
||||
yield StreamEnd(
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
has_yielded_end = True
|
||||
|
||||
finally:
|
||||
# We always upsert the session even if an error occurs
|
||||
# So we dont lose track of tool call executions
|
||||
session.messages.append(assistant_repsonse)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
|
||||
async def stream_chat_response(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""
|
||||
Pure streaming function for OpenAI chat completions with tool calling.
|
||||
@@ -110,7 +176,7 @@ async def stream_chat_response(
|
||||
# Create the stream with proper types
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
messages=session.to_openai_messages(),
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
@@ -183,12 +249,12 @@ async def stream_chat_response(
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
tool_execution_response: StreamToolResponse = (
|
||||
tool_execution_response: StreamToolExecutionResult = (
|
||||
await execute_tool(
|
||||
tool_calls[idx]["function"]["name"],
|
||||
tool_calls[idx]["function"]["arguments"],
|
||||
user_id=user_id,
|
||||
session_id=session_id or "",
|
||||
user_id=session.user_id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
)
|
||||
yield tool_execution_response
|
||||
|
||||
@@ -20,7 +20,8 @@ from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
|
||||
|
||||
from redis import ConnectionPool, Redis
|
||||
from redis.asyncio import Redis as AsyncRedis, ConnectionPool as AsyncConnectionPool
|
||||
from redis.asyncio import ConnectionPool as AsyncConnectionPool
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -62,16 +63,79 @@ def _get_cache_pool() -> ConnectionPool:
|
||||
|
||||
redis = Redis(connection_pool=_get_cache_pool())
|
||||
|
||||
async_redis = AsyncRedis(connection_pool=AsyncConnectionPool(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
max_connections=50,
|
||||
socket_keepalive=True,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
))
|
||||
_async_cache_pools: dict[asyncio.AbstractEventLoop, AsyncConnectionPool] = {}
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring async cache connection pool")
|
||||
async def _get_async_cache_pool() -> AsyncConnectionPool:
|
||||
"""Get or create an async connection pool for the current event loop."""
|
||||
global _async_cache_pools
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
raise RuntimeError("No running event loop")
|
||||
|
||||
if loop not in _async_cache_pools:
|
||||
_async_cache_pools[loop] = AsyncConnectionPool(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
max_connections=50,
|
||||
socket_keepalive=True,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _async_cache_pools[loop]
|
||||
|
||||
|
||||
# Store async Redis clients per event loop to avoid event loop conflicts
|
||||
_async_redis_clients: dict[asyncio.AbstractEventLoop, AsyncRedis] = {}
|
||||
_async_redis_locks: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {}
|
||||
|
||||
|
||||
async def get_async_redis() -> AsyncRedis:
|
||||
"""Get or create an async Redis client for the current event loop."""
|
||||
global _async_redis_clients, _async_redis_locks
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
raise RuntimeError("No running event loop")
|
||||
|
||||
# Get or create lock for this event loop
|
||||
if loop not in _async_redis_locks:
|
||||
_async_redis_locks[loop] = asyncio.Lock()
|
||||
|
||||
lock = _async_redis_locks[loop]
|
||||
|
||||
# Check if we need to create a new client
|
||||
if loop not in _async_redis_clients:
|
||||
async with lock:
|
||||
# Double-checked locking to handle multiple awaiters
|
||||
if loop not in _async_redis_clients:
|
||||
pool = await _get_async_cache_pool()
|
||||
_async_redis_clients[loop] = AsyncRedis(connection_pool=pool)
|
||||
|
||||
return _async_redis_clients[loop]
|
||||
|
||||
|
||||
# For backward compatibility, create a proxy object that lazily initializes
|
||||
class AsyncRedisProxy:
|
||||
"""Proxy for async Redis that lazily initializes the connection."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
# This will be called when any method is accessed
|
||||
async def async_method(*args, **kwargs):
|
||||
client = await get_async_redis()
|
||||
method = getattr(client, name)
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
return async_method
|
||||
|
||||
|
||||
async_redis = AsyncRedisProxy()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user