Compare commits

..

1 Commits

Author SHA1 Message Date
Zamil Majdy
2fdc311293 Test check 2026-01-23 10:37:37 -05:00
13 changed files with 352 additions and 277 deletions

View File

@@ -5,9 +5,9 @@ from asyncio import CancelledError
from collections.abc import AsyncGenerator
from typing import Any
import openai
import orjson
from langfuse import get_client
from langfuse import get_client, propagate_attributes
from langfuse.openai import openai # type: ignore
from openai import (
APIConnectionError,
APIError,
@@ -276,301 +276,347 @@ async def stream_chat_completion(
# Build system prompt with business understanding
system_prompt, understanding = await _build_system_prompt(user_id)
# Initialize variables for streaming
assistant_response = ChatMessage(
role="assistant",
content="",
)
accumulated_tool_calls: list[dict[str, Any]] = []
has_saved_assistant_message = False
has_appended_streaming_message = False
last_cache_time = 0.0
last_cache_content_len = 0
# Create Langfuse trace for this LLM call (each call gets its own trace, grouped by session_id)
# Using v3 SDK: start_observation creates a root span, update_trace sets trace-level attributes
input = message
if not message and tool_call_response:
input = tool_call_response
has_yielded_end = False
has_yielded_error = False
has_done_tool_call = False
has_received_text = False
text_streaming_ended = False
tool_response_messages: list[ChatMessage] = []
should_retry = False
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Yield message start
yield StreamStart(messageId=message_id)
try:
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
system_prompt=system_prompt,
text_block_id=text_block_id,
langfuse = get_client()
with langfuse.start_as_current_observation(
as_type="span",
name="user-copilot-request",
input=input,
) as span:
with propagate_attributes(
session_id=session_id,
user_id=user_id,
tags=["copilot"],
metadata={
"users_information": format_understanding_for_prompt(understanding)[
:200
] # langfuse only accepts upto to 200 chars
},
):
if isinstance(chunk, StreamTextStart):
# Emit text-start before first text delta
if not has_received_text:
yield chunk
elif isinstance(chunk, StreamTextDelta):
delta = chunk.delta or ""
assert assistant_response.content is not None
assistant_response.content += delta
has_received_text = True
if not has_appended_streaming_message:
session.messages.append(assistant_response)
has_appended_streaming_message = True
current_time = time.monotonic()
content_len = len(assistant_response.content)
if (
current_time - last_cache_time >= 1.0
and content_len > last_cache_content_len
# Initialize variables that will be used in finally block (must be defined before try)
assistant_response = ChatMessage(
role="assistant",
content="",
)
accumulated_tool_calls: list[dict[str, Any]] = []
has_saved_assistant_message = False
has_appended_streaming_message = False
last_cache_time = 0.0
last_cache_content_len = 0
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
has_yielded_end = False
has_yielded_error = False
has_done_tool_call = False
has_received_text = False
text_streaming_ended = False
tool_response_messages: list[ChatMessage] = []
should_retry = False
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Yield message start
yield StreamStart(messageId=message_id)
try:
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
system_prompt=system_prompt,
text_block_id=text_block_id,
):
if isinstance(chunk, StreamTextStart):
# Emit text-start before first text delta
if not has_received_text:
yield chunk
elif isinstance(chunk, StreamTextDelta):
delta = chunk.delta or ""
assert assistant_response.content is not None
assistant_response.content += delta
has_received_text = True
if not has_appended_streaming_message:
session.messages.append(assistant_response)
has_appended_streaming_message = True
current_time = time.monotonic()
content_len = len(assistant_response.content)
if (
current_time - last_cache_time >= 1.0
and content_len > last_cache_content_len
):
try:
await cache_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to cache partial session {session.session_id}: {e}"
)
last_cache_time = current_time
last_cache_content_len = content_len
yield chunk
elif isinstance(chunk, StreamTextEnd):
# Emit text-end after text completes
if has_received_text and not text_streaming_ended:
text_streaming_ended = True
if assistant_response.content:
logger.warn(
f"StreamTextEnd: Attempting to set output {assistant_response.content}"
)
span.update_trace(output=assistant_response.content)
span.update(output=assistant_response.content)
yield chunk
elif isinstance(chunk, StreamToolInputStart):
# Emit text-end before first tool call, but only if we've received text
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
yield chunk
elif isinstance(chunk, StreamToolInputAvailable):
# Accumulate tool calls in OpenAI format
accumulated_tool_calls.append(
{
"id": chunk.toolCallId,
"type": "function",
"function": {
"name": chunk.toolName,
"arguments": orjson.dumps(chunk.input).decode(
"utf-8"
),
},
}
)
elif isinstance(chunk, StreamToolOutputAvailable):
result_content = (
chunk.output
if isinstance(chunk.output, str)
else orjson.dumps(chunk.output).decode("utf-8")
)
tool_response_messages.append(
ChatMessage(
role="tool",
content=result_content,
tool_call_id=chunk.toolCallId,
)
)
has_done_tool_call = True
# Track if any tool execution failed
if not chunk.success:
logger.warning(
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
)
yield chunk
elif isinstance(chunk, StreamFinish):
if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
# Save assistant message before yielding finish to ensure it's persisted
# even if client disconnects immediately after receiving StreamFinish
if not has_saved_assistant_message:
messages_to_save_early: list[ChatMessage] = []
if accumulated_tool_calls:
assistant_response.tool_calls = (
accumulated_tool_calls
)
if not has_appended_streaming_message and (
assistant_response.content
or assistant_response.tool_calls
):
messages_to_save_early.append(assistant_response)
messages_to_save_early.extend(tool_response_messages)
if messages_to_save_early:
session.messages.extend(messages_to_save_early)
logger.info(
f"Saving assistant message before StreamFinish: "
f"content_len={len(assistant_response.content or '')}, "
f"tool_calls={len(assistant_response.tool_calls or [])}, "
f"tool_responses={len(tool_response_messages)}"
)
if (
messages_to_save_early
or has_appended_streaming_message
):
await upsert_chat_session(session)
has_saved_assistant_message = True
has_yielded_end = True
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
yield chunk
elif isinstance(chunk, StreamUsage):
session.usage.append(
Usage(
prompt_tokens=chunk.promptTokens,
completion_tokens=chunk.completionTokens,
total_tokens=chunk.totalTokens,
)
)
else:
logger.error(
f"Unknown chunk type: {type(chunk)}", exc_info=True
)
if assistant_response.content:
langfuse.update_current_trace(output=assistant_response.content)
langfuse.update_current_span(output=assistant_response.content)
elif tool_response_messages:
langfuse.update_current_trace(output=str(tool_response_messages))
langfuse.update_current_span(output=str(tool_response_messages))
except CancelledError:
if not has_saved_assistant_message:
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if assistant_response.content:
assistant_response.content = (
f"{assistant_response.content}\n\n[interrupted]"
)
else:
assistant_response.content = "[interrupted]"
if not has_appended_streaming_message:
session.messages.append(assistant_response)
if tool_response_messages:
session.messages.extend(tool_response_messages)
try:
await cache_chat_session(session)
await upsert_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to cache partial session {session.session_id}: {e}"
f"Failed to save interrupted session {session.session_id}: {e}"
)
last_cache_time = current_time
last_cache_content_len = content_len
yield chunk
elif isinstance(chunk, StreamTextEnd):
# Emit text-end after text completes
if has_received_text and not text_streaming_ended:
text_streaming_ended = True
yield chunk
elif isinstance(chunk, StreamToolInputStart):
# Emit text-end before first tool call, but only if we've received text
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
yield chunk
elif isinstance(chunk, StreamToolInputAvailable):
# Accumulate tool calls in OpenAI format
accumulated_tool_calls.append(
{
"id": chunk.toolCallId,
"type": "function",
"function": {
"name": chunk.toolName,
"arguments": orjson.dumps(chunk.input).decode("utf-8"),
},
}
)
elif isinstance(chunk, StreamToolOutputAvailable):
result_content = (
chunk.output
if isinstance(chunk.output, str)
else orjson.dumps(chunk.output).decode("utf-8")
)
tool_response_messages.append(
ChatMessage(
role="tool",
content=result_content,
tool_call_id=chunk.toolCallId,
)
)
has_done_tool_call = True
# Track if any tool execution failed
if not chunk.success:
logger.warning(
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
)
yield chunk
elif isinstance(chunk, StreamFinish):
if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
# Save assistant message before yielding finish to ensure it's persisted
# even if client disconnects immediately after receiving StreamFinish
if not has_saved_assistant_message:
messages_to_save_early: list[ChatMessage] = []
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save_early.append(assistant_response)
messages_to_save_early.extend(tool_response_messages)
if messages_to_save_early:
session.messages.extend(messages_to_save_early)
logger.info(
f"Saving assistant message before StreamFinish: "
f"content_len={len(assistant_response.content or '')}, "
f"tool_calls={len(assistant_response.tool_calls or [])}, "
f"tool_responses={len(tool_response_messages)}"
)
if messages_to_save_early or has_appended_streaming_message:
await upsert_chat_session(session)
has_saved_assistant_message = True
has_yielded_end = True
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
yield chunk
elif isinstance(chunk, StreamUsage):
session.usage.append(
Usage(
prompt_tokens=chunk.promptTokens,
completion_tokens=chunk.completionTokens,
total_tokens=chunk.totalTokens,
)
)
else:
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
except CancelledError:
if not has_saved_assistant_message:
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if assistant_response.content:
assistant_response.content = (
f"{assistant_response.content}\n\n[interrupted]"
)
else:
assistant_response.content = "[interrupted]"
if not has_appended_streaming_message:
session.messages.append(assistant_response)
if tool_response_messages:
session.messages.extend(tool_response_messages)
try:
await upsert_chat_session(session)
raise
except Exception as e:
logger.warning(
f"Failed to save interrupted session {session.session_id}: {e}"
logger.error(f"Error during stream: {e!s}", exc_info=True)
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
is_retryable = isinstance(
e, (orjson.JSONDecodeError, KeyError, TypeError)
)
raise
except Exception as e:
logger.error(f"Error during stream: {e!s}", exc_info=True)
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
is_retryable = isinstance(e, (orjson.JSONDecodeError, KeyError, TypeError))
if is_retryable and retry_count < config.max_retries:
logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
)
should_retry = True
else:
# Non-retryable error or max retries exceeded
# Save any partial progress before reporting error
messages_to_save: list[ChatMessage] = []
if is_retryable and retry_count < config.max_retries:
logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
)
should_retry = True
else:
# Non-retryable error or max retries exceeded
# Save any partial progress before reporting error
messages_to_save: list[ChatMessage] = []
# Add assistant message if it has content or tool calls
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
# Add assistant message if it has content or tool calls
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
if not has_saved_assistant_message:
if messages_to_save:
session.messages.extend(messages_to_save)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
if not has_yielded_error:
error_message = str(e)
if not is_retryable:
error_message = f"Non-retryable error: {error_message}"
elif retry_count >= config.max_retries:
error_message = f"Max retries ({config.max_retries}) exceeded: {error_message}"
error_response = StreamError(errorText=error_message)
yield error_response
if not has_yielded_end:
yield StreamFinish()
return
# Handle retry outside of exception handler to avoid nesting
if should_retry and retry_count < config.max_retries:
logger.info(
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
retry_count=retry_count + 1,
session=session,
context=context,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
# Normal completion path - save session and handle tool call continuation
# Only save if we haven't already saved when StreamFinish was received
if not has_saved_assistant_message:
logger.info(
f"Normal completion path: session={session.session_id}, "
f"current message_count={len(session.messages)}"
)
# Build the messages list in the correct order
messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
)
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
logger.info(
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
logger.info(
f"Saving {len(tool_response_messages)} tool response messages, "
f"total_to_save={len(messages_to_save)}"
)
if messages_to_save:
session.messages.extend(messages_to_save)
logger.info(
f"Extended session messages, new message_count={len(session.messages)}"
)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
else:
logger.info(
"Assistant message already saved when StreamFinish was received, "
"skipping duplicate save"
)
if not has_yielded_error:
error_message = str(e)
if not is_retryable:
error_message = f"Non-retryable error: {error_message}"
elif retry_count >= config.max_retries:
error_message = (
f"Max retries ({config.max_retries}) exceeded: {error_message}"
)
error_response = StreamError(errorText=error_message)
yield error_response
if not has_yielded_end:
yield StreamFinish()
return
# Handle retry outside of exception handler to avoid nesting
if should_retry and retry_count < config.max_retries:
logger.info(
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
retry_count=retry_count + 1,
session=session,
context=context,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
# Normal completion path - save session and handle tool call continuation
# Only save if we haven't already saved when StreamFinish was received
if not has_saved_assistant_message:
logger.info(
f"Normal completion path: session={session.session_id}, "
f"current message_count={len(session.messages)}"
)
# Build the messages list in the correct order
messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
)
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
logger.info(
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
logger.info(
f"Saving {len(tool_response_messages)} tool response messages, "
f"total_to_save={len(messages_to_save)}"
)
if messages_to_save:
session.messages.extend(messages_to_save)
logger.info(
f"Extended session messages, new message_count={len(session.messages)}"
)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
else:
logger.info(
"Assistant message already saved when StreamFinish was received, "
"skipping duplicate save"
)
# If we did a tool call, stream the chat completion again to get the next response
if has_done_tool_call:
logger.info(
"Tool call executed, streaming chat completion again to get assistant response"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
session=session, # Pass session object to avoid Redis refetch
context=context,
tool_call_response=str(tool_response_messages),
):
yield chunk
# If we did a tool call, stream the chat completion again to get the next response
if has_done_tool_call:
logger.info(
"Tool call executed, streaming chat completion again to get assistant response"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
session=session, # Pass session object to avoid Redis refetch
context=context,
tool_call_response=str(tool_response_messages),
):
yield chunk
# Retry configuration for OpenAI API calls

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
@@ -59,6 +61,7 @@ and automations for the user's specific needs."""
"""Requires authentication to store user-specific data."""
return True
@observe(as_type="tool", name="add_understanding")
async def _execute(
self,
user_id: str | None,

View File

@@ -5,6 +5,7 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from langfuse import observe
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
@@ -328,6 +329,7 @@ class AgentOutputTool(BaseTool):
total_executions=len(available_executions) if available_executions else 1,
)
@observe(as_type="tool", name="view_agent_output")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -78,6 +80,7 @@ class CreateAgentTool(BaseTool):
"required": ["description"],
}
@observe(as_type="tool", name="create_agent")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -85,6 +87,7 @@ class EditAgentTool(BaseTool):
"required": ["agent_id", "changes"],
}
@observe(as_type="tool", name="edit_agent")
async def _execute(
self,
user_id: str | None,

View File

@@ -2,6 +2,8 @@
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -35,6 +37,7 @@ class FindAgentTool(BaseTool):
"required": ["query"],
}
@observe(as_type="tool", name="find_agent")
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:

