mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/rate-limit-tiering
This commit is contained in:
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from abc import ABC
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
|
||||
from email.policy import SMTP
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
return ", ".join(recipients)
|
||||
"""Serialize recipients list to comma-separated string.
|
||||
|
||||
Strips leading/trailing whitespace from each address to keep MIME
|
||||
headers clean (mirrors the strip done in ``validate_email_recipients``).
|
||||
"""
|
||||
return ", ".join(addr.strip() for addr in recipients)
|
||||
|
||||
|
||||
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
|
||||
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
|
||||
"""Validate that all recipients are plausible email addresses.
|
||||
|
||||
Raises ``ValueError`` with a user-friendly message listing every
|
||||
invalid entry so the caller (or LLM) can correct them in one pass.
|
||||
"""
|
||||
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
|
||||
if invalid:
|
||||
formatted = ", ".join(f"'{a}'" for a in invalid)
|
||||
raise ValueError(
|
||||
f"Invalid email address(es) in '{field_name}': {formatted}. "
|
||||
f"Each entry must be a valid email address (e.g. user@example.com)."
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasRecipients(Protocol):
|
||||
to: list[str]
|
||||
cc: list[str]
|
||||
bcc: list[str]
|
||||
|
||||
|
||||
def validate_all_recipients(input_data: HasRecipients) -> None:
|
||||
"""Validate to/cc/bcc recipient fields on an input namespace.
|
||||
|
||||
Calls ``validate_email_recipients`` for ``to`` (required) and
|
||||
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
|
||||
first field that contains an invalid address.
|
||||
"""
|
||||
validate_email_recipients(input_data.to, "to")
|
||||
if input_data.cc:
|
||||
validate_email_recipients(input_data.cc, "cc")
|
||||
if input_data.bcc:
|
||||
validate_email_recipients(input_data.bcc, "bcc")
|
||||
|
||||
|
||||
def _make_mime_text(
|
||||
@@ -100,14 +145,16 @@ async def create_mime_message(
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
message = MIMEMultipart()
|
||||
message["to"] = serialize_email_recipients(input_data.to)
|
||||
message["subject"] = input_data.subject
|
||||
|
||||
if input_data.cc:
|
||||
message["cc"] = ", ".join(input_data.cc)
|
||||
message["cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
message["bcc"] = ", ".join(input_data.bcc)
|
||||
message["bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
|
||||
# Use the new helper function with content_type if available
|
||||
content_type = getattr(input_data, "content_type", None)
|
||||
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
@@ -1685,13 +1734,16 @@ To: {original_to}
|
||||
else:
|
||||
body = f"{forward_header}\n\n{original_body}"
|
||||
|
||||
# Validate all recipient lists before building the MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
|
||||
# Add body with proper content type
|
||||
|
||||
@@ -724,6 +724,9 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_reasoning")
|
||||
return None
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
@@ -739,6 +742,9 @@ def extract_openai_reasoning(response) -> str | None:
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
|
||||
return None
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
@@ -972,6 +978,8 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("Groq returned empty choices in response")
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1031,12 +1039,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
raise ValueError(f"OpenRouter returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1073,12 +1077,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"Llama API error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
raise ValueError(f"Llama API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1108,6 +1108,8 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not completion.choices:
|
||||
raise ValueError("AI/ML API returned empty choices in response")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=completion.choices[0].message,
|
||||
@@ -1144,6 +1146,9 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError(f"v0 API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,87 @@
|
||||
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
|
||||
|
||||
|
||||
class TestExtractOpenaiToolCallsEmptyChoices:
|
||||
"""extract_openai_tool_calls() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_tool_calls_when_choices_present(self):
|
||||
tool = MagicMock()
|
||||
tool.id = "call_1"
|
||||
tool.type = "function"
|
||||
tool.function.name = "my_func"
|
||||
tool.function.arguments = '{"a": 1}'
|
||||
|
||||
message = MagicMock()
|
||||
message.tool_calls = [tool]
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_tool_calls(response)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].function.name == "my_func"
|
||||
|
||||
def test_returns_none_when_no_tool_calls(self):
|
||||
message = MagicMock()
|
||||
message.tool_calls = None
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
|
||||
class TestExtractOpenaiReasoningEmptyChoices:
|
||||
"""extract_openai_reasoning() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_reasoning_from_choice(self):
|
||||
choice = MagicMock()
|
||||
choice.reasoning = "Step-by-step reasoning"
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr on response
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result == "Step-by-step reasoning"
|
||||
|
||||
def test_returns_none_when_no_reasoning(self):
|
||||
choice = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result is None
|
||||
@@ -1074,6 +1074,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
@@ -1105,6 +1106,7 @@ async def test_orchestrator_falls_back_to_block_name():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
|
||||
|
||||
Covers:
|
||||
- ExecutionMode enum members exist and have stable values
|
||||
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
|
||||
- EXTENDED_THINKING model-name validation (must start with "claude")
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExecutionMode enum integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutionModeEnum:
|
||||
"""Guard against accidental renames or removals of enum members."""
|
||||
|
||||
def test_built_in_exists(self):
|
||||
assert hasattr(ExecutionMode, "BUILT_IN")
|
||||
assert ExecutionMode.BUILT_IN.value == "built_in"
|
||||
|
||||
def test_extended_thinking_exists(self):
|
||||
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
|
||||
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
|
||||
|
||||
def test_exactly_two_members(self):
|
||||
"""If a new mode is added, this test should be updated intentionally."""
|
||||
assert set(ExecutionMode.__members__.keys()) == {
|
||||
"BUILT_IN",
|
||||
"EXTENDED_THINKING",
|
||||
}
|
||||
|
||||
def test_string_enum(self):
|
||||
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
|
||||
assert isinstance(ExecutionMode.BUILT_IN, str)
|
||||
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
|
||||
|
||||
def test_round_trip_from_value(self):
|
||||
"""Constructing from the string value should return the same member."""
|
||||
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
|
||||
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider validation (inline in OrchestratorBlock.run)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_model_stub(provider: str, value: str):
|
||||
"""Create a lightweight stub that behaves like LlmModel for validation."""
|
||||
metadata = MagicMock()
|
||||
metadata.provider = provider
|
||||
stub = MagicMock()
|
||||
stub.metadata = metadata
|
||||
stub.value = value
|
||||
return stub
|
||||
|
||||
|
||||
class TestExtendedThinkingProviderValidation:
|
||||
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
|
||||
|
||||
def test_anthropic_provider_accepted(self):
|
||||
"""provider='anthropic' + claude model should not raise."""
|
||||
model = _make_model_stub("anthropic", "claude-opus-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_open_router_provider_accepted(self):
|
||||
"""provider='open_router' + claude model should not raise."""
|
||||
model = _make_model_stub("open_router", "claude-sonnet-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_openai_provider_rejected(self):
|
||||
"""provider='openai' should be rejected for EXTENDED_THINKING."""
|
||||
model = _make_model_stub("openai", "gpt-4o")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_groq_provider_rejected(self):
|
||||
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
|
||||
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
|
||||
model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
model_name = model.value
|
||||
assert not model_name.startswith("claude")
|
||||
|
||||
def test_real_gpt4o_model_rejected(self):
|
||||
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
|
||||
model = LlmModel.GPT4O
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_real_claude_model_passes(self):
|
||||
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
|
||||
model = LlmModel.CLAUDE_4_6_SONNET
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: exercise the validation branch via OrchestratorBlock.run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
|
||||
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
|
||||
inp = MagicMock()
|
||||
inp.execution_mode = execution_mode
|
||||
inp.model = model
|
||||
inp.prompt = "test"
|
||||
inp.sys_prompt = ""
|
||||
inp.conversation_history = []
|
||||
inp.last_tool_output = None
|
||||
inp.prompt_values = {}
|
||||
return inp
|
||||
|
||||
|
||||
async def _collect_run_outputs(block, input_data, **kwargs):
|
||||
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
|
||||
outputs = []
|
||||
async for item in block.run(input_data, **kwargs):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
class TestExtendedThinkingValidationRaisesInBlock:
|
||||
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_anthropic_provider_raises_valueerror(self):
|
||||
"""EXTENDED_THINKING + openai provider raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
input_data = _make_input_data(model=LlmModel.GPT4O)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="Anthropic-compatible"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_claude_model_with_anthropic_provider_raises(self):
|
||||
"""A model with anthropic provider but non-claude name raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
input_data = _make_input_data(model=fake_model)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="only supports Claude models"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,11 +9,14 @@ shared tool registry as the SDK path.
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, cast
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -48,7 +51,17 @@ from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
from backend.util.tool_call_loop import (
|
||||
LLMLoopResponse,
|
||||
LLMToolCall,
|
||||
ToolCallResult,
|
||||
tool_call_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,6 +72,247 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaselineStreamState:
|
||||
"""Mutable state shared between the tool-call loop callbacks.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
|
||||
can be module-level functions instead of deeply nested closures.
|
||||
"""
|
||||
|
||||
pending_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
assistant_text: str = ""
|
||||
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text_started: bool = False
|
||||
turn_prompt_tokens: int = 0
|
||||
turn_completion_tokens: int = 0
|
||||
|
||||
|
||||
async def _baseline_llm_caller(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
state: _BaselineStreamState,
|
||||
) -> LLMLoopResponse:
|
||||
"""Stream an OpenAI-compatible response and collect results.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
state.pending_events.append(StreamStartStep())
|
||||
|
||||
round_text = ""
|
||||
try:
|
||||
client = _get_openai_client()
|
||||
typed_messages = cast(list[ChatCompletionMessageParam], messages)
|
||||
if tools:
|
||||
typed_tools = cast(list[ChatCompletionToolParam], tools)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=typed_messages,
|
||||
tools=typed_tools,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=typed_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.usage:
|
||||
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
if delta.content:
|
||||
if not state.text_started:
|
||||
state.pending_events.append(StreamTextStart(id=state.text_block_id))
|
||||
state.text_started = True
|
||||
round_text += delta.content
|
||||
state.pending_events.append(
|
||||
StreamTextDelta(id=state.text_block_id, delta=delta.content)
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block
|
||||
if state.text_started:
|
||||
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
finally:
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
# even when the stream is interrupted by an exception.
|
||||
state.assistant_text += round_text
|
||||
# Always emit StreamFinishStep to match the StreamStartStep,
|
||||
# even if an exception occurred during streaming.
|
||||
state.pending_events.append(StreamFinishStep())
|
||||
|
||||
# Convert to shared format
|
||||
llm_tool_calls = [
|
||||
LLMToolCall(
|
||||
id=tc["id"],
|
||||
name=tc["name"],
|
||||
arguments=tc["arguments"] or "{}",
|
||||
)
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
|
||||
return LLMLoopResponse(
|
||||
response_text=round_text or None,
|
||||
tool_calls=llm_tool_calls,
|
||||
raw_response=None, # Not needed for baseline conversation updater
|
||||
prompt_tokens=0, # Tracked via state accumulators
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
async def _baseline_tool_executor(
|
||||
tool_call: LLMToolCall,
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
state: _BaselineStreamState,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> ToolCallResult:
|
||||
"""Execute a tool via the copilot tool registry.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
tool_call_id = tool_call.id
|
||||
tool_name = tool_call.name
|
||||
raw_args = tool_call.arguments or "{}"
|
||||
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=parse_error,
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
state.pending_events.append(
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
)
|
||||
state.pending_events.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
state.pending_events.append(result)
|
||||
tool_output = (
|
||||
result.output if isinstance(result.output, str) else str(result.output)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=tool_output,
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=error_output,
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
|
||||
def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
if tool_results:
|
||||
# Build assistant message with tool_calls
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if response.response_text:
|
||||
assistant_msg["content"] = response.response_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {"name": tc.name, "arguments": tc.arguments},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
@@ -219,191 +473,32 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
state = _BaselineStreamState()
|
||||
|
||||
# Bind extracted module-level callbacks to this request's state/session
|
||||
# using functools.partial so they satisfy the Protocol signatures.
|
||||
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
|
||||
_bound_tool_executor = partial(
|
||||
_baseline_tool_executor, state=state, user_id=user_id, session=session
|
||||
)
|
||||
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
yield StreamStartStep()
|
||||
step_open = True
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_baseline_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
|
||||
# Stream a response from the model
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
# NOTE: stream_options={"include_usage": True} is not
|
||||
# universally supported — some providers (Mistral, Llama
|
||||
# via OpenRouter) always return chunk.usage=None. When
|
||||
# that happens, tokens stay 0 and the tiktoken fallback
|
||||
# below activates. Fail-open: one round is estimated.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Text content
|
||||
if delta.content:
|
||||
if not text_started:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
round_text += delta.content
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.content)
|
||||
|
||||
# Tool call fragments (streamed incrementally)
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block if we had one this round
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_started = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Accumulate text for session persistence
|
||||
assistant_text += round_text
|
||||
|
||||
# No tool calls -> model is done
|
||||
if not tool_calls_by_index:
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
break
|
||||
|
||||
# Close step before tool execution
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
openai_messages.append(assistant_msg)
|
||||
|
||||
# Execute each tool call and stream events
|
||||
for tc in tool_calls_by_index.values():
|
||||
tool_call_id = tc["id"]
|
||||
tool_name = tc["name"]
|
||||
raw_args = tc["arguments"] or "{}"
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = (
|
||||
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
)
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": parse_error,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
|
||||
# Execute via shared tool registry
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
yield result
|
||||
tool_output = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else str(result.output)
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
tool_output = error_output
|
||||
|
||||
# Append tool result to context for next round
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# for-loop exhausted without break -> tool-round limit hit
|
||||
if loop_result and not loop_result.finished_naturally:
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
"without a final response."
|
||||
@@ -418,11 +513,28 @@ async def stream_chat_completion_baseline(
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
if step_open:
|
||||
yield StreamFinishStep()
|
||||
# Close any open text block. The llm_caller's finally block
|
||||
# already appended StreamFinishStep to pending_events, so we must
|
||||
# insert StreamTextEnd *before* StreamFinishStep to preserve the
|
||||
# protocol ordering:
|
||||
# StreamStartStep -> StreamTextStart -> ...deltas... ->
|
||||
# StreamTextEnd -> StreamFinishStep
|
||||
# Appending (or yielding directly) would place it after
|
||||
# StreamFinishStep, violating the protocol.
|
||||
if state.text_started:
|
||||
# Find the last StreamFinishStep and insert before it.
|
||||
insert_pos = len(state.pending_events)
|
||||
for i in range(len(state.pending_events) - 1, -1, -1):
|
||||
if isinstance(state.pending_events[i], StreamFinishStep):
|
||||
insert_pos = i
|
||||
break
|
||||
state.pending_events.insert(
|
||||
insert_pos, StreamTextEnd(id=state.text_block_id)
|
||||
)
|
||||
# Drain pending events in correct order
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
@@ -442,26 +554,21 @@ async def stream_chat_completion_baseline(
|
||||
# Skip fallback when an error occurred and no output was produced —
|
||||
# charging rate-limit tokens for completely failed requests is unfair.
|
||||
if (
|
||||
turn_prompt_tokens == 0
|
||||
and turn_completion_tokens == 0
|
||||
and not (_stream_error and not assistant_text)
|
||||
state.turn_prompt_tokens == 0
|
||||
and state.turn_completion_tokens == 0
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
from backend.util.prompt import (
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
turn_prompt_tokens = max(
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
)
|
||||
turn_completion_tokens = estimate_token_count_str(
|
||||
assistant_text, model=config.model
|
||||
state.turn_completion_tokens = estimate_token_count_str(
|
||||
state.assistant_text, model=config.model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
state.turn_prompt_tokens,
|
||||
state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Persist token usage to session and record for rate limiting.
|
||||
@@ -471,15 +578,15 @@ async def stream_chat_completion_baseline(
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
if state.assistant_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=assistant_text)
|
||||
ChatMessage(role="assistant", content=state.assistant_text)
|
||||
)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
@@ -491,11 +598,11 @@ async def stream_chat_completion_baseline(
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -178,7 +178,7 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
present — mirrors the fallback logic in ``build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
|
||||
@@ -18,7 +18,7 @@ from prisma.types import (
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -217,6 +217,9 @@ async def add_chat_messages_batch(
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
if msg.get("duration_ms") is not None:
|
||||
data["durationMs"] = msg["duration_ms"]
|
||||
|
||||
messages_data.append(data)
|
||||
|
||||
# Run create_many and session update in parallel within transaction
|
||||
@@ -359,3 +362,22 @@ async def update_tool_message_content(
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
Also invalidates the Redis session cache so the next GET returns
|
||||
the updated duration.
|
||||
"""
|
||||
last_msg = await PrismaChatMessage.prisma().find_first(
|
||||
where={"sessionId": session_id, "role": "assistant"},
|
||||
order={"sequence": "desc"},
|
||||
)
|
||||
if last_msg:
|
||||
await PrismaChatMessage.prisma().update(
|
||||
where={"id": last_msg.id},
|
||||
data={"durationMs": duration_ms},
|
||||
)
|
||||
# Invalidate cache so the session is re-fetched from DB with durationMs
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
@@ -54,6 +54,7 @@ class ChatMessage(BaseModel):
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||
@@ -66,6 +67,7 @@ class ChatMessage(BaseModel):
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
duration_ms=prisma_message.durationMs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
68
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
68
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""SDK environment variable builder — importable without circular deps.
|
||||
|
||||
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
|
||||
can reuse the same subscription / OpenRouter / direct-Anthropic logic
|
||||
without pulling in the full copilot service module (which would create a
|
||||
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.sdk.subscription import validate_subscription
|
||||
|
||||
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
|
||||
# A singleton would require importing service.py which causes the circular dep
|
||||
# this module was created to avoid.
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
|
||||
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
validate_subscription()
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
242
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
242
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — build a ChatConfig with explicit field values so tests don't
|
||||
# depend on real environment variables.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 1 — Subscription auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvSubscription:
|
||||
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_returns_blanked_keys(self, mock_validate):
|
||||
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"backend.copilot.sdk.env.validate_subscription",
|
||||
side_effect=RuntimeError("CLI not found"),
|
||||
)
|
||||
def test_propagates_validation_error(self, mock_validate):
|
||||
"""If validate_subscription fails, the error bubbles up."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
with pytest.raises(RuntimeError, match="CLI not found"):
|
||||
build_sdk_env()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 2 — Direct Anthropic (no OpenRouter)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvDirectAnthropic:
|
||||
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_inactive(self):
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
|
||||
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
|
||||
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
|
||||
# Force api_key to None after construction (field_validator may pick up env vars)
|
||||
object.__setattr__(cfg, "api_key", None)
|
||||
assert not cfg.openrouter_active
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 3 — OpenRouter proxy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvOpenRouter:
|
||||
"""When OpenRouter is active, return proxy env vars."""
|
||||
|
||||
def _openrouter_config(self, **overrides):
|
||||
defaults = {
|
||||
"use_openrouter": True,
|
||||
"api_key": "sk-or-test-key",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return _make_config(**defaults)
|
||||
|
||||
def test_basic_openrouter_env(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
|
||||
|
||||
def test_strips_trailing_v1(self):
|
||||
"""The /v1 suffix is stripped from the base URL."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_strips_trailing_v1_and_slash(self):
|
||||
"""Trailing slash before /v1 strip is handled."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# rstrip("/") first, then remove /v1
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_no_v1_suffix_left_alone(self):
|
||||
"""A base URL without /v1 is used as-is."""
|
||||
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
|
||||
|
||||
def test_session_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="sess-123")
|
||||
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" in result
|
||||
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_user_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(user_id="user-456")
|
||||
|
||||
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_both_headers(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="s1", user_id="u2")
|
||||
|
||||
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
assert "x-session-id: s1" in headers
|
||||
assert "x-user-id: u2" in headers
|
||||
# They should be newline-separated
|
||||
assert "\n" in headers
|
||||
|
||||
def test_header_sanitisation_strips_newlines(self):
|
||||
"""Newlines/carriage-returns in header values are stripped."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="bad\r\nvalue")
|
||||
|
||||
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
# The _safe helper removes \r and \n
|
||||
assert "\r" not in header_val.split(": ", 1)[1]
|
||||
assert "badvalue" in header_val
|
||||
|
||||
def test_header_value_truncated_to_128_chars(self):
|
||||
"""Header values are truncated to 128 characters."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
long_id = "x" * 200
|
||||
result = build_sdk_env(session_id=long_id)
|
||||
|
||||
# The value after "x-session-id: " should be at most 128 chars
|
||||
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
value = header_line.split(": ", 1)[1]
|
||||
assert len(value) == 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode priority
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvModePriority:
|
||||
"""Subscription mode takes precedence over OpenRouter."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_subscription_overrides_openrouter(self, mock_validate):
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=True,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# Should get subscription result, not OpenRouter
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
@@ -1010,7 +1010,7 @@ def _make_sdk_patches(
|
||||
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
|
||||
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
|
||||
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
|
||||
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
|
||||
(f"{_SVC}.set_execution_context", {}),
|
||||
(
|
||||
|
||||
@@ -78,9 +78,9 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .subscription import validate_subscription as _validate_claude_code_subscription
|
||||
from .tool_adapter import (
|
||||
cancel_pending_tool_tasks,
|
||||
create_copilot_mcp_server,
|
||||
@@ -568,60 +568,6 @@ def _resolve_sdk_model() -> str | None:
|
||||
return model
|
||||
|
||||
|
||||
def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
|
||||
2. **Direct Anthropic** — returns `{}`; subprocess inherits
|
||||
`ANTHROPIC_API_KEY` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
_validate_claude_code_subscription()
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
# `openrouter_active` checks the flag *and* credential presence.
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path.
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
"""Sanitise a header value: strip newlines/whitespace and cap length."""
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
@@ -1868,7 +1814,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
|
||||
@@ -26,6 +26,7 @@ import orjson
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
@@ -111,6 +112,14 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
|
||||
pre-dates the turn_id field (backward compat for in-flight sessions).
|
||||
"""
|
||||
created_at = datetime.now(timezone.utc)
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return ActiveSession(
|
||||
session_id=meta.get("session_id", "") or session_id,
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
@@ -119,6 +128,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
turn_id=meta.get("turn_id", "") or session_id,
|
||||
blocking=meta.get("blocking") == "1",
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -802,6 +812,33 @@ async def mark_session_completed(
|
||||
f"Failed to publish error event for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
# Compute wall-clock duration from session created_at.
|
||||
# Only persist when (a) the session completed successfully and
|
||||
# (b) created_at was actually present in Redis meta (not a fallback).
|
||||
duration_ms: int | None = None
|
||||
if meta and not error_message:
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
elapsed = datetime.now(timezone.utc) - created_at
|
||||
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Failed to compute session duration for %s (created_at=%r)",
|
||||
session_id,
|
||||
created_at_raw,
|
||||
)
|
||||
|
||||
# Persist duration on the last assistant message
|
||||
if duration_ms is not None:
|
||||
try:
|
||||
await chat_db().set_turn_duration(session_id, duration_ms)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
|
||||
|
||||
# Publish StreamFinish AFTER status is set to "completed"/"failed".
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
|
||||
@@ -344,6 +344,7 @@ class DatabaseManager(AppService):
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
update_chat_session_title = _(chat_db.update_chat_session_title)
|
||||
set_turn_duration = _(chat_db.set_turn_duration)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
@@ -540,3 +541,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
update_chat_session_title = d.update_chat_session_title
|
||||
set_turn_duration = d.set_turn_duration
|
||||
|
||||
20
autogpt_platform/backend/backend/util/security.py
Normal file
20
autogpt_platform/backend/backend/util/security.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Shared security constants for field-level filtering.
|
||||
|
||||
Other modules (e.g. orchestrator, future blocks) import from here so the
|
||||
sensitive-field list stays in one place.
|
||||
"""
|
||||
|
||||
# Field names to exclude from hardcoded-defaults descriptions (case-insensitive).
|
||||
SENSITIVE_FIELD_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"credentials",
|
||||
"api_key",
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"auth",
|
||||
"authorization",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
}
|
||||
)
|
||||
281
autogpt_platform/backend/backend/util/tool_call_loop.py
Normal file
281
autogpt_platform/backend/backend/util/tool_call_loop.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Shared tool-calling conversation loop.
|
||||
|
||||
Provides a generic, provider-agnostic conversation loop that both
|
||||
the OrchestratorBlock and copilot baseline can use. The loop:
|
||||
|
||||
1. Calls the LLM with tool definitions
|
||||
2. Extracts tool calls from the response
|
||||
3. Executes tools via a caller-supplied callback
|
||||
4. Appends results to the conversation
|
||||
5. Repeats until no more tool calls or max iterations reached
|
||||
|
||||
Callers provide callbacks for LLM calling, tool execution, and
|
||||
conversation updating.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed dict definitions for tool definitions and conversation messages.
|
||||
# These document the expected shapes and allow callers to pass TypedDict
|
||||
# subclasses (e.g. ``ChatCompletionToolParam``) without ``type: ignore``.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FunctionParameters(TypedDict, total=False):
|
||||
"""JSON Schema object describing a tool function's parameters."""
|
||||
|
||||
type: str
|
||||
properties: dict[str, Any]
|
||||
required: list[str]
|
||||
additionalProperties: bool
|
||||
|
||||
|
||||
class FunctionDefinition(TypedDict, total=False):
|
||||
"""Function definition within a tool definition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: FunctionParameters
|
||||
|
||||
|
||||
class ToolDefinition(TypedDict):
|
||||
"""OpenAI-compatible tool definition (function-calling format).
|
||||
|
||||
Compatible with ``openai.types.chat.ChatCompletionToolParam`` and the
|
||||
dict-based tool definitions built by ``OrchestratorBlock``.
|
||||
"""
|
||||
|
||||
type: str
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class ConversationMessage(TypedDict, total=False):
|
||||
"""A single message in the conversation (OpenAI chat format).
|
||||
|
||||
Primarily for documentation; at runtime plain dicts are used because
|
||||
messages from different providers carry varying keys.
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str | list[Any] | None
|
||||
tool_calls: list[dict[str, Any]]
|
||||
tool_call_id: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""Result of a single tool execution."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
content: str
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMToolCall:
|
||||
"""A tool call extracted from an LLM response."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: str # JSON string
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMLoopResponse:
|
||||
"""Response from a single LLM call in the loop.
|
||||
|
||||
``raw_response`` is typed as ``Any`` intentionally: the loop itself
|
||||
never inspects it — it is an opaque pass-through that the caller's
|
||||
``ConversationUpdater`` uses to rebuild provider-specific message
|
||||
history (OpenAI ChatCompletion, Anthropic Message, Ollama str, etc.).
|
||||
"""
|
||||
|
||||
response_text: str | None
|
||||
tool_calls: list[LLMToolCall]
|
||||
raw_response: Any
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
class LLMCaller(Protocol):
|
||||
"""Protocol for LLM call functions."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Sequence[Any],
|
||||
) -> LLMLoopResponse: ...
|
||||
|
||||
|
||||
class ToolExecutor(Protocol):
|
||||
"""Protocol for tool execution functions."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
tool_call: LLMToolCall,
|
||||
tools: Sequence[Any],
|
||||
) -> ToolCallResult: ...
|
||||
|
||||
|
||||
class ConversationUpdater(Protocol):
|
||||
"""Protocol for updating conversation history after an LLM response."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallLoopResult:
|
||||
"""Final result of the tool-calling loop."""
|
||||
|
||||
response_text: str
|
||||
messages: list[dict[str, Any]]
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
iterations: int = 0
|
||||
finished_naturally: bool = True # False if hit max iterations
|
||||
last_tool_calls: list[LLMToolCall] = field(default_factory=list)
|
||||
|
||||
|
||||
async def tool_call_loop(
|
||||
*,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Sequence[Any],
|
||||
llm_call: LLMCaller,
|
||||
execute_tool: ToolExecutor,
|
||||
update_conversation: ConversationUpdater,
|
||||
max_iterations: int = -1,
|
||||
last_iteration_message: str | None = None,
|
||||
parallel_tool_calls: bool = True,
|
||||
) -> AsyncGenerator[ToolCallLoopResult, None]:
|
||||
"""Run a tool-calling conversation loop as an async generator.
|
||||
|
||||
Yields a ``ToolCallLoopResult`` after each iteration so callers can
|
||||
drain buffered events (e.g. streaming text deltas) between iterations.
|
||||
The **final** yielded result has ``finished_naturally`` set and contains
|
||||
the complete response text.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (modified in-place).
|
||||
tools: Tool function definitions (OpenAI format). Accepts any
|
||||
sequence of tool dicts, including ``ChatCompletionToolParam``.
|
||||
llm_call: Async function to call the LLM. The callback can
|
||||
perform streaming internally (e.g. accumulate text deltas
|
||||
and collect events) — it just needs to return the final
|
||||
``LLMLoopResponse`` with extracted tool calls.
|
||||
execute_tool: Async function to execute a tool call.
|
||||
update_conversation: Function to update messages with LLM
|
||||
response and tool results.
|
||||
max_iterations: Max iterations. -1 = infinite, 0 = no loop
|
||||
(immediately yields a "max reached" result).
|
||||
last_iteration_message: Optional message to append on the last
|
||||
iteration to encourage the model to finish.
|
||||
parallel_tool_calls: If True (default), execute multiple tool
|
||||
calls from a single LLM response concurrently via
|
||||
``asyncio.gather``. Set to False when tool calls may have
|
||||
ordering dependencies or mutate shared state.
|
||||
|
||||
Yields:
|
||||
ToolCallLoopResult after each iteration. Check ``finished_naturally``
|
||||
to determine if the loop completed or is still running.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
iteration = 0
|
||||
|
||||
while max_iterations < 0 or iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# On last iteration, add a hint to finish. Only copy the list
|
||||
# when the hint needs to be appended to avoid per-iteration overhead
|
||||
# on long conversations.
|
||||
is_last = (
|
||||
last_iteration_message
|
||||
and max_iterations > 0
|
||||
and iteration == max_iterations
|
||||
)
|
||||
if is_last:
|
||||
iteration_messages = list(messages)
|
||||
iteration_messages.append(
|
||||
{"role": "system", "content": last_iteration_message}
|
||||
)
|
||||
else:
|
||||
iteration_messages = messages
|
||||
|
||||
# Call LLM
|
||||
response = await llm_call(iteration_messages, tools)
|
||||
total_prompt_tokens += response.prompt_tokens
|
||||
total_completion_tokens += response.completion_tokens
|
||||
|
||||
# No tool calls = done
|
||||
if not response.tool_calls:
|
||||
update_conversation(messages, response)
|
||||
yield ToolCallLoopResult(
|
||||
response_text=response.response_text or "",
|
||||
messages=messages,
|
||||
total_prompt_tokens=total_prompt_tokens,
|
||||
total_completion_tokens=total_completion_tokens,
|
||||
iterations=iteration,
|
||||
finished_naturally=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Execute tools — parallel or sequential depending on caller preference.
|
||||
# NOTE: asyncio.gather does not cancel sibling tasks when one raises.
|
||||
# Callers should handle errors inside execute_tool (return error
|
||||
# ToolCallResult) rather than letting exceptions propagate.
|
||||
if parallel_tool_calls and len(response.tool_calls) > 1:
|
||||
# Parallel: side-effects from different tool executors (e.g.
|
||||
# streaming events appended to a shared list) may interleave
|
||||
# nondeterministically. Each event carries its own tool-call
|
||||
# identifier, so consumers must correlate by ID.
|
||||
tool_results: list[ToolCallResult] = list(
|
||||
await asyncio.gather(
|
||||
*(execute_tool(tc, tools) for tc in response.tool_calls)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Sequential: preserves ordering guarantees for callers that
|
||||
# need deterministic execution order.
|
||||
tool_results = [await execute_tool(tc, tools) for tc in response.tool_calls]
|
||||
|
||||
# Update conversation with response + tool results
|
||||
update_conversation(messages, response, tool_results)
|
||||
|
||||
# Yield a fresh result so callers can drain buffered events
|
||||
yield ToolCallLoopResult(
|
||||
response_text="",
|
||||
messages=messages,
|
||||
total_prompt_tokens=total_prompt_tokens,
|
||||
total_completion_tokens=total_completion_tokens,
|
||||
iterations=iteration,
|
||||
finished_naturally=False,
|
||||
last_tool_calls=list(response.tool_calls),
|
||||
)
|
||||
|
||||
# Hit max iterations
|
||||
yield ToolCallLoopResult(
|
||||
response_text=f"Completed after {max_iterations} iterations (limit reached)",
|
||||
messages=messages,
|
||||
total_prompt_tokens=total_prompt_tokens,
|
||||
total_completion_tokens=total_completion_tokens,
|
||||
iterations=iteration,
|
||||
finished_naturally=False,
|
||||
)
|
||||
554
autogpt_platform/backend/backend/util/tool_call_loop_test.py
Normal file
554
autogpt_platform/backend/backend/util/tool_call_loop_test.py
Normal file
@@ -0,0 +1,554 @@
|
||||
"""Unit tests for tool_call_loop shared abstraction.
|
||||
|
||||
Covers:
|
||||
- Happy path with tool calls (single and multi-round)
|
||||
- Final text response (no tool calls)
|
||||
- Max iterations reached
|
||||
- No tools scenario
|
||||
- Exception propagation from tool executor
|
||||
- Parallel tool execution
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.tool_call_loop import (
|
||||
LLMLoopResponse,
|
||||
LLMToolCall,
|
||||
ToolCallLoopResult,
|
||||
ToolCallResult,
|
||||
tool_call_loop,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TOOL_DEFS: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _make_response(
|
||||
text: str | None = None,
|
||||
tool_calls: list[LLMToolCall] | None = None,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 5,
|
||||
) -> LLMLoopResponse:
|
||||
return LLMLoopResponse(
|
||||
response_text=text,
|
||||
tool_calls=tool_calls or [],
|
||||
raw_response={"mock": True},
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_response_no_tool_calls():
|
||||
"""LLM responds with text only -- loop should yield once and finish."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(text="Hello world")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise AssertionError("Should not be called")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Hi"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].finished_naturally is True
|
||||
assert results[0].response_text == "Hello world"
|
||||
assert results[0].iterations == 1
|
||||
assert results[0].total_prompt_tokens == 10
|
||||
assert results[0].total_completion_tokens == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_then_text():
|
||||
"""LLM makes one tool call, then responds with text on second round."""
|
||||
call_count = 0
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id="tc_1", name="get_weather", arguments='{"city":"NYC"}'
|
||||
)
|
||||
]
|
||||
)
|
||||
return _make_response(text="It's sunny in NYC")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
content='{"temp": 72}',
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
if tool_results:
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Weather?"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
# First yield: tool call iteration (not finished)
|
||||
# Second yield: text response (finished)
|
||||
assert len(results) == 2
|
||||
assert results[0].finished_naturally is False
|
||||
assert results[0].iterations == 1
|
||||
assert len(results[0].last_tool_calls) == 1
|
||||
assert results[1].finished_naturally is True
|
||||
assert results[1].response_text == "It's sunny in NYC"
|
||||
assert results[1].iterations == 2
|
||||
assert results[1].total_prompt_tokens == 20
|
||||
assert results[1].total_completion_tokens == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_reached():
|
||||
"""Loop should stop after max_iterations even if LLM keeps calling tools."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_x", name="get_weather", arguments='{"city":"X"}')
|
||||
]
|
||||
)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="result"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
max_iterations=3,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
# 3 tool-call iterations + 1 final "max reached"
|
||||
assert len(results) == 4
|
||||
for r in results[:3]:
|
||||
assert r.finished_naturally is False
|
||||
final = results[-1]
|
||||
assert final.finished_naturally is False
|
||||
assert "3 iterations" in final.response_text
|
||||
assert final.iterations == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tools_first_response_text():
|
||||
"""When LLM immediately responds with text (empty tools list), finishes."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(text="No tools needed")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise AssertionError("Should not be called")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Hi"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=[],
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].finished_naturally is True
|
||||
assert results[0].response_text == "No tools needed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_executor_exception_propagates():
|
||||
"""Exception in execute_tool should propagate out of the loop."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(
|
||||
tool_calls=[LLMToolCall(id="tc_err", name="get_weather", arguments="{}")]
|
||||
)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise RuntimeError("Tool execution failed!")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
with pytest.raises(RuntimeError, match="Tool execution failed!"):
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_execution():
|
||||
"""Multiple tool calls in one response should execute concurrently."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
if len(messages) == 1:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_a", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_b", name="tool_b", arguments="{}"),
|
||||
]
|
||||
)
|
||||
return _make_response(text="Done")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
# tool_b starts instantly, tool_a has a small delay.
|
||||
# With parallel execution, both should overlap.
|
||||
if tool_call.name == "tool_a":
|
||||
await asyncio.sleep(0.05)
|
||||
execution_order.append(tool_call.name)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": "called tools"})
|
||||
if tool_results:
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Run both"}]
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
pass
|
||||
|
||||
# With parallel execution, tool_b (no delay) finishes before tool_a
|
||||
assert execution_order == ["tool_b", "tool_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_tool_execution():
|
||||
"""With parallel_tool_calls=False, tools execute in order regardless of speed."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
if len(messages) == 1:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_a", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_b", name="tool_b", arguments="{}"),
|
||||
]
|
||||
)
|
||||
return _make_response(text="Done")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
# tool_b would finish first if parallel, but sequential should keep order
|
||||
if tool_call.name == "tool_a":
|
||||
await asyncio.sleep(0.05)
|
||||
execution_order.append(tool_call.name)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": "called tools"})
|
||||
if tool_results:
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Run both"}]
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
parallel_tool_calls=False,
|
||||
):
|
||||
pass
|
||||
|
||||
# With sequential execution, tool_a runs first despite being slower
|
||||
assert execution_order == ["tool_a", "tool_b"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_iteration_message_appended():
|
||||
"""On the final iteration, last_iteration_message should be appended."""
|
||||
captured_messages: list[list[dict[str, Any]]] = []
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
captured_messages.append(list(messages))
|
||||
return _make_response(
|
||||
tool_calls=[LLMToolCall(id="tc_1", name="get_weather", arguments="{}")]
|
||||
)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
max_iterations=2,
|
||||
last_iteration_message="Please finish now.",
|
||||
):
|
||||
pass
|
||||
|
||||
# First iteration: no extra message
|
||||
assert len(captured_messages[0]) == 1
|
||||
# Second (last) iteration: should have the hint appended
|
||||
last_call_msgs = captured_messages[1]
|
||||
assert any(
|
||||
m.get("role") == "system" and "Please finish now." in m.get("content", "")
|
||||
for m in last_call_msgs
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_accumulation():
|
||||
"""Tokens should accumulate across iterations."""
|
||||
call_count = 0
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id=f"tc_{call_count}", name="get_weather", arguments="{}"
|
||||
)
|
||||
],
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
return _make_response(text="Final", prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
final_result = None
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
final_result = r
|
||||
|
||||
assert final_result is not None
|
||||
assert final_result.total_prompt_tokens == 300 # 3 calls * 100
|
||||
assert final_result.total_completion_tokens == 150 # 3 calls * 50
|
||||
assert final_result.iterations == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_zero_no_loop():
|
||||
"""max_iterations=0 should immediately yield a 'max reached' result without calling LLM."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
raise AssertionError("LLM should not be called when max_iterations=0")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise AssertionError("Tool should not be called when max_iterations=0")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
raise AssertionError("Updater should not be called when max_iterations=0")
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
max_iterations=0,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].finished_naturally is False
|
||||
assert results[0].iterations == 0
|
||||
assert "0 iterations" in results[0].response_text
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Add durationMs column to ChatMessage for persisting turn elapsed time.
|
||||
ALTER TABLE "ChatMessage" ADD COLUMN "durationMs" INTEGER;
|
||||
@@ -155,6 +155,7 @@ asyncio_default_fixture_loop_scope = "session"
|
||||
addopts = "-p no:syrupy"
|
||||
markers = [
|
||||
"supplementary: tests kept for coverage but superseded by integration tests",
|
||||
"integration: end-to-end tests that require a live database (skipped in CI)",
|
||||
]
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
|
||||
@@ -257,7 +257,8 @@ model ChatMessage {
|
||||
functionCall Json? // Deprecated but kept for compatibility
|
||||
|
||||
// Ordering within session
|
||||
sequence Int
|
||||
sequence Int
|
||||
durationMs Int? // Wall-clock milliseconds for this assistant turn
|
||||
|
||||
@@unique([sessionId, sequence])
|
||||
}
|
||||
|
||||
@@ -1,9 +1,20 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.google.gmail import GmailReadBlock
|
||||
from backend.blocks.google.gmail import (
|
||||
GmailForwardBlock,
|
||||
GmailReadBlock,
|
||||
HasRecipients,
|
||||
_build_reply_message,
|
||||
create_mime_message,
|
||||
validate_all_recipients,
|
||||
validate_email_recipients,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
class TestGmailReadBlock:
|
||||
@@ -250,3 +261,244 @@ class TestGmailReadBlock:
|
||||
|
||||
result = await self.gmail_block._get_email_body(msg, self.mock_service)
|
||||
assert result == "This email does not contain a readable body."
|
||||
|
||||
|
||||
class TestValidateEmailRecipients:
|
||||
"""Test cases for validate_email_recipients."""
|
||||
|
||||
def test_valid_single_email(self):
|
||||
validate_email_recipients(["user@example.com"])
|
||||
|
||||
def test_valid_multiple_emails(self):
|
||||
validate_email_recipients(["a@b.com", "x@y.org", "test@sub.domain.co"])
|
||||
|
||||
def test_invalid_missing_at(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients(["not-an-email"])
|
||||
|
||||
def test_invalid_missing_domain_dot(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients(["user@localhost"])
|
||||
|
||||
def test_invalid_empty_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients([""])
|
||||
|
||||
def test_invalid_json_object_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients(['{"email": "user@example.com"}'])
|
||||
|
||||
def test_mixed_valid_and_invalid(self):
|
||||
with pytest.raises(ValueError, match="'bad-addr'"):
|
||||
validate_email_recipients(["good@example.com", "bad-addr"])
|
||||
|
||||
def test_field_name_in_error(self):
|
||||
with pytest.raises(ValueError, match="'cc'"):
|
||||
validate_email_recipients(["nope"], field_name="cc")
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
validate_email_recipients([" user@example.com "])
|
||||
|
||||
def test_empty_list_passes(self):
|
||||
validate_email_recipients([])
|
||||
|
||||
|
||||
class TestValidateAllRecipients:
|
||||
"""Test cases for validate_all_recipients."""
|
||||
|
||||
def test_valid_all_fields(self):
|
||||
data = cast(
|
||||
HasRecipients,
|
||||
SimpleNamespace(to=["a@b.com"], cc=["c@d.com"], bcc=["e@f.com"]),
|
||||
)
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_invalid_to_raises(self):
|
||||
data = cast(HasRecipients, SimpleNamespace(to=["bad"], cc=[], bcc=[]))
|
||||
with pytest.raises(ValueError, match="'to'"):
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_invalid_cc_raises(self):
|
||||
data = cast(HasRecipients, SimpleNamespace(to=["a@b.com"], cc=["bad"], bcc=[]))
|
||||
with pytest.raises(ValueError, match="'cc'"):
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_invalid_bcc_raises(self):
|
||||
data = cast(
|
||||
HasRecipients,
|
||||
SimpleNamespace(to=["a@b.com"], cc=["c@d.com"], bcc=["bad"]),
|
||||
)
|
||||
with pytest.raises(ValueError, match="'bcc'"):
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_empty_cc_bcc_skipped(self):
|
||||
data = cast(HasRecipients, SimpleNamespace(to=["a@b.com"], cc=[], bcc=[]))
|
||||
validate_all_recipients(data)
|
||||
|
||||
|
||||
class TestCreateMimeMessageValidation:
|
||||
"""Integration tests verifying validation hooks in create_mime_message()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_to_raises_before_mime_construction(self):
|
||||
"""Invalid 'to' recipients should raise ValueError before any MIME work."""
|
||||
input_data = SimpleNamespace(
|
||||
to=["not-an-email"],
|
||||
cc=[],
|
||||
bcc=[],
|
||||
subject="Test",
|
||||
body="Hello",
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
await create_mime_message(input_data, exec_ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cc_raises_before_mime_construction(self):
|
||||
"""Invalid 'cc' recipients should raise ValueError."""
|
||||
input_data = SimpleNamespace(
|
||||
to=["valid@example.com"],
|
||||
cc=["bad-addr"],
|
||||
bcc=[],
|
||||
subject="Test",
|
||||
body="Hello",
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
with pytest.raises(ValueError, match="'cc'"):
|
||||
await create_mime_message(input_data, exec_ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_recipients_passes_validation(self):
|
||||
"""Valid recipients should not raise during validation."""
|
||||
input_data = SimpleNamespace(
|
||||
to=["user@example.com"],
|
||||
cc=["other@example.com"],
|
||||
bcc=[],
|
||||
subject="Test",
|
||||
body="Hello",
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
# Should succeed without raising
|
||||
result = await create_mime_message(input_data, exec_ctx)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestBuildReplyMessageValidation:
|
||||
"""Integration tests verifying validation hooks in _build_reply_message()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_to_raises_before_reply_construction(self):
|
||||
"""Invalid 'to' in reply should raise ValueError before MIME work."""
|
||||
mock_service = Mock()
|
||||
mock_parent = {
|
||||
"threadId": "thread-1",
|
||||
"payload": {
|
||||
"headers": [
|
||||
{"name": "Subject", "value": "Original"},
|
||||
{"name": "Message-ID", "value": "<msg@example.com>"},
|
||||
{"name": "From", "value": "sender@example.com"},
|
||||
]
|
||||
},
|
||||
}
|
||||
mock_service.users().messages().get().execute.return_value = mock_parent
|
||||
|
||||
input_data = SimpleNamespace(
|
||||
parentMessageId="msg-1",
|
||||
to=["not-valid"],
|
||||
cc=[],
|
||||
bcc=[],
|
||||
subject="",
|
||||
body="Reply body",
|
||||
replyAll=False,
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
await _build_reply_message(mock_service, input_data, exec_ctx)
|
||||
|
||||
|
||||
class TestForwardMessageValidation:
|
||||
"""Test that _forward_message() raises ValueError for invalid recipients."""
|
||||
|
||||
@staticmethod
|
||||
def _make_input(
|
||||
to: list[str] | None = None,
|
||||
cc: list[str] | None = None,
|
||||
bcc: list[str] | None = None,
|
||||
) -> "GmailForwardBlock.Input":
|
||||
mock = Mock(spec=GmailForwardBlock.Input)
|
||||
mock.messageId = "m1"
|
||||
mock.to = to or []
|
||||
mock.cc = cc or []
|
||||
mock.bcc = bcc or []
|
||||
mock.subject = ""
|
||||
mock.forwardMessage = "FYI"
|
||||
mock.includeAttachments = False
|
||||
mock.content_type = None
|
||||
mock.additionalAttachments = []
|
||||
mock.credentials = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def _exec_ctx():
|
||||
return ExecutionContext(user_id="u1", graph_exec_id="g1")
|
||||
|
||||
@staticmethod
|
||||
def _mock_service():
|
||||
"""Build a mock Gmail service that returns a parent message."""
|
||||
parent_message = {
|
||||
"id": "m1",
|
||||
"payload": {
|
||||
"headers": [
|
||||
{"name": "Subject", "value": "Original subject"},
|
||||
{"name": "From", "value": "sender@example.com"},
|
||||
{"name": "To", "value": "me@example.com"},
|
||||
{"name": "Date", "value": "Mon, 31 Mar 2026 00:00:00 +0000"},
|
||||
],
|
||||
"mimeType": "text/plain",
|
||||
"body": {
|
||||
"data": base64.urlsafe_b64encode(b"Hello world").decode(),
|
||||
},
|
||||
"parts": [],
|
||||
},
|
||||
}
|
||||
svc = Mock()
|
||||
svc.users().messages().get().execute.return_value = parent_message
|
||||
return svc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_to_raises(self):
|
||||
block = GmailForwardBlock()
|
||||
with pytest.raises(ValueError, match="Invalid email address.*'to'"):
|
||||
await block._forward_message(
|
||||
self._mock_service(),
|
||||
self._make_input(to=["bad-addr"]),
|
||||
self._exec_ctx(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cc_raises(self):
|
||||
block = GmailForwardBlock()
|
||||
with pytest.raises(ValueError, match="Invalid email address.*'cc'"):
|
||||
await block._forward_message(
|
||||
self._mock_service(),
|
||||
self._make_input(to=["valid@example.com"], cc=["not-valid"]),
|
||||
self._exec_ctx(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_bcc_raises(self):
|
||||
block = GmailForwardBlock()
|
||||
with pytest.raises(ValueError, match="Invalid email address.*'bcc'"):
|
||||
await block._forward_message(
|
||||
self._mock_service(),
|
||||
self._make_input(to=["valid@example.com"], bcc=["nope"]),
|
||||
self._exec_ctx(),
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ images: {
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
@@ -569,6 +570,10 @@ async def main():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("CI") == "true",
|
||||
reason="Data seeding test requires a dedicated database; not for CI",
|
||||
)
|
||||
async def test_main_function_runs_without_errors():
|
||||
await main()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Base stage for both dev and prod
|
||||
FROM node:21-alpine AS base
|
||||
FROM node:22.22-alpine3.23 AS base
|
||||
WORKDIR /app
|
||||
RUN corepack enable
|
||||
COPY autogpt_platform/frontend/package.json autogpt_platform/frontend/pnpm-lock.yaml ./
|
||||
@@ -33,7 +33,7 @@ ENV NEXT_PUBLIC_SOURCEMAPS="false"
|
||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
|
||||
|
||||
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
||||
FROM node:21-alpine AS prod
|
||||
FROM node:22.22-alpine3.23 AS prod
|
||||
ENV NODE_ENV=production
|
||||
ENV HOSTNAME=0.0.0.0
|
||||
WORKDIR /app
|
||||
|
||||
@@ -95,6 +95,8 @@ export function CopilotPage() {
|
||||
isDeleting,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
// Historical durations for persisted timer stats
|
||||
historicalDurations,
|
||||
// Rate limit reset
|
||||
rateLimitMessage,
|
||||
dismissRateLimit,
|
||||
@@ -186,6 +188,7 @@ export function CopilotPage() {
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -27,6 +27,8 @@ export interface ChatContainerProps {
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
/** Duration in ms for historical turns, keyed by message ID. */
|
||||
historicalDurations?: Map<string, number>;
|
||||
}
|
||||
export const ChatContainer = ({
|
||||
messages,
|
||||
@@ -44,6 +46,7 @@ export const ChatContainer = ({
|
||||
isUploadingFiles,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
historicalDurations,
|
||||
}: ChatContainerProps) => {
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
@@ -81,6 +84,7 @@ export const ChatContainer = ({
|
||||
isLoading={isLoadingSession}
|
||||
sessionID={sessionId}
|
||||
onRetry={handleRetry}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useMemo } from "react";
|
||||
import { useEffect, useMemo, useRef } from "react";
|
||||
import {
|
||||
Conversation,
|
||||
ConversationContent,
|
||||
@@ -13,6 +13,7 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
|
||||
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
|
||||
import { useElapsedTimer } from "../JobStatsBar/useElapsedTimer";
|
||||
import { CopilotPendingReviews } from "../CopilotPendingReviews/CopilotPendingReviews";
|
||||
import {
|
||||
buildRenderSegments,
|
||||
@@ -37,6 +38,7 @@ interface Props {
|
||||
isLoading: boolean;
|
||||
sessionID?: string | null;
|
||||
onRetry?: () => void;
|
||||
historicalDurations?: Map<string, number>;
|
||||
}
|
||||
|
||||
function renderSegments(
|
||||
@@ -111,6 +113,7 @@ export function ChatMessagesContainer({
|
||||
isLoading,
|
||||
sessionID,
|
||||
onRetry,
|
||||
historicalDurations,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const graphExecId = useMemo(() => extractGraphExecId(messages), [messages]);
|
||||
@@ -139,6 +142,25 @@ export function ChatMessagesContainer({
|
||||
const showThinking =
|
||||
status === "submitted" || (status === "streaming" && !hasInflight);
|
||||
|
||||
const isActivelyStreaming = status === "streaming" || status === "submitted";
|
||||
const { elapsedSeconds } = useElapsedTimer(isActivelyStreaming);
|
||||
|
||||
// Freeze elapsed time when streaming ends so TurnStatsBar shows the final value.
|
||||
// Reset when a new streaming turn begins.
|
||||
const frozenElapsedRef = useRef(0);
|
||||
const wasStreamingRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (isActivelyStreaming) {
|
||||
if (!wasStreamingRef.current) {
|
||||
frozenElapsedRef.current = 0;
|
||||
}
|
||||
if (elapsedSeconds > 0) {
|
||||
frozenElapsedRef.current = elapsedSeconds;
|
||||
}
|
||||
}
|
||||
wasStreamingRef.current = isActivelyStreaming;
|
||||
});
|
||||
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
@@ -239,10 +261,19 @@ export function ChatMessagesContainer({
|
||||
{isLastInTurn && !isCurrentlyStreaming && (
|
||||
<TurnStatsBar
|
||||
turnMessages={getTurnMessages(messages, messageIndex)}
|
||||
elapsedSeconds={
|
||||
messageIndex === messages.length - 1
|
||||
? frozenElapsedRef.current
|
||||
: undefined
|
||||
}
|
||||
durationMs={historicalDurations?.get(message.id)}
|
||||
/>
|
||||
)}
|
||||
{isLastAssistant && showThinking && (
|
||||
<ThinkingIndicator active={showThinking} />
|
||||
<ThinkingIndicator
|
||||
active={showThinking}
|
||||
elapsedSeconds={elapsedSeconds}
|
||||
/>
|
||||
)}
|
||||
</MessageContent>
|
||||
{message.role === "user" && textParts.length > 0 && (
|
||||
@@ -268,7 +299,10 @@ export function ChatMessagesContainer({
|
||||
{showThinking && lastMessage?.role !== "assistant" && (
|
||||
<Message from="assistant">
|
||||
<MessageContent className="text-[1rem] leading-relaxed">
|
||||
<ThinkingIndicator active={showThinking} />
|
||||
<ThinkingIndicator
|
||||
active={showThinking}
|
||||
elapsedSeconds={elapsedSeconds}
|
||||
/>
|
||||
</MessageContent>
|
||||
</Message>
|
||||
)}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { formatElapsed } from "../../JobStatsBar/formatElapsed";
|
||||
import { ScaleLoader } from "../../ScaleLoader/ScaleLoader";
|
||||
|
||||
const THINKING_PHRASES = [
|
||||
@@ -27,6 +28,9 @@ const THINKING_PHRASES = [
|
||||
const PHRASE_CYCLE_MS = 6_000;
|
||||
const FADE_DURATION_MS = 300;
|
||||
|
||||
/** Only show elapsed time after this many seconds. */
|
||||
const SHOW_TIME_AFTER_SECONDS = 20;
|
||||
|
||||
/**
|
||||
* Cycles through thinking phrases sequentially with a fade-out/in transition.
|
||||
* Returns the current phrase and whether it's visible (for opacity).
|
||||
@@ -72,10 +76,12 @@ function useCyclingPhrase(active: boolean) {
|
||||
|
||||
interface Props {
|
||||
active: boolean;
|
||||
elapsedSeconds: number;
|
||||
}
|
||||
|
||||
export function ThinkingIndicator({ active }: Props) {
|
||||
export function ThinkingIndicator({ active, elapsedSeconds }: Props) {
|
||||
const { phrase, visible } = useCyclingPhrase(active);
|
||||
const showTime = active && elapsedSeconds >= SHOW_TIME_AFTER_SECONDS;
|
||||
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 text-neutral-500">
|
||||
@@ -88,6 +94,11 @@ export function ThinkingIndicator({ active }: Props) {
|
||||
{phrase}
|
||||
</span>
|
||||
</span>
|
||||
{showTime && (
|
||||
<span className="animate-pulse tabular-nums [animation-duration:1.5s]">
|
||||
• {formatElapsed(elapsedSeconds)}
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,44 @@
|
||||
import type { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { formatElapsed } from "./formatElapsed";
|
||||
import { getWorkDoneCounters } from "./useWorkDoneCounters";
|
||||
|
||||
interface Props {
|
||||
turnMessages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
elapsedSeconds?: number;
|
||||
durationMs?: number;
|
||||
}
|
||||
|
||||
export function TurnStatsBar({ turnMessages }: Props) {
|
||||
export function TurnStatsBar({
|
||||
turnMessages,
|
||||
elapsedSeconds,
|
||||
durationMs,
|
||||
}: Props) {
|
||||
const { counters } = getWorkDoneCounters(turnMessages);
|
||||
|
||||
if (counters.length === 0) return null;
|
||||
// Prefer live elapsedSeconds, fall back to persisted durationMs
|
||||
const displaySeconds =
|
||||
elapsedSeconds !== undefined && elapsedSeconds > 0
|
||||
? elapsedSeconds
|
||||
: durationMs !== undefined
|
||||
? Math.round(durationMs / 1000)
|
||||
: undefined;
|
||||
|
||||
const hasTime = displaySeconds !== undefined && displaySeconds > 0;
|
||||
|
||||
if (counters.length === 0 && !hasTime) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex items-center gap-1.5">
|
||||
{hasTime && (
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
Thought for {formatElapsed(displaySeconds)}
|
||||
</span>
|
||||
)}
|
||||
{counters.map(function renderCounter(counter, index) {
|
||||
const needsDot = index > 0 || hasTime;
|
||||
return (
|
||||
<span key={counter.category} className="flex items-center gap-1">
|
||||
{index > 0 && (
|
||||
{needsDot && (
|
||||
<span className="text-xs text-neutral-300">·</span>
|
||||
)}
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
export function formatElapsed(totalSeconds: number): string {
|
||||
const minutes = Math.floor(totalSeconds / 60);
|
||||
const seconds = totalSeconds % 60;
|
||||
|
||||
if (minutes === 0) return `${seconds}s`;
|
||||
return `${minutes}m ${seconds}s`;
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
export function useElapsedTimer(isRunning: boolean) {
|
||||
const [elapsedSeconds, setElapsedSeconds] = useState(0);
|
||||
const startTimeRef = useRef<number | null>(null);
|
||||
const intervalRef = useRef<ReturnType<typeof setInterval>>();
|
||||
|
||||
useEffect(() => {
|
||||
if (isRunning) {
|
||||
if (startTimeRef.current === null) {
|
||||
startTimeRef.current = Date.now();
|
||||
setElapsedSeconds(0);
|
||||
}
|
||||
|
||||
intervalRef.current = setInterval(() => {
|
||||
if (startTimeRef.current !== null) {
|
||||
setElapsedSeconds(
|
||||
Math.floor((Date.now() - startTimeRef.current) / 1000),
|
||||
);
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
return () => clearInterval(intervalRef.current);
|
||||
}
|
||||
|
||||
clearInterval(intervalRef.current);
|
||||
startTimeRef.current = null;
|
||||
}, [isRunning]);
|
||||
|
||||
return { elapsedSeconds };
|
||||
}
|
||||
@@ -6,6 +6,7 @@ interface SessionChatMessage {
|
||||
content: string | null;
|
||||
tool_call_id: string | null;
|
||||
tool_calls: unknown[] | null;
|
||||
duration_ms: number | null;
|
||||
}
|
||||
|
||||
function coerceSessionChatMessages(
|
||||
@@ -34,6 +35,8 @@ function coerceSessionChatMessages(
|
||||
? null
|
||||
: String(msg.tool_call_id),
|
||||
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
|
||||
duration_ms:
|
||||
typeof msg.duration_ms === "number" ? msg.duration_ms : null,
|
||||
};
|
||||
})
|
||||
.filter((m): m is SessionChatMessage => m !== null);
|
||||
@@ -102,7 +105,10 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
sessionId: string,
|
||||
rawMessages: unknown[],
|
||||
options?: { isComplete?: boolean },
|
||||
): UIMessage<unknown, UIDataTypes, UITools>[] {
|
||||
): {
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
durations: Map<string, number>;
|
||||
} {
|
||||
const messages = coerceSessionChatMessages(rawMessages);
|
||||
const toolOutputsByCallId = new Map<string, unknown>();
|
||||
|
||||
@@ -114,6 +120,7 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
}
|
||||
|
||||
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
|
||||
const durations = new Map<string, number>();
|
||||
|
||||
messages.forEach((msg, index) => {
|
||||
if (msg.role === "tool") return;
|
||||
@@ -186,15 +193,24 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
const prevUI = uiMessages[uiMessages.length - 1];
|
||||
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
|
||||
prevUI.parts.push(...parts);
|
||||
// Capture duration on merged message (last assistant msg wins)
|
||||
if (msg.duration_ms != null) {
|
||||
durations.set(prevUI.id, msg.duration_ms);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const msgId = `${sessionId}-${index}`;
|
||||
uiMessages.push({
|
||||
id: `${sessionId}-${index}`,
|
||||
id: msgId,
|
||||
role: msg.role,
|
||||
parts,
|
||||
});
|
||||
|
||||
if (msg.role === "assistant" && msg.duration_ms != null) {
|
||||
durations.set(msgId, msg.duration_ms);
|
||||
}
|
||||
});
|
||||
|
||||
return uiMessages;
|
||||
return { messages: uiMessages, durations };
|
||||
}
|
||||
|
||||
@@ -61,13 +61,21 @@ export function useChatSession() {
|
||||
// array reference every render. Re-derives only when query data changes.
|
||||
// When the session is complete (no active stream), mark dangling tool
|
||||
// calls as completed so stale spinners don't persist after refresh.
|
||||
const hydratedMessages = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
|
||||
return convertChatSessionMessagesToUiMessages(
|
||||
const { hydratedMessages, historicalDurations } = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200 || !sessionId)
|
||||
return {
|
||||
hydratedMessages: undefined,
|
||||
historicalDurations: new Map<string, number>(),
|
||||
};
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
sessionId,
|
||||
sessionQuery.data.data.messages ?? [],
|
||||
{ isComplete: !hasActiveStream },
|
||||
);
|
||||
return {
|
||||
hydratedMessages: result.messages,
|
||||
historicalDurations: result.durations,
|
||||
};
|
||||
}, [sessionQuery.data, sessionId, hasActiveStream]);
|
||||
|
||||
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
||||
@@ -122,6 +130,7 @@ export function useChatSession() {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
historicalDurations,
|
||||
hasActiveStream,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
isSessionError: sessionQuery.isError,
|
||||
|
||||
@@ -39,6 +39,7 @@ export function useCopilotPage() {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
historicalDurations,
|
||||
hasActiveStream,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
@@ -377,6 +378,8 @@ export function useCopilotPage() {
|
||||
handleDeleteClick,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
// Historical durations for persisted timer stats
|
||||
historicalDurations,
|
||||
// Rate limit reset
|
||||
rateLimitMessage,
|
||||
dismissRateLimit,
|
||||
|
||||
@@ -731,6 +731,7 @@ _Add technical explanation here._
|
||||
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
||||
| ollama_host | Ollama host for local models | str | No |
|
||||
| agent_mode_max_iterations | Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit. | int | No |
|
||||
| execution_mode | How tool calls are executed. 'built_in' uses the default tool-call loop (all providers). 'extended_thinking' delegates to an external Agent SDK for richer reasoning (currently Anthropic / OpenRouter only, requires API credentials, ignores 'Agent Mode Max Iterations'). | "built_in" \| "extended_thinking" | No |
|
||||
| conversation_compaction | Automatically compact the context window once it hits the limit | bool | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
Reference in New Issue
Block a user