update service to process chunks

This commit is contained in:
Swifty
2025-10-24 11:26:20 +02:00
parent 821d4cde21
commit b26b0cd23e
6 changed files with 252 additions and 111 deletions

View File

@@ -22,7 +22,7 @@ class ChatConfig(BaseSettings):
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# System Prompt Configuration
system_prompt_path: str = Field(
default="prompts/chat_system.md",
@@ -33,7 +33,7 @@ class ChatConfig(BaseSettings):
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
@field_validator("api_key", mode="before")
@@ -110,4 +110,4 @@ class ChatConfig(BaseSettings):
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore" # Ignore extra environment variables
extra = "ignore" # Ignore extra environment variables

View File

@@ -1,24 +1,25 @@
import json
from pydantic import BaseModel
from datetime import datetime, UTC
from backend.util.cache import async_redis
from backend.server.v2.chat.config import ChatConfig
from backend.util.exceptions import NotFoundError
import uuid
from datetime import UTC, datetime
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionDeveloperMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionDeveloperMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import Function, ChatCompletionMessageToolCallParam
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
import uuid
from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from pydantic import BaseModel
from backend.server.v2.chat.config import ChatConfig
from backend.util.cache import async_redis
config = ChatConfig()
@@ -38,9 +39,9 @@ class ChatSession(BaseModel):
messages: list[ChatMessage]
started_at: datetime
updated_at: datetime
@staticmethod
def new(user_id: str | None) -> 'ChatSession':
def new(user_id: str | None) -> "ChatSession":
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -83,22 +84,24 @@ class ChatSession(BaseModel):
)
if message.function_call:
m["function_call"] = FunctionCall(
arguments=message.function_call["arguments"],
name=message.function_call["name"],
)
arguments=message.function_call["arguments"],
name=message.function_call["name"],
)
if message.refusal:
m["refusal"] = message.refusal
if message.tool_calls:
t: list[ChatCompletionMessageToolCallParam] = []
for tool_call in message.tool_calls:
t.append(ChatCompletionMessageToolCallParam(
id=tool_call["id"],
type="function",
function=Function(
arguments=tool_call["arguments"],
name=tool_call["name"],
),
))
t.append(
ChatCompletionMessageToolCallParam(
id=tool_call["id"],
type="function",
function=Function(
arguments=tool_call["arguments"],
name=tool_call["name"],
),
)
)
m["tool_calls"] = t
if message.name:
m["name"] = message.name
@@ -112,11 +115,13 @@ class ChatSession(BaseModel):
)
)
elif message.role == "function":
messages.append(ChatCompletionFunctionMessageParam(
role="function",
content=message.content,
name=message.name or "",
))
messages.append(
ChatCompletionFunctionMessageParam(
role="function",
content=message.content,
name=message.name or "",
)
)
return messages
@@ -151,7 +156,7 @@ async def upsert_chat_session(
redis_key, config.session_ttl, session.model_dump_json()
)
if resp != True:
if not resp:
raise Exception(f"Failed to update chat session: {resp}")
return session

View File

@@ -1,30 +1,34 @@
from backend.server.v2.chat.data import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
import pytest
from datetime import datetime, UTC
from backend.server.v2.chat.data import (
ChatMessage,
ChatSession,
get_chat_session,
upsert_chat_session,
)
messages = [
ChatMessage(
content="Hello, how are you?",
role="user"
),
ChatMessage(
content="I'm fine, thank you!",
role="assistant",
tool_calls=[{
ChatMessage(content="Hello, how are you?", role="user"),
ChatMessage(
content="I'm fine, thank you!",
role="assistant",
tool_calls=[
{
"id": "t123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\": \"New York\"}"
}
}]
),
ChatMessage(
content="I'm using the tool to get the weather",
role="tool",
tool_call_id="t123"
),
]
"arguments": '{"city": "New York"}',
},
}
],
),
ChatMessage(
content="I'm using the tool to get the weather",
role="tool",
tool_call_id="t123",
),
]
@pytest.mark.asyncio
@@ -38,26 +42,27 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio
async def test_chatsession_redis_storage():
s = ChatSession.new(user_id=None)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 == s
@pytest.mark.asyncio
async def test_chatsession_redis_storage_user_id_mismatch():
s = ChatSession.new(user_id="abc123")
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, None)
assert s2 is None
assert s2 is None