View File

@@ -1,6 +1,7 @@
import logging
from typing import Any
from langfuse import observe
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
@@ -55,6 +56,7 @@ class FindBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
@observe(as_type="tool", name="find_block")
async def _execute(
self,
user_id: str | None,

View File

@@ -2,6 +2,8 @@
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -41,6 +43,7 @@ class FindLibraryAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@observe(as_type="tool", name="find_library_agent")
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:

View File

@@ -4,6 +4,8 @@ import logging
from pathlib import Path
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
@@ -71,6 +73,7 @@ class GetDocPageTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="get_doc_page")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,6 +3,7 @@
import logging
from typing import Any
from langfuse import observe
from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig
@@ -154,6 +155,7 @@ class RunAgentTool(BaseTool):
"""All operations require authentication."""
return True
@observe(as_type="tool", name="run_agent")
async def _execute(
self,
user_id: str | None,

View File

@@ -4,6 +4,8 @@ import logging
from collections import defaultdict
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
@@ -128,6 +130,7 @@ class RunBlockTool(BaseTool):
return matched_credentials, missing_credentials
@observe(as_type="tool", name="run_block")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,6 +3,7 @@
import logging
from typing import Any
from langfuse import observe
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
@@ -87,6 +88,7 @@ class SearchDocsTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="search_docs")
async def _execute(
self,
user_id: str | None,

View File

@@ -165,7 +165,7 @@ async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, No
@pytest.mark.asyncio(loop_scope="session")
async def test_authorize_creates_code_in_database(
async def test_authorize_creates_code_in_database_test(
client: httpx.AsyncClient,
test_user: str,
test_oauth_app: dict,