Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/rate-limit-tiering

This commit is contained in:
Zamil Majdy
2026-03-31 15:17:31 +02:00
38 changed files with 4665 additions and 392 deletions

View File

@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.

View File

@@ -1,5 +1,6 @@
import asyncio
import base64
import re
from abc import ABC
from email import encoders
from email.mime.base import MIMEBase
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
from email.policy import SMTP
from email.utils import getaddresses, parseaddr
from pathlib import Path
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Protocol, runtime_checkable
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
def serialize_email_recipients(recipients: list[str]) -> str:
"""Serialize recipients list to comma-separated string."""
return ", ".join(recipients)
"""Serialize recipients list to comma-separated string.
Strips leading/trailing whitespace from each address to keep MIME
headers clean (mirrors the strip done in ``validate_email_recipients``).
"""
return ", ".join(addr.strip() for addr in recipients)
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
"""Validate that all recipients are plausible email addresses.
Raises ``ValueError`` with a user-friendly message listing every
invalid entry so the caller (or LLM) can correct them in one pass.
"""
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
if invalid:
formatted = ", ".join(f"'{a}'" for a in invalid)
raise ValueError(
f"Invalid email address(es) in '{field_name}': {formatted}. "
f"Each entry must be a valid email address (e.g. user@example.com)."
)
@runtime_checkable
class HasRecipients(Protocol):
to: list[str]
cc: list[str]
bcc: list[str]
def validate_all_recipients(input_data: HasRecipients) -> None:
"""Validate to/cc/bcc recipient fields on an input namespace.
Calls ``validate_email_recipients`` for ``to`` (required) and
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
first field that contains an invalid address.
"""
validate_email_recipients(input_data.to, "to")
if input_data.cc:
validate_email_recipients(input_data.cc, "cc")
if input_data.bcc:
validate_email_recipients(input_data.bcc, "bcc")
def _make_mime_text(
@@ -100,14 +145,16 @@ async def create_mime_message(
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
validate_all_recipients(input_data)
message = MIMEMultipart()
message["to"] = serialize_email_recipients(input_data.to)
message["subject"] = input_data.subject
if input_data.cc:
message["cc"] = ", ".join(input_data.cc)
message["cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
message["bcc"] = ", ".join(input_data.bcc)
message["bcc"] = serialize_email_recipients(input_data.bcc)
# Use the new helper function with content_type if available
content_type = getattr(input_data, "content_type", None)
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
references.append(headers["message-id"])
# Create MIME message
validate_all_recipients(input_data)
msg = MIMEMultipart()
if input_data.to:
msg["To"] = ", ".join(input_data.to)
msg["To"] = serialize_email_recipients(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
msg["Cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
@@ -1685,13 +1734,16 @@ To: {original_to}
else:
body = f"{forward_header}\n\n{original_body}"
# Validate all recipient lists before building the MIME message
validate_all_recipients(input_data)
# Create MIME message
msg = MIMEMultipart()
msg["To"] = ", ".join(input_data.to)
msg["To"] = serialize_email_recipients(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
msg["Cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Subject"] = subject
# Add body with proper content type

View File

@@ -724,6 +724,9 @@ def convert_openai_tool_fmt_to_anthropic(
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_reasoning")
return None
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
@@ -739,6 +742,9 @@ def extract_openai_reasoning(response) -> str | None:
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
return None
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
@@ -972,6 +978,8 @@ async def llm_call(
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("Groq returned empty choices in response")
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
@@ -1031,12 +1039,8 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
raise ValueError(f"OpenRouter returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1073,12 +1077,8 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"Llama API error: {response}")
else:
raise ValueError("No response from Llama API.")
raise ValueError(f"Llama API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1108,6 +1108,8 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
if not completion.choices:
raise ValueError("AI/ML API returned empty choices in response")
return LLMResponse(
raw_response=completion.choices[0].message,
@@ -1144,6 +1146,9 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
if not response.choices:
raise ValueError(f"v0 API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,87 @@
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
from unittest.mock import MagicMock
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
class TestExtractOpenaiToolCallsEmptyChoices:
"""extract_openai_tool_calls() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_tool_calls(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_tool_calls(response) is None
def test_returns_tool_calls_when_choices_present(self):
tool = MagicMock()
tool.id = "call_1"
tool.type = "function"
tool.function.name = "my_func"
tool.function.arguments = '{"a": 1}'
message = MagicMock()
message.tool_calls = [tool]
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
result = extract_openai_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "my_func"
def test_returns_none_when_no_tool_calls(self):
message = MagicMock()
message.tool_calls = None
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
assert extract_openai_tool_calls(response) is None
class TestExtractOpenaiReasoningEmptyChoices:
"""extract_openai_reasoning() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_reasoning(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_reasoning(response) is None
def test_returns_reasoning_from_choice(self):
choice = MagicMock()
choice.reasoning = "Step-by-step reasoning"
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
response = MagicMock(spec=[]) # no 'reasoning' attr on response
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result == "Step-by-step reasoning"
def test_returns_none_when_no_reasoning(self):
choice = MagicMock(spec=[]) # no 'reasoning' attr
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
response = MagicMock(spec=[]) # no 'reasoning' attr
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result is None

View File

@@ -1074,6 +1074,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
@@ -1105,6 +1106,7 @@ async def test_orchestrator_falls_back_to_block_name():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)

View File

@@ -0,0 +1,202 @@
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
Covers:
- ExecutionMode enum members exist and have stable values
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
- EXTENDED_THINKING model-name validation (must start with "claude")
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LlmModel
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
# ---------------------------------------------------------------------------
# ExecutionMode enum integrity
# ---------------------------------------------------------------------------
class TestExecutionModeEnum:
"""Guard against accidental renames or removals of enum members."""
def test_built_in_exists(self):
assert hasattr(ExecutionMode, "BUILT_IN")
assert ExecutionMode.BUILT_IN.value == "built_in"
def test_extended_thinking_exists(self):
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
def test_exactly_two_members(self):
"""If a new mode is added, this test should be updated intentionally."""
assert set(ExecutionMode.__members__.keys()) == {
"BUILT_IN",
"EXTENDED_THINKING",
}
def test_string_enum(self):
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
assert isinstance(ExecutionMode.BUILT_IN, str)
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
def test_round_trip_from_value(self):
"""Constructing from the string value should return the same member."""
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
# ---------------------------------------------------------------------------
# Provider validation (inline in OrchestratorBlock.run)
# ---------------------------------------------------------------------------
def _make_model_stub(provider: str, value: str):
"""Create a lightweight stub that behaves like LlmModel for validation."""
metadata = MagicMock()
metadata.provider = provider
stub = MagicMock()
stub.metadata = metadata
stub.value = value
return stub
class TestExtendedThinkingProviderValidation:
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
def test_anthropic_provider_accepted(self):
"""provider='anthropic' + claude model should not raise."""
model = _make_model_stub("anthropic", "claude-opus-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_open_router_provider_accepted(self):
"""provider='open_router' + claude model should not raise."""
model = _make_model_stub("open_router", "claude-sonnet-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_openai_provider_rejected(self):
"""provider='openai' should be rejected for EXTENDED_THINKING."""
model = _make_model_stub("openai", "gpt-4o")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_groq_provider_rejected(self):
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
model = _make_model_stub("anthropic", "not-a-claude-model")
model_name = model.value
assert not model_name.startswith("claude")
def test_real_gpt4o_model_rejected(self):
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
model = LlmModel.GPT4O
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_real_claude_model_passes(self):
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
model = LlmModel.CLAUDE_4_6_SONNET
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
# ---------------------------------------------------------------------------
# Integration-style: exercise the validation branch via OrchestratorBlock.run
# ---------------------------------------------------------------------------
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
inp = MagicMock()
inp.execution_mode = execution_mode
inp.model = model
inp.prompt = "test"
inp.sys_prompt = ""
inp.conversation_history = []
inp.last_tool_output = None
inp.prompt_values = {}
return inp
async def _collect_run_outputs(block, input_data, **kwargs):
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
outputs = []
async for item in block.run(input_data, **kwargs):
outputs.append(item)
return outputs
class TestExtendedThinkingValidationRaisesInBlock:
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
@pytest.mark.asyncio
async def test_non_anthropic_provider_raises_valueerror(self):
"""EXTENDED_THINKING + openai provider raises ValueError."""
block = OrchestratorBlock()
input_data = _make_input_data(model=LlmModel.GPT4O)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="Anthropic-compatible"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)
@pytest.mark.asyncio
async def test_non_claude_model_with_anthropic_provider_raises(self):
"""A model with anthropic provider but non-claude name raises ValueError."""
block = OrchestratorBlock()
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
input_data = _make_input_data(model=fake_model)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="only supports Claude models"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)

File diff suppressed because it is too large Load Diff

View File

@@ -9,11 +9,14 @@ shared tool registry as the SDK path.
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, cast
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from backend.copilot.model import (
ChatMessage,
@@ -48,7 +51,17 @@ from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.tool_call_loop import (
LLMLoopResponse,
LLMToolCall,
ToolCallResult,
tool_call_loop,
)
logger = logging.getLogger(__name__)
@@ -59,6 +72,247 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
@dataclass
class _BaselineStreamState:
"""Mutable state shared between the tool-call loop callbacks.
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
can be module-level functions instead of deeply nested closures.
"""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
async def _baseline_llm_caller(
messages: list[dict[str, Any]],
tools: Sequence[Any],
*,
state: _BaselineStreamState,
) -> LLMLoopResponse:
"""Stream an OpenAI-compatible response and collect results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
state.pending_events.append(StreamStartStep())
round_text = ""
try:
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
tools=typed_tools,
stream=True,
stream_options={"include_usage": True},
)
else:
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
stream=True,
stream_options={"include_usage": True},
)
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
if chunk.usage:
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
if delta.content:
if not state.text_started:
state.pending_events.append(StreamTextStart(id=state.text_block_id))
state.text_started = True
round_text += delta.content
state.pending_events.append(
StreamTextDelta(id=state.text_block_id, delta=delta.content)
)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block
if state.text_started:
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
state.text_started = False
state.text_block_id = str(uuid.uuid4())
finally:
# Always persist partial text so the session history stays consistent,
# even when the stream is interrupted by an exception.
state.assistant_text += round_text
# Always emit StreamFinishStep to match the StreamStartStep,
# even if an exception occurred during streaming.
state.pending_events.append(StreamFinishStep())
# Convert to shared format
llm_tool_calls = [
LLMToolCall(
id=tc["id"],
name=tc["name"],
arguments=tc["arguments"] or "{}",
)
for tc in tool_calls_by_index.values()
]
return LLMLoopResponse(
response_text=round_text or None,
tool_calls=llm_tool_calls,
raw_response=None, # Not needed for baseline conversation updater
prompt_tokens=0, # Tracked via state accumulators
completion_tokens=0,
)
async def _baseline_tool_executor(
tool_call: LLMToolCall,
tools: Sequence[Any],
*,
state: _BaselineStreamState,
user_id: str | None,
session: ChatSession,
) -> ToolCallResult:
"""Execute a tool via the copilot tool registry.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
tool_call_id = tool_call.id
tool_name = tool_call.name
raw_args = tool_call.arguments or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
logger.warning("[Baseline] %s", parse_error)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=parse_error,
is_error=True,
)
state.pending_events.append(
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
)
state.pending_events.append(
StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
)
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
state.pending_events.append(result)
tool_output = (
result.output if isinstance(result.output, str) else str(result.output)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=tool_output,
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=error_output,
is_error=True,
)
def _baseline_conversation_updater(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
"""Update OpenAI message list with assistant response + tool results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
if tool_results:
# Build assistant message with tool_calls
assistant_msg: dict[str, Any] = {"role": "assistant"}
if response.response_text:
assistant_msg["content"] = response.response_text
assistant_msg["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": tc.arguments},
}
for tc in response.tool_calls
]
messages.append(assistant_msg)
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
else:
if response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
@@ -219,191 +473,32 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
state = _BaselineStreamState()
# Bind extracted module-level callbacks to this request's state/session
# using functools.partial so they satisfy the Protocol signatures.
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
_bound_tool_executor = partial(
_baseline_tool_executor, state=state, user_id=user_id, session=session
)
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
loop_result = None
async for loop_result in tool_call_loop(
messages=openai_messages,
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_baseline_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
# Drain buffered events after each iteration (real-time streaming)
for evt in state.pending_events:
yield evt
state.pending_events.clear()
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
if loop_result and not loop_result.finished_naturally:
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
@@ -418,11 +513,28 @@ async def stream_chat_completion_baseline(
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
# Close any open text block. The llm_caller's finally block
# already appended StreamFinishStep to pending_events, so we must
# insert StreamTextEnd *before* StreamFinishStep to preserve the
# protocol ordering:
# StreamStartStep -> StreamTextStart -> ...deltas... ->
# StreamTextEnd -> StreamFinishStep
# Appending (or yielding directly) would place it after
# StreamFinishStep, violating the protocol.
if state.text_started:
# Find the last StreamFinishStep and insert before it.
insert_pos = len(state.pending_events)
for i in range(len(state.pending_events) - 1, -1, -1):
if isinstance(state.pending_events[i], StreamFinishStep):
insert_pos = i
break
state.pending_events.insert(
insert_pos, StreamTextEnd(id=state.text_block_id)
)
# Drain pending events in correct order
for evt in state.pending_events:
yield evt
state.pending_events.clear()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
@@ -442,26 +554,21 @@ async def stream_chat_completion_baseline(
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
state.turn_prompt_tokens == 0
and state.turn_completion_tokens == 0
and not (_stream_error and not state.assistant_text)
):
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
state.turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
state.turn_completion_tokens = estimate_token_count_str(
state.assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
state.turn_prompt_tokens,
state.turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
@@ -471,15 +578,15 @@ async def stream_chat_completion_baseline(
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if assistant_text:
if state.assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=assistant_text)
ChatMessage(role="assistant", content=state.assistant_text)
)
try:
await upsert_chat_session(session)
@@ -491,11 +598,11 @@ async def stream_chat_completion_baseline(
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -178,7 +178,7 @@ class ChatConfig(BaseSettings):
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``_build_sdk_env``.
present — mirrors the fallback logic in ``build_sdk_env``.
"""
if not self.use_openrouter:
return False

View File

@@ -18,7 +18,7 @@ from prisma.types import (
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
logger = logging.getLogger(__name__)
@@ -217,6 +217,9 @@ async def add_chat_messages_batch(
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
if msg.get("duration_ms") is not None:
data["durationMs"] = msg["duration_ms"]
messages_data.append(data)
# Run create_many and session update in parallel within transaction
@@ -359,3 +362,22 @@ async def update_tool_message_content(
f"tool_call_id {tool_call_id}: {e}"
)
return False
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.
Also invalidates the Redis session cache so the next GET returns
the updated duration.
"""
last_msg = await PrismaChatMessage.prisma().find_first(
where={"sessionId": session_id, "role": "assistant"},
order={"sequence": "desc"},
)
if last_msg:
await PrismaChatMessage.prisma().update(
where={"id": last_msg.id},
data={"durationMs": duration_ms},
)
# Invalidate cache so the session is re-fetched from DB with durationMs
await invalidate_session_cache(session_id)

View File

@@ -54,6 +54,7 @@ class ChatMessage(BaseModel):
refusal: str | None = None
tool_calls: list[dict] | None = None
function_call: dict | None = None
duration_ms: int | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
@@ -66,6 +67,7 @@ class ChatMessage(BaseModel):
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
duration_ms=prisma_message.durationMs,
)

View File

@@ -0,0 +1,68 @@
"""SDK environment variable builder — importable without circular deps.
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
can reuse the same subscription / OpenRouter / direct-Anthropic logic
without pulling in the full copilot service module (which would create a
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
"""
from __future__ import annotations
from backend.copilot.config import ChatConfig
from backend.copilot.sdk.subscription import validate_subscription
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
# A singleton would require importing service.py which causes the circular dep
# this module was created to avoid.
config = ChatConfig()
def build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
``ANTHROPIC_API_KEY`` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
validate_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env

View File

@@ -0,0 +1,242 @@
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
from unittest.mock import patch
import pytest
from backend.copilot.config import ChatConfig
# ---------------------------------------------------------------------------
# Helpers — build a ChatConfig with explicit field values so tests don't
# depend on real environment variables.
# ---------------------------------------------------------------------------
def _make_config(**overrides) -> ChatConfig:
"""Create a ChatConfig with safe defaults, applying *overrides*."""
defaults = {
"use_claude_code_subscription": False,
"use_openrouter": False,
"api_key": None,
"base_url": None,
}
defaults.update(overrides)
return ChatConfig(**defaults)
# ---------------------------------------------------------------------------
# Mode 1 — Subscription auth
# ---------------------------------------------------------------------------
class TestBuildSdkEnvSubscription:
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_returns_blanked_keys(self, mock_validate):
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
mock_validate.assert_called_once()
@patch(
"backend.copilot.sdk.env.validate_subscription",
side_effect=RuntimeError("CLI not found"),
)
def test_propagates_validation_error(self, mock_validate):
"""If validate_subscription fails, the error bubbles up."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
with pytest.raises(RuntimeError, match="CLI not found"):
build_sdk_env()
# ---------------------------------------------------------------------------
# Mode 2 — Direct Anthropic (no OpenRouter)
# ---------------------------------------------------------------------------
class TestBuildSdkEnvDirectAnthropic:
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
def test_returns_empty_dict_when_openrouter_inactive(self):
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
# Force api_key to None after construction (field_validator may pick up env vars)
object.__setattr__(cfg, "api_key", None)
assert not cfg.openrouter_active
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
# ---------------------------------------------------------------------------
# Mode 3 — OpenRouter proxy
# ---------------------------------------------------------------------------
class TestBuildSdkEnvOpenRouter:
"""When OpenRouter is active, return proxy env vars."""
def _openrouter_config(self, **overrides):
defaults = {
"use_openrouter": True,
"api_key": "sk-or-test-key",
"base_url": "https://openrouter.ai/api/v1",
}
defaults.update(overrides)
return _make_config(**defaults)
def test_basic_openrouter_env(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
assert result["ANTHROPIC_API_KEY"] == ""
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
def test_strips_trailing_v1(self):
"""The /v1 suffix is stripped from the base URL."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_strips_trailing_v1_and_slash(self):
"""Trailing slash before /v1 strip is handled."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# rstrip("/") first, then remove /v1
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_no_v1_suffix_left_alone(self):
"""A base URL without /v1 is used as-is."""
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
def test_session_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="sess-123")
assert "ANTHROPIC_CUSTOM_HEADERS" in result
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_user_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(user_id="user-456")
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_both_headers(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="s1", user_id="u2")
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
assert "x-session-id: s1" in headers
assert "x-user-id: u2" in headers
# They should be newline-separated
assert "\n" in headers
def test_header_sanitisation_strips_newlines(self):
"""Newlines/carriage-returns in header values are stripped."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="bad\r\nvalue")
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
# The _safe helper removes \r and \n
assert "\r" not in header_val.split(": ", 1)[1]
assert "badvalue" in header_val
def test_header_value_truncated_to_128_chars(self):
"""Header values are truncated to 128 characters."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
long_id = "x" * 200
result = build_sdk_env(session_id=long_id)
# The value after "x-session-id: " should be at most 128 chars
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
value = header_line.split(": ", 1)[1]
assert len(value) == 128
# ---------------------------------------------------------------------------
# Mode priority
# ---------------------------------------------------------------------------
class TestBuildSdkEnvModePriority:
"""Subscription mode takes precedence over OpenRouter."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_subscription_overrides_openrouter(self, mock_validate):
cfg = _make_config(
use_claude_code_subscription=True,
use_openrouter=True,
api_key="sk-or-key",
base_url="https://openrouter.ai/api/v1",
)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# Should get subscription result, not OpenRouter
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}

View File

@@ -1010,7 +1010,7 @@ def _make_sdk_patches(
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
(f"{_SVC}.set_execution_context", {}),
(

View File

@@ -78,9 +78,9 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .subscription import validate_subscription as _validate_claude_code_subscription
from .tool_adapter import (
cancel_pending_tool_tasks,
create_copilot_mcp_server,
@@ -568,60 +568,6 @@ def _resolve_sdk_model() -> str | None:
return model
def _build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
2. **Direct Anthropic** — returns `{}`; subprocess inherits
`ANTHROPIC_API_KEY` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
_validate_claude_code_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
# `openrouter_active` checks the flag *and* credential presence.
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
# Strip /v1 suffix — SDK expects the base URL without a version path.
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
"""Sanitise a header value: strip newlines/whitespace and cap length."""
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
@@ -1868,7 +1814,7 @@ async def stream_chat_completion_sdk(
)
# Fail fast when no API credentials are available at all.
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
if not config.api_key and not config.use_claude_code_subscription:
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY, "

View File

@@ -26,6 +26,7 @@ import orjson
from redis.exceptions import RedisError
from backend.api.model import CopilotCompletionPayload
from backend.data.db_accessors import chat_db
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
@@ -111,6 +112,14 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
pre-dates the turn_id field (backward compat for in-flight sessions).
"""
created_at = datetime.now(timezone.utc)
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
except (ValueError, TypeError):
pass
return ActiveSession(
session_id=meta.get("session_id", "") or session_id,
user_id=meta.get("user_id", "") or None,
@@ -119,6 +128,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
turn_id=meta.get("turn_id", "") or session_id,
blocking=meta.get("blocking") == "1",
status=meta.get("status", "running"), # type: ignore[arg-type]
created_at=created_at,
)
@@ -802,6 +812,33 @@ async def mark_session_completed(
f"Failed to publish error event for session {session_id}: {e}"
)
# Compute wall-clock duration from session created_at.
# Only persist when (a) the session completed successfully and
# (b) created_at was actually present in Redis meta (not a fallback).
duration_ms: int | None = None
if meta and not error_message:
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
elapsed = datetime.now(timezone.utc) - created_at
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
except (ValueError, TypeError):
logger.warning(
"Failed to compute session duration for %s (created_at=%r)",
session_id,
created_at_raw,
)
# Persist duration on the last assistant message
if duration_ms is not None:
try:
await chat_db().set_turn_duration(session_id, duration_ms)
except Exception as e:
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
# Publish StreamFinish AFTER status is set to "completed"/"failed".
# This is the SINGLE place that publishes StreamFinish — services and
# the processor must NOT publish it themselves.

View File

@@ -344,6 +344,7 @@ class DatabaseManager(AppService):
get_next_sequence = _(chat_db.get_next_sequence)
update_tool_message_content = _(chat_db.update_tool_message_content)
update_chat_session_title = _(chat_db.update_chat_session_title)
set_turn_duration = _(chat_db.set_turn_duration)
class DatabaseManagerClient(AppServiceClient):
@@ -540,3 +541,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_next_sequence = d.get_next_sequence
update_tool_message_content = d.update_tool_message_content
update_chat_session_title = d.update_chat_session_title
set_turn_duration = d.set_turn_duration

View File

@@ -0,0 +1,20 @@
"""Shared security constants for field-level filtering.
Other modules (e.g. orchestrator, future blocks) import from here so the
sensitive-field list stays in one place.
"""
# Field names to exclude from hardcoded-defaults descriptions (case-insensitive).
SENSITIVE_FIELD_NAMES: frozenset[str] = frozenset(
{
"credentials",
"api_key",
"password",
"secret",
"token",
"auth",
"authorization",
"access_token",
"refresh_token",
}
)

View File

@@ -0,0 +1,281 @@
"""Shared tool-calling conversation loop.
Provides a generic, provider-agnostic conversation loop that both
the OrchestratorBlock and copilot baseline can use. The loop:
1. Calls the LLM with tool definitions
2. Extracts tool calls from the response
3. Executes tools via a caller-supplied callback
4. Appends results to the conversation
5. Repeats until no more tool calls or max iterations reached
Callers provide callbacks for LLM calling, tool execution, and
conversation updating.
"""
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from typing import Any, Protocol, TypedDict
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Typed dict definitions for tool definitions and conversation messages.
# These document the expected shapes and allow callers to pass TypedDict
# subclasses (e.g. ``ChatCompletionToolParam``) without ``type: ignore``.
# ---------------------------------------------------------------------------
class FunctionParameters(TypedDict, total=False):
"""JSON Schema object describing a tool function's parameters."""
type: str
properties: dict[str, Any]
required: list[str]
additionalProperties: bool
class FunctionDefinition(TypedDict, total=False):
"""Function definition within a tool definition."""
name: str
description: str
parameters: FunctionParameters
class ToolDefinition(TypedDict):
"""OpenAI-compatible tool definition (function-calling format).
Compatible with ``openai.types.chat.ChatCompletionToolParam`` and the
dict-based tool definitions built by ``OrchestratorBlock``.
"""
type: str
function: FunctionDefinition
class ConversationMessage(TypedDict, total=False):
"""A single message in the conversation (OpenAI chat format).
Primarily for documentation; at runtime plain dicts are used because
messages from different providers carry varying keys.
"""
role: str
content: str | list[Any] | None
tool_calls: list[dict[str, Any]]
tool_call_id: str
name: str
@dataclass
class ToolCallResult:
"""Result of a single tool execution."""
tool_call_id: str
tool_name: str
content: str
is_error: bool = False
@dataclass
class LLMToolCall:
"""A tool call extracted from an LLM response."""
id: str
name: str
arguments: str # JSON string
@dataclass
class LLMLoopResponse:
"""Response from a single LLM call in the loop.
``raw_response`` is typed as ``Any`` intentionally: the loop itself
never inspects it — it is an opaque pass-through that the caller's
``ConversationUpdater`` uses to rebuild provider-specific message
history (OpenAI ChatCompletion, Anthropic Message, Ollama str, etc.).
"""
response_text: str | None
tool_calls: list[LLMToolCall]
raw_response: Any
prompt_tokens: int = 0
completion_tokens: int = 0
reasoning: str | None = None
class LLMCaller(Protocol):
"""Protocol for LLM call functions."""
async def __call__(
self,
messages: list[dict[str, Any]],
tools: Sequence[Any],
) -> LLMLoopResponse: ...
class ToolExecutor(Protocol):
"""Protocol for tool execution functions."""
async def __call__(
self,
tool_call: LLMToolCall,
tools: Sequence[Any],
) -> ToolCallResult: ...
class ConversationUpdater(Protocol):
"""Protocol for updating conversation history after an LLM response."""
def __call__(
self,
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None: ...
@dataclass
class ToolCallLoopResult:
"""Final result of the tool-calling loop."""
response_text: str
messages: list[dict[str, Any]]
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
iterations: int = 0
finished_naturally: bool = True # False if hit max iterations
last_tool_calls: list[LLMToolCall] = field(default_factory=list)
async def tool_call_loop(
*,
messages: list[dict[str, Any]],
tools: Sequence[Any],
llm_call: LLMCaller,
execute_tool: ToolExecutor,
update_conversation: ConversationUpdater,
max_iterations: int = -1,
last_iteration_message: str | None = None,
parallel_tool_calls: bool = True,
) -> AsyncGenerator[ToolCallLoopResult, None]:
"""Run a tool-calling conversation loop as an async generator.
Yields a ``ToolCallLoopResult`` after each iteration so callers can
drain buffered events (e.g. streaming text deltas) between iterations.
The **final** yielded result has ``finished_naturally`` set and contains
the complete response text.
Args:
messages: Initial conversation messages (modified in-place).
tools: Tool function definitions (OpenAI format). Accepts any
sequence of tool dicts, including ``ChatCompletionToolParam``.
llm_call: Async function to call the LLM. The callback can
perform streaming internally (e.g. accumulate text deltas
and collect events) — it just needs to return the final
``LLMLoopResponse`` with extracted tool calls.
execute_tool: Async function to execute a tool call.
update_conversation: Function to update messages with LLM
response and tool results.
max_iterations: Max iterations. -1 = infinite, 0 = no loop
(immediately yields a "max reached" result).
last_iteration_message: Optional message to append on the last
iteration to encourage the model to finish.
parallel_tool_calls: If True (default), execute multiple tool
calls from a single LLM response concurrently via
``asyncio.gather``. Set to False when tool calls may have
ordering dependencies or mutate shared state.
Yields:
ToolCallLoopResult after each iteration. Check ``finished_naturally``
to determine if the loop completed or is still running.
"""
total_prompt_tokens = 0
total_completion_tokens = 0
iteration = 0
while max_iterations < 0 or iteration < max_iterations:
iteration += 1
# On last iteration, add a hint to finish. Only copy the list
# when the hint needs to be appended to avoid per-iteration overhead
# on long conversations.
is_last = (
last_iteration_message
and max_iterations > 0
and iteration == max_iterations
)
if is_last:
iteration_messages = list(messages)
iteration_messages.append(
{"role": "system", "content": last_iteration_message}
)
else:
iteration_messages = messages
# Call LLM
response = await llm_call(iteration_messages, tools)
total_prompt_tokens += response.prompt_tokens
total_completion_tokens += response.completion_tokens
# No tool calls = done
if not response.tool_calls:
update_conversation(messages, response)
yield ToolCallLoopResult(
response_text=response.response_text or "",
messages=messages,
total_prompt_tokens=total_prompt_tokens,
total_completion_tokens=total_completion_tokens,
iterations=iteration,
finished_naturally=True,
)
return
# Execute tools — parallel or sequential depending on caller preference.
# NOTE: asyncio.gather does not cancel sibling tasks when one raises.
# Callers should handle errors inside execute_tool (return error
# ToolCallResult) rather than letting exceptions propagate.
if parallel_tool_calls and len(response.tool_calls) > 1:
# Parallel: side-effects from different tool executors (e.g.
# streaming events appended to a shared list) may interleave
# nondeterministically. Each event carries its own tool-call
# identifier, so consumers must correlate by ID.
tool_results: list[ToolCallResult] = list(
await asyncio.gather(
*(execute_tool(tc, tools) for tc in response.tool_calls)
)
)
else:
# Sequential: preserves ordering guarantees for callers that
# need deterministic execution order.
tool_results = [await execute_tool(tc, tools) for tc in response.tool_calls]
# Update conversation with response + tool results
update_conversation(messages, response, tool_results)
# Yield a fresh result so callers can drain buffered events
yield ToolCallLoopResult(
response_text="",
messages=messages,
total_prompt_tokens=total_prompt_tokens,
total_completion_tokens=total_completion_tokens,
iterations=iteration,
finished_naturally=False,
last_tool_calls=list(response.tool_calls),
)
# Hit max iterations
yield ToolCallLoopResult(
response_text=f"Completed after {max_iterations} iterations (limit reached)",
messages=messages,
total_prompt_tokens=total_prompt_tokens,
total_completion_tokens=total_completion_tokens,
iterations=iteration,
finished_naturally=False,
)

View File

@@ -0,0 +1,554 @@
"""Unit tests for tool_call_loop shared abstraction.
Covers:
- Happy path with tool calls (single and multi-round)
- Final text response (no tool calls)
- Max iterations reached
- No tools scenario
- Exception propagation from tool executor
- Parallel tool execution
"""
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from typing import Any
import pytest
from backend.util.tool_call_loop import (
LLMLoopResponse,
LLMToolCall,
ToolCallLoopResult,
ToolCallResult,
tool_call_loop,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
TOOL_DEFS: list[dict[str, Any]] = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
]
def _make_response(
text: str | None = None,
tool_calls: list[LLMToolCall] | None = None,
prompt_tokens: int = 10,
completion_tokens: int = 5,
) -> LLMLoopResponse:
return LLMLoopResponse(
response_text=text,
tool_calls=tool_calls or [],
raw_response={"mock": True},
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_text_response_no_tool_calls():
"""LLM responds with text only -- loop should yield once and finish."""
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
return _make_response(text="Hello world")
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
raise AssertionError("Should not be called")
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
messages.append({"role": "assistant", "content": response.response_text})
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Hi"}]
results: list[ToolCallLoopResult] = []
async for r in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
):
results.append(r)
assert len(results) == 1
assert results[0].finished_naturally is True
assert results[0].response_text == "Hello world"
assert results[0].iterations == 1
assert results[0].total_prompt_tokens == 10
assert results[0].total_completion_tokens == 5
@pytest.mark.asyncio
async def test_single_tool_call_then_text():
"""LLM makes one tool call, then responds with text on second round."""
call_count = 0
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
nonlocal call_count
call_count += 1
if call_count == 1:
return _make_response(
tool_calls=[
LLMToolCall(
id="tc_1", name="get_weather", arguments='{"city":"NYC"}'
)
]
)
return _make_response(text="It's sunny in NYC")
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
return ToolCallResult(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
content='{"temp": 72}',
)
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
messages.append({"role": "assistant", "content": response.response_text})
if tool_results:
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Weather?"}]
results: list[ToolCallLoopResult] = []
async for r in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
):
results.append(r)
# First yield: tool call iteration (not finished)
# Second yield: text response (finished)
assert len(results) == 2
assert results[0].finished_naturally is False
assert results[0].iterations == 1
assert len(results[0].last_tool_calls) == 1
assert results[1].finished_naturally is True
assert results[1].response_text == "It's sunny in NYC"
assert results[1].iterations == 2
assert results[1].total_prompt_tokens == 20
assert results[1].total_completion_tokens == 10
@pytest.mark.asyncio
async def test_max_iterations_reached():
"""Loop should stop after max_iterations even if LLM keeps calling tools."""
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
return _make_response(
tool_calls=[
LLMToolCall(id="tc_x", name="get_weather", arguments='{"city":"X"}')
]
)
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
return ToolCallResult(
tool_call_id=tool_call.id, tool_name=tool_call.name, content="result"
)
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
pass
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
results: list[ToolCallLoopResult] = []
async for r in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
max_iterations=3,
):
results.append(r)
# 3 tool-call iterations + 1 final "max reached"
assert len(results) == 4
for r in results[:3]:
assert r.finished_naturally is False
final = results[-1]
assert final.finished_naturally is False
assert "3 iterations" in final.response_text
assert final.iterations == 3
@pytest.mark.asyncio
async def test_no_tools_first_response_text():
"""When LLM immediately responds with text (empty tools list), finishes."""
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
return _make_response(text="No tools needed")
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
raise AssertionError("Should not be called")
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
pass
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Hi"}]
results: list[ToolCallLoopResult] = []
async for r in tool_call_loop(
messages=msgs,
tools=[],
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
):
results.append(r)
assert len(results) == 1
assert results[0].finished_naturally is True
assert results[0].response_text == "No tools needed"
@pytest.mark.asyncio
async def test_tool_executor_exception_propagates():
"""Exception in execute_tool should propagate out of the loop."""
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
return _make_response(
tool_calls=[LLMToolCall(id="tc_err", name="get_weather", arguments="{}")]
)
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
raise RuntimeError("Tool execution failed!")
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
pass
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
with pytest.raises(RuntimeError, match="Tool execution failed!"):
async for _ in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
):
pass
@pytest.mark.asyncio
async def test_parallel_tool_execution():
"""Multiple tool calls in one response should execute concurrently."""
execution_order: list[str] = []
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
if len(messages) == 1:
return _make_response(
tool_calls=[
LLMToolCall(id="tc_a", name="tool_a", arguments="{}"),
LLMToolCall(id="tc_b", name="tool_b", arguments="{}"),
]
)
return _make_response(text="Done")
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
# tool_b starts instantly, tool_a has a small delay.
# With parallel execution, both should overlap.
if tool_call.name == "tool_a":
await asyncio.sleep(0.05)
execution_order.append(tool_call.name)
return ToolCallResult(
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
)
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
messages.append({"role": "assistant", "content": "called tools"})
if tool_results:
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Run both"}]
async for _ in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
):
pass
# With parallel execution, tool_b (no delay) finishes before tool_a
assert execution_order == ["tool_b", "tool_a"]
@pytest.mark.asyncio
async def test_sequential_tool_execution():
"""With parallel_tool_calls=False, tools execute in order regardless of speed."""
execution_order: list[str] = []
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
if len(messages) == 1:
return _make_response(
tool_calls=[
LLMToolCall(id="tc_a", name="tool_a", arguments="{}"),
LLMToolCall(id="tc_b", name="tool_b", arguments="{}"),
]
)
return _make_response(text="Done")
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
# tool_b would finish first if parallel, but sequential should keep order
if tool_call.name == "tool_a":
await asyncio.sleep(0.05)
execution_order.append(tool_call.name)
return ToolCallResult(
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
)
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
messages.append({"role": "assistant", "content": "called tools"})
if tool_results:
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Run both"}]
async for _ in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
parallel_tool_calls=False,
):
pass
# With sequential execution, tool_a runs first despite being slower
assert execution_order == ["tool_a", "tool_b"]
@pytest.mark.asyncio
async def test_last_iteration_message_appended():
"""On the final iteration, last_iteration_message should be appended."""
captured_messages: list[list[dict[str, Any]]] = []
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
captured_messages.append(list(messages))
return _make_response(
tool_calls=[LLMToolCall(id="tc_1", name="get_weather", arguments="{}")]
)
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
return ToolCallResult(
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
)
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
pass
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
async for _ in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
max_iterations=2,
last_iteration_message="Please finish now.",
):
pass
# First iteration: no extra message
assert len(captured_messages[0]) == 1
# Second (last) iteration: should have the hint appended
last_call_msgs = captured_messages[1]
assert any(
m.get("role") == "system" and "Please finish now." in m.get("content", "")
for m in last_call_msgs
)
@pytest.mark.asyncio
async def test_token_accumulation():
"""Tokens should accumulate across iterations."""
call_count = 0
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
nonlocal call_count
call_count += 1
if call_count <= 2:
return _make_response(
tool_calls=[
LLMToolCall(
id=f"tc_{call_count}", name="get_weather", arguments="{}"
)
],
prompt_tokens=100,
completion_tokens=50,
)
return _make_response(text="Final", prompt_tokens=100, completion_tokens=50)
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
return ToolCallResult(
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
)
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
pass
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
final_result = None
async for r in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
):
final_result = r
assert final_result is not None
assert final_result.total_prompt_tokens == 300 # 3 calls * 100
assert final_result.total_completion_tokens == 150 # 3 calls * 50
assert final_result.iterations == 3
@pytest.mark.asyncio
async def test_max_iterations_zero_no_loop():
"""max_iterations=0 should immediately yield a 'max reached' result without calling LLM."""
async def llm_call(
messages: list[dict[str, Any]], tools: Sequence[Any]
) -> LLMLoopResponse:
raise AssertionError("LLM should not be called when max_iterations=0")
async def execute_tool(
tool_call: LLMToolCall, tools: Sequence[Any]
) -> ToolCallResult:
raise AssertionError("Tool should not be called when max_iterations=0")
def update_conversation(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
raise AssertionError("Updater should not be called when max_iterations=0")
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
results: list[ToolCallLoopResult] = []
async for r in tool_call_loop(
messages=msgs,
tools=TOOL_DEFS,
llm_call=llm_call,
execute_tool=execute_tool,
update_conversation=update_conversation,
max_iterations=0,
):
results.append(r)
assert len(results) == 1
assert results[0].finished_naturally is False
assert results[0].iterations == 0
assert "0 iterations" in results[0].response_text

View File

@@ -0,0 +1,2 @@
-- Add durationMs column to ChatMessage for persisting turn elapsed time.
ALTER TABLE "ChatMessage" ADD COLUMN "durationMs" INTEGER;

View File

@@ -155,6 +155,7 @@ asyncio_default_fixture_loop_scope = "session"
addopts = "-p no:syrupy"
markers = [
"supplementary: tests kept for coverage but superseded by integration tests",
"integration: end-to-end tests that require a live database (skipped in CI)",
]
filterwarnings = [
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",

View File

@@ -257,7 +257,8 @@ model ChatMessage {
functionCall Json? // Deprecated but kept for compatibility
// Ordering within session
sequence Int
sequence Int
durationMs Int? // Wall-clock milliseconds for this assistant turn
@@unique([sessionId, sequence])
}

View File

@@ -1,9 +1,20 @@
import base64
from types import SimpleNamespace
from typing import cast
from unittest.mock import Mock, patch
import pytest
from backend.blocks.google.gmail import GmailReadBlock
from backend.blocks.google.gmail import (
GmailForwardBlock,
GmailReadBlock,
HasRecipients,
_build_reply_message,
create_mime_message,
validate_all_recipients,
validate_email_recipients,
)
from backend.data.execution import ExecutionContext
class TestGmailReadBlock:
@@ -250,3 +261,244 @@ class TestGmailReadBlock:
result = await self.gmail_block._get_email_body(msg, self.mock_service)
assert result == "This email does not contain a readable body."
class TestValidateEmailRecipients:
"""Test cases for validate_email_recipients."""
def test_valid_single_email(self):
validate_email_recipients(["user@example.com"])
def test_valid_multiple_emails(self):
validate_email_recipients(["a@b.com", "x@y.org", "test@sub.domain.co"])
def test_invalid_missing_at(self):
with pytest.raises(ValueError, match="Invalid email address"):
validate_email_recipients(["not-an-email"])
def test_invalid_missing_domain_dot(self):
with pytest.raises(ValueError, match="Invalid email address"):
validate_email_recipients(["user@localhost"])
def test_invalid_empty_string(self):
with pytest.raises(ValueError, match="Invalid email address"):
validate_email_recipients([""])
def test_invalid_json_object_string(self):
with pytest.raises(ValueError, match="Invalid email address"):
validate_email_recipients(['{"email": "user@example.com"}'])
def test_mixed_valid_and_invalid(self):
with pytest.raises(ValueError, match="'bad-addr'"):
validate_email_recipients(["good@example.com", "bad-addr"])
def test_field_name_in_error(self):
with pytest.raises(ValueError, match="'cc'"):
validate_email_recipients(["nope"], field_name="cc")
def test_whitespace_trimmed(self):
validate_email_recipients([" user@example.com "])
def test_empty_list_passes(self):
validate_email_recipients([])
class TestValidateAllRecipients:
"""Test cases for validate_all_recipients."""
def test_valid_all_fields(self):
data = cast(
HasRecipients,
SimpleNamespace(to=["a@b.com"], cc=["c@d.com"], bcc=["e@f.com"]),
)
validate_all_recipients(data)
def test_invalid_to_raises(self):
data = cast(HasRecipients, SimpleNamespace(to=["bad"], cc=[], bcc=[]))
with pytest.raises(ValueError, match="'to'"):
validate_all_recipients(data)
def test_invalid_cc_raises(self):
data = cast(HasRecipients, SimpleNamespace(to=["a@b.com"], cc=["bad"], bcc=[]))
with pytest.raises(ValueError, match="'cc'"):
validate_all_recipients(data)
def test_invalid_bcc_raises(self):
data = cast(
HasRecipients,
SimpleNamespace(to=["a@b.com"], cc=["c@d.com"], bcc=["bad"]),
)
with pytest.raises(ValueError, match="'bcc'"):
validate_all_recipients(data)
def test_empty_cc_bcc_skipped(self):
data = cast(HasRecipients, SimpleNamespace(to=["a@b.com"], cc=[], bcc=[]))
validate_all_recipients(data)
class TestCreateMimeMessageValidation:
"""Integration tests verifying validation hooks in create_mime_message()."""
@pytest.mark.asyncio
async def test_invalid_to_raises_before_mime_construction(self):
"""Invalid 'to' recipients should raise ValueError before any MIME work."""
input_data = SimpleNamespace(
to=["not-an-email"],
cc=[],
bcc=[],
subject="Test",
body="Hello",
attachments=[],
)
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
with pytest.raises(ValueError, match="Invalid email address"):
await create_mime_message(input_data, exec_ctx)
@pytest.mark.asyncio
async def test_invalid_cc_raises_before_mime_construction(self):
"""Invalid 'cc' recipients should raise ValueError."""
input_data = SimpleNamespace(
to=["valid@example.com"],
cc=["bad-addr"],
bcc=[],
subject="Test",
body="Hello",
attachments=[],
)
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
with pytest.raises(ValueError, match="'cc'"):
await create_mime_message(input_data, exec_ctx)
@pytest.mark.asyncio
async def test_valid_recipients_passes_validation(self):
"""Valid recipients should not raise during validation."""
input_data = SimpleNamespace(
to=["user@example.com"],
cc=["other@example.com"],
bcc=[],
subject="Test",
body="Hello",
attachments=[],
)
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
# Should succeed without raising
result = await create_mime_message(input_data, exec_ctx)
assert isinstance(result, str)
class TestBuildReplyMessageValidation:
"""Integration tests verifying validation hooks in _build_reply_message()."""
@pytest.mark.asyncio
async def test_invalid_to_raises_before_reply_construction(self):
"""Invalid 'to' in reply should raise ValueError before MIME work."""
mock_service = Mock()
mock_parent = {
"threadId": "thread-1",
"payload": {
"headers": [
{"name": "Subject", "value": "Original"},
{"name": "Message-ID", "value": "<msg@example.com>"},
{"name": "From", "value": "sender@example.com"},
]
},
}
mock_service.users().messages().get().execute.return_value = mock_parent
input_data = SimpleNamespace(
parentMessageId="msg-1",
to=["not-valid"],
cc=[],
bcc=[],
subject="",
body="Reply body",
replyAll=False,
attachments=[],
)
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
with pytest.raises(ValueError, match="Invalid email address"):
await _build_reply_message(mock_service, input_data, exec_ctx)
class TestForwardMessageValidation:
"""Test that _forward_message() raises ValueError for invalid recipients."""
@staticmethod
def _make_input(
to: list[str] | None = None,
cc: list[str] | None = None,
bcc: list[str] | None = None,
) -> "GmailForwardBlock.Input":
mock = Mock(spec=GmailForwardBlock.Input)
mock.messageId = "m1"
mock.to = to or []
mock.cc = cc or []
mock.bcc = bcc or []
mock.subject = ""
mock.forwardMessage = "FYI"
mock.includeAttachments = False
mock.content_type = None
mock.additionalAttachments = []
mock.credentials = None
return mock
@staticmethod
def _exec_ctx():
return ExecutionContext(user_id="u1", graph_exec_id="g1")
@staticmethod
def _mock_service():
"""Build a mock Gmail service that returns a parent message."""
parent_message = {
"id": "m1",
"payload": {
"headers": [
{"name": "Subject", "value": "Original subject"},
{"name": "From", "value": "sender@example.com"},
{"name": "To", "value": "me@example.com"},
{"name": "Date", "value": "Mon, 31 Mar 2026 00:00:00 +0000"},
],
"mimeType": "text/plain",
"body": {
"data": base64.urlsafe_b64encode(b"Hello world").decode(),
},
"parts": [],
},
}
svc = Mock()
svc.users().messages().get().execute.return_value = parent_message
return svc
@pytest.mark.asyncio
async def test_invalid_to_raises(self):
block = GmailForwardBlock()
with pytest.raises(ValueError, match="Invalid email address.*'to'"):
await block._forward_message(
self._mock_service(),
self._make_input(to=["bad-addr"]),
self._exec_ctx(),
)
@pytest.mark.asyncio
async def test_invalid_cc_raises(self):
block = GmailForwardBlock()
with pytest.raises(ValueError, match="Invalid email address.*'cc'"):
await block._forward_message(
self._mock_service(),
self._make_input(to=["valid@example.com"], cc=["not-valid"]),
self._exec_ctx(),
)
@pytest.mark.asyncio
async def test_invalid_bcc_raises(self):
block = GmailForwardBlock()
with pytest.raises(ValueError, match="Invalid email address.*'bcc'"):
await block._forward_message(
self._mock_service(),
self._make_input(to=["valid@example.com"], bcc=["nope"]),
self._exec_ctx(),
)

View File

@@ -17,6 +17,7 @@ images: {
"""
import asyncio
import os
import random
from datetime import datetime
@@ -569,6 +570,10 @@ async def main():
@pytest.mark.asyncio
@pytest.mark.integration
@pytest.mark.skipif(
os.getenv("CI") == "true",
reason="Data seeding test requires a dedicated database; not for CI",
)
async def test_main_function_runs_without_errors():
await main()

View File

@@ -1,5 +1,5 @@
# Base stage for both dev and prod
FROM node:21-alpine AS base
FROM node:22.22-alpine3.23 AS base
WORKDIR /app
RUN corepack enable
COPY autogpt_platform/frontend/package.json autogpt_platform/frontend/pnpm-lock.yaml ./
@@ -33,7 +33,7 @@ ENV NEXT_PUBLIC_SOURCEMAPS="false"
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
FROM node:21-alpine AS prod
FROM node:22.22-alpine3.23 AS prod
ENV NODE_ENV=production
ENV HOSTNAME=0.0.0.0
WORKDIR /app

View File

@@ -95,6 +95,8 @@ export function CopilotPage() {
isDeleting,
handleConfirmDelete,
handleCancelDelete,
// Historical durations for persisted timer stats
historicalDurations,
// Rate limit reset
rateLimitMessage,
dismissRateLimit,
@@ -186,6 +188,7 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
historicalDurations={historicalDurations}
/>
</div>
</div>

View File

@@ -27,6 +27,8 @@ export interface ChatContainerProps {
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
onDroppedFilesConsumed?: () => void;
/** Duration in ms for historical turns, keyed by message ID. */
historicalDurations?: Map<string, number>;
}
export const ChatContainer = ({
messages,
@@ -44,6 +46,7 @@ export const ChatContainer = ({
isUploadingFiles,
droppedFiles,
onDroppedFilesConsumed,
historicalDurations,
}: ChatContainerProps) => {
const isBusy =
status === "streaming" ||
@@ -81,6 +84,7 @@ export const ChatContainer = ({
isLoading={isLoadingSession}
sessionID={sessionId}
onRetry={handleRetry}
historicalDurations={historicalDurations}
/>
<motion.div
initial={{ opacity: 0 }}

View File

@@ -1,4 +1,4 @@
import { useMemo } from "react";
import { useEffect, useMemo, useRef } from "react";
import {
Conversation,
ConversationContent,
@@ -13,6 +13,7 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
import { useElapsedTimer } from "../JobStatsBar/useElapsedTimer";
import { CopilotPendingReviews } from "../CopilotPendingReviews/CopilotPendingReviews";
import {
buildRenderSegments,
@@ -37,6 +38,7 @@ interface Props {
isLoading: boolean;
sessionID?: string | null;
onRetry?: () => void;
historicalDurations?: Map<string, number>;
}
function renderSegments(
@@ -111,6 +113,7 @@ export function ChatMessagesContainer({
isLoading,
sessionID,
onRetry,
historicalDurations,
}: Props) {
const lastMessage = messages[messages.length - 1];
const graphExecId = useMemo(() => extractGraphExecId(messages), [messages]);
@@ -139,6 +142,25 @@ export function ChatMessagesContainer({
const showThinking =
status === "submitted" || (status === "streaming" && !hasInflight);
const isActivelyStreaming = status === "streaming" || status === "submitted";
const { elapsedSeconds } = useElapsedTimer(isActivelyStreaming);
// Freeze elapsed time when streaming ends so TurnStatsBar shows the final value.
// Reset when a new streaming turn begins.
const frozenElapsedRef = useRef(0);
const wasStreamingRef = useRef(false);
useEffect(() => {
if (isActivelyStreaming) {
if (!wasStreamingRef.current) {
frozenElapsedRef.current = 0;
}
if (elapsedSeconds > 0) {
frozenElapsedRef.current = elapsedSeconds;
}
}
wasStreamingRef.current = isActivelyStreaming;
});
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
@@ -239,10 +261,19 @@ export function ChatMessagesContainer({
{isLastInTurn && !isCurrentlyStreaming && (
<TurnStatsBar
turnMessages={getTurnMessages(messages, messageIndex)}
elapsedSeconds={
messageIndex === messages.length - 1
? frozenElapsedRef.current
: undefined
}
durationMs={historicalDurations?.get(message.id)}
/>
)}
{isLastAssistant && showThinking && (
<ThinkingIndicator active={showThinking} />
<ThinkingIndicator
active={showThinking}
elapsedSeconds={elapsedSeconds}
/>
)}
</MessageContent>
{message.role === "user" && textParts.length > 0 && (
@@ -268,7 +299,10 @@ export function ChatMessagesContainer({
{showThinking && lastMessage?.role !== "assistant" && (
<Message from="assistant">
<MessageContent className="text-[1rem] leading-relaxed">
<ThinkingIndicator active={showThinking} />
<ThinkingIndicator
active={showThinking}
elapsedSeconds={elapsedSeconds}
/>
</MessageContent>
</Message>
)}

View File

@@ -1,4 +1,5 @@
import { useEffect, useRef, useState } from "react";
import { formatElapsed } from "../../JobStatsBar/formatElapsed";
import { ScaleLoader } from "../../ScaleLoader/ScaleLoader";
const THINKING_PHRASES = [
@@ -27,6 +28,9 @@ const THINKING_PHRASES = [
const PHRASE_CYCLE_MS = 6_000;
const FADE_DURATION_MS = 300;
/** Only show elapsed time after this many seconds. */
const SHOW_TIME_AFTER_SECONDS = 20;
/**
* Cycles through thinking phrases sequentially with a fade-out/in transition.
* Returns the current phrase and whether it's visible (for opacity).
@@ -72,10 +76,12 @@ function useCyclingPhrase(active: boolean) {
interface Props {
active: boolean;
elapsedSeconds: number;
}
export function ThinkingIndicator({ active }: Props) {
export function ThinkingIndicator({ active, elapsedSeconds }: Props) {
const { phrase, visible } = useCyclingPhrase(active);
const showTime = active && elapsedSeconds >= SHOW_TIME_AFTER_SECONDS;
return (
<span className="inline-flex items-center gap-1.5 text-neutral-500">
@@ -88,6 +94,11 @@ export function ThinkingIndicator({ active }: Props) {
{phrase}
</span>
</span>
{showTime && (
<span className="animate-pulse tabular-nums [animation-duration:1.5s]">
{formatElapsed(elapsedSeconds)}
</span>
)}
</span>
);
}

View File

@@ -1,21 +1,44 @@
import type { UIDataTypes, UIMessage, UITools } from "ai";
import { formatElapsed } from "./formatElapsed";
import { getWorkDoneCounters } from "./useWorkDoneCounters";
interface Props {
turnMessages: UIMessage<unknown, UIDataTypes, UITools>[];
elapsedSeconds?: number;
durationMs?: number;
}
export function TurnStatsBar({ turnMessages }: Props) {
export function TurnStatsBar({
turnMessages,
elapsedSeconds,
durationMs,
}: Props) {
const { counters } = getWorkDoneCounters(turnMessages);
if (counters.length === 0) return null;
// Prefer live elapsedSeconds, fall back to persisted durationMs
const displaySeconds =
elapsedSeconds !== undefined && elapsedSeconds > 0
? elapsedSeconds
: durationMs !== undefined
? Math.round(durationMs / 1000)
: undefined;
const hasTime = displaySeconds !== undefined && displaySeconds > 0;
if (counters.length === 0 && !hasTime) return null;
return (
<div className="mt-2 flex items-center gap-1.5">
{hasTime && (
<span className="text-[11px] tabular-nums text-neutral-500">
Thought for {formatElapsed(displaySeconds)}
</span>
)}
{counters.map(function renderCounter(counter, index) {
const needsDot = index > 0 || hasTime;
return (
<span key={counter.category} className="flex items-center gap-1">
{index > 0 && (
{needsDot && (
<span className="text-xs text-neutral-300">&middot;</span>
)}
<span className="text-[11px] tabular-nums text-neutral-500">

View File

@@ -0,0 +1,7 @@
export function formatElapsed(totalSeconds: number): string {
const minutes = Math.floor(totalSeconds / 60);
const seconds = totalSeconds % 60;
if (minutes === 0) return `${seconds}s`;
return `${minutes}m ${seconds}s`;
}

View File

@@ -0,0 +1,31 @@
import { useEffect, useRef, useState } from "react";
export function useElapsedTimer(isRunning: boolean) {
const [elapsedSeconds, setElapsedSeconds] = useState(0);
const startTimeRef = useRef<number | null>(null);
const intervalRef = useRef<ReturnType<typeof setInterval>>();
useEffect(() => {
if (isRunning) {
if (startTimeRef.current === null) {
startTimeRef.current = Date.now();
setElapsedSeconds(0);
}
intervalRef.current = setInterval(() => {
if (startTimeRef.current !== null) {
setElapsedSeconds(
Math.floor((Date.now() - startTimeRef.current) / 1000),
);
}
}, 1000);
return () => clearInterval(intervalRef.current);
}
clearInterval(intervalRef.current);
startTimeRef.current = null;
}, [isRunning]);
return { elapsedSeconds };
}

View File

@@ -6,6 +6,7 @@ interface SessionChatMessage {
content: string | null;
tool_call_id: string | null;
tool_calls: unknown[] | null;
duration_ms: number | null;
}
function coerceSessionChatMessages(
@@ -34,6 +35,8 @@ function coerceSessionChatMessages(
? null
: String(msg.tool_call_id),
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
duration_ms:
typeof msg.duration_ms === "number" ? msg.duration_ms : null,
};
})
.filter((m): m is SessionChatMessage => m !== null);
@@ -102,7 +105,10 @@ export function convertChatSessionMessagesToUiMessages(
sessionId: string,
rawMessages: unknown[],
options?: { isComplete?: boolean },
): UIMessage<unknown, UIDataTypes, UITools>[] {
): {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
durations: Map<string, number>;
} {
const messages = coerceSessionChatMessages(rawMessages);
const toolOutputsByCallId = new Map<string, unknown>();
@@ -114,6 +120,7 @@ export function convertChatSessionMessagesToUiMessages(
}
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
const durations = new Map<string, number>();
messages.forEach((msg, index) => {
if (msg.role === "tool") return;
@@ -186,15 +193,24 @@ export function convertChatSessionMessagesToUiMessages(
const prevUI = uiMessages[uiMessages.length - 1];
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
prevUI.parts.push(...parts);
// Capture duration on merged message (last assistant msg wins)
if (msg.duration_ms != null) {
durations.set(prevUI.id, msg.duration_ms);
}
return;
}
const msgId = `${sessionId}-${index}`;
uiMessages.push({
id: `${sessionId}-${index}`,
id: msgId,
role: msg.role,
parts,
});
if (msg.role === "assistant" && msg.duration_ms != null) {
durations.set(msgId, msg.duration_ms);
}
});
return uiMessages;
return { messages: uiMessages, durations };
}

View File

@@ -61,13 +61,21 @@ export function useChatSession() {
// array reference every render. Re-derives only when query data changes.
// When the session is complete (no active stream), mark dangling tool
// calls as completed so stale spinners don't persist after refresh.
const hydratedMessages = useMemo(() => {
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
return convertChatSessionMessagesToUiMessages(
const { hydratedMessages, historicalDurations } = useMemo(() => {
if (sessionQuery.data?.status !== 200 || !sessionId)
return {
hydratedMessages: undefined,
historicalDurations: new Map<string, number>(),
};
const result = convertChatSessionMessagesToUiMessages(
sessionId,
sessionQuery.data.data.messages ?? [],
{ isComplete: !hasActiveStream },
);
return {
hydratedMessages: result.messages,
historicalDurations: result.durations,
};
}, [sessionQuery.data, sessionId, hasActiveStream]);
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
@@ -122,6 +130,7 @@ export function useChatSession() {
sessionId,
setSessionId,
hydratedMessages,
historicalDurations,
hasActiveStream,
isLoadingSession: sessionQuery.isLoading,
isSessionError: sessionQuery.isError,

View File

@@ -39,6 +39,7 @@ export function useCopilotPage() {
sessionId,
setSessionId,
hydratedMessages,
historicalDurations,
hasActiveStream,
isLoadingSession,
isSessionError,
@@ -377,6 +378,8 @@ export function useCopilotPage() {
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
// Historical durations for persisted timer stats
historicalDurations,
// Rate limit reset
rateLimitMessage,
dismissRateLimit,

View File

@@ -731,6 +731,7 @@ _Add technical explanation here._
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
| ollama_host | Ollama host for local models | str | No |
| agent_mode_max_iterations | Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit. | int | No |
| execution_mode | How tool calls are executed. 'built_in' uses the default tool-call loop (all providers). 'extended_thinking' delegates to an external Agent SDK for richer reasoning (currently Anthropic / OpenRouter only, requires API credentials, ignores 'Agent Mode Max Iterations'). | "built_in" \| "extended_thinking" | No |
| conversation_compaction | Automatically compact the context window once it hits the limit | bool | No |
### Outputs