View File

@@ -21,11 +21,11 @@ class StreamBaseResponse(BaseModel):
type: ResponseType
timestamp: str | None = None
def to_sse(self) -> str:
"""Convert to SSE format."""
return f"data: {self.model_dump_json()}\n\n"
class StreamTextChunk(StreamBaseResponse):
"""Streaming text content from the assistant."""
@@ -44,7 +44,7 @@ class StreamToolCall(StreamBaseResponse):
)
class StreamToolResponse(StreamBaseResponse):
class StreamToolExecutionResult(StreamBaseResponse):
"""Tool execution result."""
type: ResponseType = ResponseType.TOOL_RESPONSE
@@ -80,10 +80,11 @@ class StreamError(StreamBaseResponse):
default=None, description="Additional error details"
)
class StreamEnd(StreamBaseResponse):
"""End of stream marker."""
type: ResponseType = ResponseType.STREAM_END
summary: dict[str, Any] | None = Field(
default=None, description="Stream summary statistics"
)
)

View File

@@ -1,25 +1,28 @@
import logging
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
from datetime import datetime, UTC
import orjson
from openai import AsyncOpenAI
from openai.types.chat import (
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
import backend.server.v2.chat.config
from backend.server.v2.chat.data import (
ChatMessage,
ChatSession,
get_chat_session,
upsert_chat_session,
)
from backend.server.v2.chat.models import (
StreamTextChunk,
StreamToolCall,
StreamToolResponse,
StreamError,
StreamEnd,
ResponseType,
StreamBaseResponse,
StreamEnd,
StreamError,
StreamTextChunk,
StreamToolCall,
StreamToolExecutionResult,
)
import logging
logger = logging.getLogger(__name__)
@@ -32,11 +35,11 @@ async def execute_tool(
parameters: dict[str, Any],
user_id: str | None,
session_id: str,
) -> StreamToolResponse:
) -> StreamToolExecutionResult:
"""
TODO: Implement tool execution.
"""
return StreamToolResponse(
return StreamToolExecutionResult(
type=ResponseType.TOOL_RESPONSE,
tool_id=tool_name,
tool_name=tool_name,
@@ -51,7 +54,7 @@ async def stream_chat_completion(
user_message: str,
user_id: str | None,
max_messages: int = 50,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling.
This function handles all database operations and delegates streaming
@@ -68,21 +71,84 @@ async def stream_chat_completion(
SSE formatted JSON strings with response data
"""
# TODO: Implement this function once db operations are implemented
async for chunk in stream_chat_response(
messages=[],
tools=[],
session_id=session_id,
user_id=user_id,
):
yield chunk.to_sse()
session = await get_chat_session(session_id, user_id)
if not session:
session = ChatSession.new(user_id)
assistant_repsonse = ChatMessage(
role="assistant",
content="",
)
has_yielded_end = False
has_yielded_error = False
try:
async for chunk in stream_chat_response(
session=session,
tools=[],
):
if isinstance(chunk, StreamTextChunk):
assistant_repsonse.content += chunk.content
yield chunk
elif isinstance(chunk, StreamToolCall):
tool_call_response = ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": chunk.tool_id,
"type": "function",
"function": {
"name": chunk.tool_name,
"arguments": chunk.arguments,
},
}
],
)
session.messages.append(tool_call_response)
elif isinstance(chunk, StreamToolExecutionResult):
session.messages.append(
ChatMessage(
role="tool",
content=orjson.dumps(chunk.result).decode("utf-8"),
tool_call_id=chunk.tool_id,
)
)
elif isinstance(chunk, StreamEnd):
has_yielded_end = True
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
yield chunk
else:
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
except Exception as e:
logger.error(f"Error in stream: {e!s}", exc_info=True)
if not has_yielded_error:
error_response = StreamError(
message=str(e),
timestamp=datetime.now(UTC).isoformat(),
)
yield error_response
if not has_yielded_end:
yield StreamEnd(
timestamp=datetime.now(UTC).isoformat(),
)
has_yielded_end = True
finally:
# We always upsert the session even if an error occurs
# So we dont lose track of tool call executions
session.messages.append(assistant_repsonse)
await upsert_chat_session(session)
async def stream_chat_response(
messages: list[ChatCompletionMessageParam],
session: ChatSession,
tools: list[ChatCompletionToolParam],
session_id: str,
user_id: str | None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""
Pure streaming function for OpenAI chat completions with tool calling.
@@ -110,7 +176,7 @@ async def stream_chat_response(
# Create the stream with proper types
stream = await client.chat.completions.create(
model=model,
messages=messages,
messages=session.to_openai_messages(),
tools=tools,
tool_choice="auto",
stream=True,
@@ -183,12 +249,12 @@ async def stream_chat_response(
timestamp=datetime.now(UTC).isoformat(),
)
tool_execution_response: StreamToolResponse = (
tool_execution_response: StreamToolExecutionResult = (
await execute_tool(
tool_calls[idx]["function"]["name"],
tool_calls[idx]["function"]["arguments"],
user_id=user_id,
session_id=session_id or "",
user_id=session.user_id,
session_id=session.session_id,
)
)
yield tool_execution_response

View File

@@ -20,7 +20,8 @@ from functools import wraps
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
from redis import ConnectionPool, Redis
from redis.asyncio import Redis as AsyncRedis, ConnectionPool as AsyncConnectionPool
from redis.asyncio import ConnectionPool as AsyncConnectionPool
from redis.asyncio import Redis as AsyncRedis
from backend.util.retry import conn_retry
from backend.util.settings import Settings
@@ -62,16 +63,79 @@ def _get_cache_pool() -> ConnectionPool:
redis = Redis(connection_pool=_get_cache_pool())
async_redis = AsyncRedis(connection_pool=AsyncConnectionPool(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
decode_responses=False, # Binary mode for pickle
max_connections=50,
socket_keepalive=True,
socket_connect_timeout=5,
retry_on_timeout=True,
))
_async_cache_pools: dict[asyncio.AbstractEventLoop, AsyncConnectionPool] = {}
@conn_retry("Redis", "Acquiring async cache connection pool")
async def _get_async_cache_pool() -> AsyncConnectionPool:
"""Get or create an async connection pool for the current event loop."""
global _async_cache_pools
try:
loop = asyncio.get_running_loop()
except RuntimeError:
raise RuntimeError("No running event loop")
if loop not in _async_cache_pools:
_async_cache_pools[loop] = AsyncConnectionPool(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
decode_responses=False, # Binary mode for pickle
max_connections=50,
socket_keepalive=True,
socket_connect_timeout=5,
retry_on_timeout=True,
)
return _async_cache_pools[loop]
# Store async Redis clients per event loop to avoid event loop conflicts
_async_redis_clients: dict[asyncio.AbstractEventLoop, AsyncRedis] = {}
_async_redis_locks: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {}
async def get_async_redis() -> AsyncRedis:
"""Get or create an async Redis client for the current event loop."""
global _async_redis_clients, _async_redis_locks
try:
loop = asyncio.get_running_loop()
except RuntimeError:
raise RuntimeError("No running event loop")
# Get or create lock for this event loop
if loop not in _async_redis_locks:
_async_redis_locks[loop] = asyncio.Lock()
lock = _async_redis_locks[loop]
# Check if we need to create a new client
if loop not in _async_redis_clients:
async with lock:
# Double-checked locking to handle multiple awaiters
if loop not in _async_redis_clients:
pool = await _get_async_cache_pool()
_async_redis_clients[loop] = AsyncRedis(connection_pool=pool)
return _async_redis_clients[loop]
# For backward compatibility, create a proxy object that lazily initializes
class AsyncRedisProxy:
"""Proxy for async Redis that lazily initializes the connection."""
def __getattr__(self, name):
# This will be called when any method is accessed
async def async_method(*args, **kwargs):
client = await get_async_redis()
method = getattr(client, name)
return await method(*args, **kwargs)
return async_method
async_redis = AsyncRedisProxy()
@dataclass