Compare commits
58 Commits
dev
...
test-scree
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bada09ff7e | ||
|
|
2e7bbd879d | ||
|
|
cdb2699477 | ||
|
|
7d061a0de0 | ||
|
|
fe660d9aaf | ||
|
|
84c3dd7000 | ||
|
|
4e0d6bbde5 | ||
|
|
d9c59e3616 | ||
|
|
457c2f4dca | ||
|
|
86c64ddb8e | ||
|
|
3a46fde82b | ||
|
|
25a216885f | ||
|
|
e621ecb824 | ||
|
|
c82c3c34ef | ||
|
|
97268bb22c | ||
|
|
571a4c9540 | ||
|
|
6c7cdea55c | ||
|
|
374694f64a | ||
|
|
7062ea7244 | ||
|
|
248293a1de | ||
|
|
7b15c4d350 | ||
|
|
816736826e | ||
|
|
b753cb7d0b | ||
|
|
60b6101f25 | ||
|
|
45fe984d66 | ||
|
|
da10cf6f47 | ||
|
|
5feeaaf39a | ||
|
|
09706ca8d2 | ||
|
|
18dd829a89 | ||
|
|
538e8619da | ||
|
|
ad77e881c9 | ||
|
|
49c7ab4011 | ||
|
|
927c6e7db0 | ||
|
|
114f91ff53 | ||
|
|
697b15ce81 | ||
|
|
7f986bc565 | ||
|
|
6f679a0e32 | ||
|
|
05495d8478 | ||
|
|
1a645e1e37 | ||
|
|
fd1d706315 | ||
|
|
89264091ad | ||
|
|
14ad37b0c7 | ||
|
|
389cd28879 | ||
|
|
f0a3afda7d | ||
|
|
9ffecbac02 | ||
|
|
c2709fbc28 | ||
|
|
3adbaacc0e | ||
|
|
56e0b568a4 | ||
|
|
0b0777ac87 | ||
|
|
698b1599cb | ||
|
|
a2f94f08d9 | ||
|
|
0c6f20f728 | ||
|
|
d100b2515b | ||
|
|
14113f96a9 | ||
|
|
ee40a4b9a8 | ||
|
|
0008cafc3b | ||
|
|
f55bc84fe7 | ||
|
|
3cfee4c4b5 |
@@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -111,6 +111,11 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: CopilotMode | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -840,6 +845,7 @@ async def stream_chat_post(
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
|
||||
@@ -541,3 +541,41 @@ def test_create_session_rejects_nested_metadata(
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestStreamChatRequestModeValidation:
|
||||
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
|
||||
|
||||
def test_rejects_invalid_mode_value(self) -> None:
|
||||
"""Any string outside the Literal set must raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
|
||||
|
||||
def test_accepts_fast_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="fast")
|
||||
assert req.mode == "fast"
|
||||
|
||||
def test_accepts_extended_thinking_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="extended_thinking")
|
||||
assert req.mode == "extended_thinking"
|
||||
|
||||
def test_accepts_none_mode(self) -> None:
|
||||
"""``mode=None`` is valid (server decides via feature flags)."""
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode=None)
|
||||
assert req.mode is None
|
||||
|
||||
def test_mode_defaults_to_none_when_omitted(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
@@ -18,6 +18,7 @@ import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.context import set_execution_context
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -52,6 +53,15 @@ from backend.copilot.service import (
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
download_transcript,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
@@ -74,6 +84,19 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
def _resolve_baseline_model(mode: CopilotMode | None) -> str:
|
||||
"""Pick the model for the baseline path based on the per-request mode.
|
||||
|
||||
Only ``mode='fast'`` downgrades to the cheaper/faster model. Any other
|
||||
value (including ``None`` and ``'extended_thinking'``) preserves the
|
||||
default model so that users who never select a mode don't get
|
||||
silently moved to the cheaper tier.
|
||||
"""
|
||||
if mode == "fast":
|
||||
return config.fast_model
|
||||
return config.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaselineStreamState:
|
||||
"""Mutable state shared between the tool-call loop callbacks.
|
||||
@@ -82,6 +105,7 @@ class _BaselineStreamState:
|
||||
can be module-level functions instead of deeply nested closures.
|
||||
"""
|
||||
|
||||
model: str = ""
|
||||
pending_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
assistant_text: str = ""
|
||||
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
@@ -109,7 +133,7 @@ async def _baseline_llm_caller(
|
||||
if tools:
|
||||
typed_tools = cast(list[ChatCompletionToolParam], tools)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
model=state.model,
|
||||
messages=typed_messages,
|
||||
tools=typed_tools,
|
||||
stream=True,
|
||||
@@ -117,7 +141,7 @@ async def _baseline_llm_caller(
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
model=state.model,
|
||||
messages=typed_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
@@ -279,17 +303,17 @@ async def _baseline_tool_executor(
|
||||
)
|
||||
|
||||
|
||||
def _baseline_conversation_updater(
|
||||
def _mutate_openai_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
tool_results: list[ToolCallResult] | None,
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
"""Append assistant / tool-result entries to the OpenAI message list.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
This is the side-effect boundary for the next LLM call — no transcript
|
||||
mutation happens here.
|
||||
"""
|
||||
if tool_results:
|
||||
# Build assistant message with tool_calls
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if response.response_text:
|
||||
assistant_msg["content"] = response.response_text
|
||||
@@ -310,9 +334,89 @@ def _baseline_conversation_updater(
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
elif response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
|
||||
|
||||
def _record_turn_to_transcript(
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""Append assistant + tool-result entries to the transcript builder.
|
||||
|
||||
Kept separate from :func:`_mutate_openai_messages` so the two
|
||||
concerns (next-LLM-call payload vs. durable conversation log) can
|
||||
evolve independently.
|
||||
"""
|
||||
if tool_results:
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
content_blocks.append({"type": "text", "text": response.response_text})
|
||||
for tc in response.tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.arguments) if tc.arguments else {}
|
||||
except (ValueError, TypeError, orjson.JSONDecodeError) as parse_err:
|
||||
logger.debug(
|
||||
"[Baseline] Failed to parse tool_call arguments "
|
||||
"(tool=%s, id=%s): %s",
|
||||
tc.name,
|
||||
tc.id,
|
||||
parse_err,
|
||||
)
|
||||
args = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
if content_blocks:
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=content_blocks,
|
||||
model=model,
|
||||
stop_reason=STOP_REASON_TOOL_USE,
|
||||
)
|
||||
for tr in tool_results:
|
||||
# Record tool result to transcript AFTER the assistant tool_use
|
||||
# block to maintain correct Anthropic API ordering:
|
||||
# assistant(tool_use) → user(tool_result)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=tr.tool_call_id,
|
||||
content=tr.content,
|
||||
)
|
||||
elif response.response_text:
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": response.response_text}],
|
||||
model=model,
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
|
||||
def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str = "",
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
Thin composition of :func:`_mutate_openai_messages` and
|
||||
:func:`_record_turn_to_transcript`.
|
||||
"""
|
||||
_mutate_openai_messages(messages, response, tool_results)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=transcript_builder,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
@@ -329,6 +433,7 @@ async def _update_title_async(
|
||||
|
||||
async def _compress_session_messages(
|
||||
messages: list[ChatMessage],
|
||||
model: str,
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress session messages if they exceed the model's token limit.
|
||||
|
||||
@@ -341,45 +446,178 @@ async def _compress_session_messages(
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
if msg.tool_call_id:
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
model=model,
|
||||
client=_get_openai_client(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
model=model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
"[Baseline] Context compacted: %d -> %d tokens "
|
||||
"(%d summarized, %d dropped)",
|
||||
"[Baseline] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
return [
|
||||
ChatMessage(role=m["role"], content=m.get("content"))
|
||||
ChatMessage(
|
||||
role=m["role"],
|
||||
content=m.get("content"),
|
||||
tool_calls=m.get("tool_calls"),
|
||||
tool_call_id=m.get("tool_call_id"),
|
||||
)
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
|
||||
"""Return ``True`` when a download doesn't cover the current session.
|
||||
|
||||
A transcript is stale when it has a known ``message_count`` and that
|
||||
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
|
||||
already advanced beyond what the stored transcript captures).
|
||||
Loading a stale transcript would silently drop intermediate turns,
|
||||
so callers should treat stale as "skip load, skip upload".
|
||||
|
||||
An unknown ``message_count`` (``0``) is treated as **not stale**
|
||||
because older transcripts uploaded before msg_count tracking
|
||||
existed must still be usable.
|
||||
"""
|
||||
if dl is None:
|
||||
return False
|
||||
if not dl.message_count:
|
||||
return False
|
||||
return dl.message_count < session_msg_count - 1
|
||||
|
||||
|
||||
def should_upload_transcript(
|
||||
user_id: str | None, transcript_covers_prefix: bool
|
||||
) -> bool:
|
||||
"""Return ``True`` when the caller should upload the final transcript.
|
||||
|
||||
Uploads require a logged-in user (for the storage key) *and* a
|
||||
transcript that covered the session prefix when loaded — otherwise
|
||||
we'd be overwriting a more complete version in storage with a
|
||||
partial one built from just the current turn.
|
||||
"""
|
||||
return bool(user_id) and transcript_covers_prefix
|
||||
|
||||
|
||||
async def _load_prior_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
session_msg_count: int,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
) -> bool:
|
||||
"""Download and load the prior transcript into ``transcript_builder``.
|
||||
|
||||
Returns ``True`` when the loaded transcript fully covers the session
|
||||
prefix; ``False`` otherwise (stale, missing, invalid, or download
|
||||
error). Callers should suppress uploads when this returns ``False``
|
||||
to avoid overwriting a more complete version in storage.
|
||||
"""
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
return False
|
||||
|
||||
if dl is None:
|
||||
logger.debug("[Baseline] No transcript available")
|
||||
return False
|
||||
|
||||
if not validate_transcript(dl.content):
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
return False
|
||||
|
||||
if is_transcript_stale(dl, session_msg_count):
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
session_msg_count,
|
||||
)
|
||||
return False
|
||||
|
||||
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def _upload_final_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
session_msg_count: int,
|
||||
) -> None:
|
||||
"""Serialize and upload the transcript for next-turn continuity.
|
||||
|
||||
Uses the builder's own invariants to decide whether to upload,
|
||||
avoiding a JSONL re-parse. A builder that ends with an assistant
|
||||
entry is structurally complete; a builder that doesn't (empty, or
|
||||
ends mid-turn) is skipped.
|
||||
"""
|
||||
try:
|
||||
if transcript_builder.last_entry_type != "assistant":
|
||||
logger.debug(
|
||||
"[Baseline] No complete assistant turn to upload (last_entry=%s)",
|
||||
transcript_builder.last_entry_type,
|
||||
)
|
||||
return
|
||||
content = transcript_builder.to_jsonl()
|
||||
if not content:
|
||||
logger.debug("[Baseline] Empty transcript content, skipping upload")
|
||||
return
|
||||
# Track the upload as a background task so a timeout doesn't leak an
|
||||
# orphaned coroutine; shield it so cancellation of this caller doesn't
|
||||
# abort the in-flight GCS write.
|
||||
upload_task = asyncio.create_task(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
message_count=session_msg_count,
|
||||
log_prefix="[Baseline]",
|
||||
skip_strip=True,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(upload_task)
|
||||
upload_task.add_done_callback(_background_tasks.discard)
|
||||
# Bound the wait: a hung storage backend must not block the response
|
||||
# from finishing. The task keeps running in _background_tasks on
|
||||
# timeout and will be cleaned up when it resolves.
|
||||
await asyncio.wait_for(asyncio.shield(upload_task), timeout=30)
|
||||
except Exception as upload_err:
|
||||
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
|
||||
|
||||
|
||||
async def stream_chat_completion_baseline(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Baseline LLM with tool calling via OpenAI-compatible API.
|
||||
@@ -408,6 +646,47 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Select model based on the per-request mode. 'fast' downgrades to
|
||||
# the cheaper/faster model; everything else keeps the default.
|
||||
active_model = _resolve_baseline_model(mode)
|
||||
|
||||
# --- Transcript support (feature parity with SDK path) ---
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
prompt_task = _build_system_prompt(user_id, has_conversation_history=False)
|
||||
else:
|
||||
prompt_task = _build_system_prompt(user_id=None, has_conversation_history=True)
|
||||
|
||||
# Run download + prompt build concurrently — both are independent I/O
|
||||
# on the request critical path.
|
||||
if user_id and len(session.messages) > 1:
|
||||
transcript_covers_prefix, (base_system_prompt, _) = await asyncio.gather(
|
||||
_load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
transcript_builder=transcript_builder,
|
||||
),
|
||||
prompt_task,
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await prompt_task
|
||||
|
||||
# Append user message to transcript.
|
||||
# Always append when the message is present and is from the user,
|
||||
# even on duplicate-suppressed retries (is_new_message=False).
|
||||
# The loaded transcript may be stale (uploaded before the previous
|
||||
# attempt stored this message), so skipping it would leave the
|
||||
# transcript without the user turn, creating a malformed
|
||||
# assistant-after-assistant structure when the LLM reply is added.
|
||||
if message and is_user_message:
|
||||
transcript_builder.append_user(content=message)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
@@ -422,30 +701,38 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
messages_for_context = await _compress_session_messages(
|
||||
session.messages, model=active_model
|
||||
)
|
||||
|
||||
# Build OpenAI message list from session history
|
||||
# Build OpenAI message list from session history.
|
||||
# Include tool_calls on assistant messages and tool-role results so the
|
||||
# model retains full context of what tools were invoked and their outcomes.
|
||||
openai_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
for msg in messages_for_context:
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
if msg.role == "assistant":
|
||||
entry: dict[str, Any] = {"role": "assistant"}
|
||||
if msg.content:
|
||||
entry["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
entry["tool_calls"] = msg.tool_calls
|
||||
if msg.content or msg.tool_calls:
|
||||
openai_messages.append(entry)
|
||||
elif msg.role == "tool" and msg.tool_call_id:
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"content": msg.content or "",
|
||||
}
|
||||
)
|
||||
elif msg.role == "user" and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
@@ -470,7 +757,7 @@ async def stream_chat_completion_baseline(
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
state = _BaselineStreamState()
|
||||
state = _BaselineStreamState(model=active_model)
|
||||
|
||||
# Bind extracted module-level callbacks to this request's state/session
|
||||
# using functools.partial so they satisfy the Protocol signatures.
|
||||
@@ -479,6 +766,12 @@ async def stream_chat_completion_baseline(
|
||||
_baseline_tool_executor, state=state, user_id=user_id, session=session
|
||||
)
|
||||
|
||||
_bound_conversation_updater = partial(
|
||||
_baseline_conversation_updater,
|
||||
transcript_builder=transcript_builder,
|
||||
model=active_model,
|
||||
)
|
||||
|
||||
try:
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
@@ -486,7 +779,7 @@ async def stream_chat_completion_baseline(
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_baseline_conversation_updater,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
@@ -555,10 +848,10 @@ async def stream_chat_completion_baseline(
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
estimate_token_count(openai_messages, model=active_model), 1
|
||||
)
|
||||
state.turn_completion_tokens = estimate_token_count_str(
|
||||
state.assistant_text, model=config.model
|
||||
state.assistant_text, model=active_model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
@@ -589,6 +882,27 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# --- Upload transcript for next-turn continuity ---
|
||||
# Backfill partial assistant text that wasn't recorded by the
|
||||
# conversation updater (e.g. when the stream aborted mid-round).
|
||||
# Without this, mode-switching after a failed turn would lose
|
||||
# the partial assistant response from the transcript.
|
||||
if _stream_error and state.assistant_text:
|
||||
if transcript_builder.last_entry_type != "assistant":
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": state.assistant_text}],
|
||||
model=active_model,
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
if user_id and should_upload_transcript(user_id, transcript_covers_prefix):
|
||||
await _upload_final_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
transcript_builder=transcript_builder,
|
||||
session_msg_count=len(session.messages),
|
||||
)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
|
||||
@@ -0,0 +1,367 @@
|
||||
"""Unit tests for baseline service pure-logic helpers.
|
||||
|
||||
These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState``
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_conversation_updater,
|
||||
_BaselineStreamState,
|
||||
_compress_session_messages,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.prompt import CompressResult
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
class TestBaselineStreamState:
|
||||
def test_defaults(self):
|
||||
state = _BaselineStreamState()
|
||||
assert state.pending_events == []
|
||||
assert state.assistant_text == ""
|
||||
assert state.text_started is False
|
||||
assert state.turn_prompt_tokens == 0
|
||||
assert state.turn_completion_tokens == 0
|
||||
assert state.text_block_id # Should be a UUID string
|
||||
|
||||
def test_mutable_fields(self):
|
||||
state = _BaselineStreamState()
|
||||
state.assistant_text = "hello"
|
||||
state.turn_prompt_tokens = 100
|
||||
state.turn_completion_tokens = 50
|
||||
assert state.assistant_text == "hello"
|
||||
assert state.turn_prompt_tokens == 100
|
||||
assert state.turn_completion_tokens == 50
|
||||
|
||||
|
||||
class TestBaselineConversationUpdater:
|
||||
"""Tests for _baseline_conversation_updater which updates the OpenAI
|
||||
message list and transcript builder after each LLM call."""
|
||||
|
||||
def _make_transcript_builder(self) -> TranscriptBuilder:
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("test question")
|
||||
return builder
|
||||
|
||||
def test_text_only_response(self):
|
||||
"""When the LLM returns text without tool calls, the updater appends
|
||||
a single assistant message and records it in the transcript."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Hello, world!",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Hello, world!"
|
||||
# Transcript should have user + assistant
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
def test_tool_calls_response(self):
|
||||
"""When the LLM returns tool calls, the updater appends the assistant
|
||||
message with tool_calls and tool result messages."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Let me search...",
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id="tc_1",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(
|
||||
tool_call_id="tc_1",
|
||||
tool_name="search",
|
||||
content="Found result",
|
||||
),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Messages: assistant (with tool_calls) + tool result
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Let me search..."
|
||||
assert len(messages[0]["tool_calls"]) == 1
|
||||
assert messages[0]["tool_calls"][0]["id"] == "tc_1"
|
||||
assert messages[1]["role"] == "tool"
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[1]["content"] == "Found result"
|
||||
|
||||
# Transcript: user + assistant(tool_use) + user(tool_result)
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_tool_calls_without_text(self):
|
||||
"""Tool calls without accompanying text should still work."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="run", arguments="{}"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "content" not in messages[0] # No text content
|
||||
assert messages[0]["tool_calls"][0]["function"]["name"] == "run"
|
||||
|
||||
def test_no_text_no_tools(self):
|
||||
"""When the response has no text and no tool calls, nothing is appended."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 0
|
||||
# Only the user entry from setup
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
"""Multiple tool calls in a single response are all recorded."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"),
|
||||
ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# 1 assistant + 2 tool results
|
||||
assert len(messages) == 3
|
||||
assert len(messages[0]["tool_calls"]) == 2
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[2]["tool_call_id"] == "tc_2"
|
||||
|
||||
def test_invalid_tool_arguments_handled(self):
|
||||
"""Tool call with invalid JSON arguments: the arguments field is
|
||||
stored as-is in the message, and orjson failure falls back to {}
|
||||
in the transcript content_blocks."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should not raise — invalid JSON falls back to {} in transcript
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
|
||||
|
||||
|
||||
class TestCompressSessionMessagesPreservesToolCalls:
|
||||
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
|
||||
|
||||
Compression serialises ChatMessage to dict for ``compress_context`` and
|
||||
reifies the result back to ChatMessage. A regression that drops
|
||||
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
|
||||
list and break downstream tool-execution rounds.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compressed_output_keeps_tool_calls_and_ids(self):
|
||||
# Simulate compression that returns a summary + the most recent
|
||||
# assistant(tool_call) + tool(tool_result) intact.
|
||||
summary = {"role": "system", "content": "prior turns: user asked X"}
|
||||
assistant_with_tc = {
|
||||
"role": "assistant",
|
||||
"content": "calling tool",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q":"y"}'},
|
||||
}
|
||||
],
|
||||
}
|
||||
tool_result = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "tc_abc",
|
||||
"content": "search result",
|
||||
}
|
||||
|
||||
compress_result = CompressResult(
|
||||
messages=[summary, assistant_with_tc, tool_result],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
# Input: messages that should be compressed.
|
||||
input_messages = [
|
||||
ChatMessage(role="user", content="q1"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"q":"y"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
tool_call_id="tc_abc",
|
||||
content="search result",
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=compress_result),
|
||||
):
|
||||
compressed = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
# Summary, assistant(tool_calls), tool(tool_call_id).
|
||||
assert len(compressed) == 3
|
||||
# Assistant message must keep its tool_calls intact.
|
||||
assistant_msg = compressed[1]
|
||||
assert assistant_msg.role == "assistant"
|
||||
assert assistant_msg.tool_calls is not None
|
||||
assert len(assistant_msg.tool_calls) == 1
|
||||
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
|
||||
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
|
||||
# Tool-role message must keep tool_call_id for OpenAI linkage.
|
||||
tool_msg = compressed[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.tool_call_id == "tc_abc"
|
||||
assert tool_msg.content == "search result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uncompressed_passthrough_keeps_fields(self):
|
||||
"""When compression is a no-op (was_compacted=False), the original
|
||||
messages must be returned unchanged — including tool_calls."""
|
||||
input_messages = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="c",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
|
||||
]
|
||||
|
||||
noop_result = CompressResult(
|
||||
messages=[], # ignored when was_compacted=False
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=noop_result),
|
||||
):
|
||||
out = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
assert out is input_messages # same list returned
|
||||
assert out[0].tool_calls is not None
|
||||
assert out[0].tool_calls[0]["id"] == "t1"
|
||||
assert out[1].tool_call_id == "t1"
|
||||
@@ -0,0 +1,667 @@
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
|
||||
import json as stdlib_json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
def _make_transcript_content(*roles: str) -> str:
|
||||
"""Build a minimal valid JSONL transcript from role names."""
|
||||
lines = []
|
||||
parent = ""
|
||||
for i, role in enumerate(roles):
|
||||
uid = f"uuid-{i}"
|
||||
entry: dict = {
|
||||
"type": role,
|
||||
"uuid": uid,
|
||||
"parentUuid": parent,
|
||||
"message": {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": f"{role} message {i}"}],
|
||||
},
|
||||
}
|
||||
if role == "assistant":
|
||||
entry["message"]["id"] = f"msg_{i}"
|
||||
entry["message"]["model"] = "test-model"
|
||||
entry["message"]["type"] = "message"
|
||||
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
|
||||
lines.append(stdlib_json.dumps(entry))
|
||||
parent = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
def test_fast_mode_selects_fast_model(self):
|
||||
assert _resolve_baseline_model("fast") == config.fast_model
|
||||
|
||||
def test_extended_thinking_selects_default_model(self):
|
||||
assert _resolve_baseline_model("extended_thinking") == config.model
|
||||
|
||||
def test_none_mode_selects_default_model(self):
|
||||
"""Critical: baseline users without a mode MUST keep the default (opus)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
|
||||
def test_default_and_fast_models_differ(self):
|
||||
"""Sanity: the two tiers are actually distinct in production config."""
|
||||
assert config.model != config.fast_model
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=0,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
class TestUploadFinalTranscript:
|
||||
"""``_upload_final_transcript`` serialises and calls storage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uploads_valid_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
call_kwargs = upload_mock.await_args.kwargs
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=0,
|
||||
)
|
||||
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_upload_exceptions(self):
|
||||
"""Upload failures should not propagate (flow continues for the user)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
|
||||
):
|
||||
# Should not raise.
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
|
||||
class TestRecordTurnToTranscript:
|
||||
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
|
||||
|
||||
def test_records_final_assistant_text(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text="hello there",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "hello there" in jsonl
|
||||
assert STOP_REASON_END_TURN in jsonl
|
||||
|
||||
def test_records_tool_use_then_tool_result(self):
|
||||
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="use a tool")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
|
||||
],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# user, assistant(tool_use), user(tool_result) = 3 entries
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert STOP_REASON_TOOL_USE in jsonl
|
||||
assert "tool_use" in jsonl
|
||||
assert "tool_result" in jsonl
|
||||
assert "call-1" in jsonl
|
||||
|
||||
def test_records_nothing_on_empty_response(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_malformed_tool_args_dont_crash(self):
|
||||
"""Bad JSON in tool arguments falls back to {} without raising."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""End-to-end: load prior → append new turn → upload."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# New user turn.
|
||||
builder.append_user(content="new question")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# New assistant turn.
|
||||
response = LLMLoopResponse(
|
||||
response_text="new answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
# Upload.
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
"""Backfill only runs when the last entry is not already assistant."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
# Simulate the backfill guard from stream_chat_completion_baseline.
|
||||
assistant_text = "partial text before error"
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": assistant_text}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "partial text before error" in builder.to_jsonl()
|
||||
|
||||
# Second invocation: the guard must prevent double-append.
|
||||
initial_count = builder.entry_count
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "duplicate"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
def test_upload_allowed_for_user_with_coverage(self):
|
||||
assert should_upload_transcript("user-1", True) is True
|
||||
|
||||
def test_upload_skipped_when_no_user(self):
|
||||
assert should_upload_transcript(None, True) is False
|
||||
|
||||
def test_upload_skipped_when_empty_user(self):
|
||||
assert should_upload_transcript("", True) is False
|
||||
|
||||
def test_upload_skipped_without_coverage(self):
|
||||
"""Partial transcript must never clobber a more complete stored one."""
|
||||
assert should_upload_transcript("user-1", False) is False
|
||||
|
||||
def test_upload_skipped_when_no_user_and_no_coverage(self):
|
||||
assert should_upload_transcript(None, False) is False
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
driving each step through the real helpers.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
|
||||
# --- 2. Append a new user turn + a new assistant response ---
|
||||
builder.append_user(content="follow-up question")
|
||||
_record_turn_to_transcript(
|
||||
LLMLoopResponse(
|
||||
response_text="follow-up answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
),
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=stale),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
"""Anonymous (user_id=None) → upload gate must return False."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No download: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete transcript with a single-turn snapshot.
|
||||
assert covers is False
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
@@ -8,13 +8,26 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
# Per-request routing mode for a single chat turn.
|
||||
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
|
||||
# - 'extended_thinking': route to the Claude Agent SDK path with the default
|
||||
# (opus) model.
|
||||
# ``None`` means "no override"; the server falls back to the Claude Code
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.6",
|
||||
description="Default model for extended thinking mode",
|
||||
)
|
||||
fast_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4",
|
||||
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
|
||||
@@ -13,7 +13,7 @@ import time
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.response_model import StreamError
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@@ -30,6 +30,57 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Mode Routing ============ #
|
||||
|
||||
|
||||
async def resolve_effective_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
) -> CopilotMode | None:
|
||||
"""Strip ``mode`` when the user is not entitled to the toggle.
|
||||
|
||||
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
|
||||
processor enforces the same gate server-side so an authenticated
|
||||
user cannot bypass the flag by crafting a request directly.
|
||||
"""
|
||||
if mode is None:
|
||||
return None
|
||||
allowed = await is_feature_enabled(
|
||||
Flag.CHAT_MODE_OPTION,
|
||||
user_id or "anonymous",
|
||||
default=False,
|
||||
)
|
||||
if not allowed:
|
||||
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
|
||||
return None
|
||||
return mode
|
||||
|
||||
|
||||
async def resolve_use_sdk_for_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
*,
|
||||
use_claude_code_subscription: bool,
|
||||
config_default: bool,
|
||||
) -> bool:
|
||||
"""Pick the SDK vs baseline path for a single turn.
|
||||
|
||||
Per-request ``mode`` wins whenever it is set (after the
|
||||
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
|
||||
falls back to the Claude Code subscription override, then the
|
||||
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
|
||||
"""
|
||||
if mode == "fast":
|
||||
return False
|
||||
if mode == "extended_thinking":
|
||||
return True
|
||||
return use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
user_id or "anonymous",
|
||||
default=config_default,
|
||||
)
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
@@ -250,21 +301,26 @@ class CoPilotProcessor:
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
effective_mode = None
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
# Enforce server-side feature-flag gate so unauthorised
|
||||
# users cannot force a mode by crafting the request.
|
||||
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
|
||||
use_sdk = await resolve_use_sdk_for_mode(
|
||||
effective_mode,
|
||||
entry.user_id,
|
||||
use_claude_code_subscription=config.use_claude_code_subscription,
|
||||
config_default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(
|
||||
f"Using {'SDK' if use_sdk else 'baseline'} service "
|
||||
f"(mode={effective_mode or 'default'})"
|
||||
)
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
@@ -276,6 +332,7 @@ class CoPilotProcessor:
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Unit tests for CoPilot mode routing logic in the processor.
|
||||
|
||||
Tests cover the mode→service mapping:
|
||||
- 'fast' → baseline service
|
||||
- 'extended_thinking' → SDK service
|
||||
- None → feature flag / config fallback
|
||||
|
||||
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
|
||||
the real production helpers from ``processor.py`` so the routing logic
|
||||
has meaningful coverage.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.executor.processor import (
|
||||
resolve_effective_mode,
|
||||
resolve_use_sdk_for_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveUseSdkForMode:
|
||||
"""Tests for the per-request mode routing logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_mode_uses_baseline(self):
|
||||
"""mode='fast' always routes to baseline, regardless of flags."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
"fast",
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=True,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_thinking_uses_sdk(self):
|
||||
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
"extended_thinking",
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_subscription_override(self):
|
||||
"""mode=None with claude_code_subscription=True routes to SDK."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_feature_flag(self):
|
||||
"""mode=None with feature flag enabled routes to SDK."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
) as flag_mock:
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
flag_mock.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_config_default(self):
|
||||
"""mode=None falls back to config.use_claude_agent_sdk."""
|
||||
# When LaunchDarkly returns the default (True), we expect SDK routing.
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=True,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_all_disabled(self):
|
||||
"""mode=None with all flags off routes to baseline."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
class TestResolveEffectiveMode:
|
||||
"""Tests for the CHAT_MODE_OPTION server-side gate."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_passes_through(self):
|
||||
"""mode=None is returned as-is without a flag check."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await resolve_effective_mode(None, "user-1") is None
|
||||
flag_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_stripped_when_flag_disabled(self):
|
||||
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert await resolve_effective_mode("fast", "user-1") is None
|
||||
assert await resolve_effective_mode("extended_thinking", "user-1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_preserved_when_flag_enabled(self):
|
||||
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert await resolve_effective_mode("fast", "user-1") == "fast"
|
||||
assert (
|
||||
await resolve_effective_mode("extended_thinking", "user-1")
|
||||
== "extended_thinking"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_user_with_mode(self):
|
||||
"""Anonymous users (user_id=None) still pass through the gate."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await resolve_effective_mode("fast", None) is None
|
||||
flag_mock.assert_awaited_once()
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -156,6 +157,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
mode: CopilotMode | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -175,6 +179,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -186,6 +191,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -197,6 +203,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
123
autogpt_platform/backend/backend/copilot/executor/utils_test.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tests for CoPilot executor utils (queue config, message models, logging)."""
|
||||
|
||||
from backend.copilot.executor.utils import (
|
||||
COPILOT_EXECUTION_EXCHANGE,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_ROUTING_KEY,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
CoPilotLogMetadata,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
|
||||
class TestCoPilotExecutionEntry:
|
||||
def test_basic_fields(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="hello",
|
||||
)
|
||||
assert entry.session_id == "s1"
|
||||
assert entry.user_id == "u1"
|
||||
assert entry.message == "hello"
|
||||
assert entry.is_user_message is True
|
||||
assert entry.mode is None
|
||||
assert entry.context is None
|
||||
assert entry.file_ids is None
|
||||
|
||||
def test_mode_field(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
mode="fast",
|
||||
)
|
||||
assert entry.mode == "fast"
|
||||
|
||||
entry2 = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
mode="extended_thinking",
|
||||
)
|
||||
assert entry2.mode == "extended_thinking"
|
||||
|
||||
def test_optional_fields(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
turn_id="t1",
|
||||
context={"url": "https://example.com"},
|
||||
file_ids=["f1", "f2"],
|
||||
is_user_message=False,
|
||||
)
|
||||
assert entry.turn_id == "t1"
|
||||
assert entry.context == {"url": "https://example.com"}
|
||||
assert entry.file_ids == ["f1", "f2"]
|
||||
assert entry.is_user_message is False
|
||||
|
||||
def test_serialization_roundtrip(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="hello",
|
||||
mode="fast",
|
||||
)
|
||||
json_str = entry.model_dump_json()
|
||||
restored = CoPilotExecutionEntry.model_validate_json(json_str)
|
||||
assert restored == entry
|
||||
|
||||
|
||||
class TestCancelCoPilotEvent:
|
||||
def test_basic(self):
|
||||
event = CancelCoPilotEvent(session_id="s1")
|
||||
assert event.session_id == "s1"
|
||||
|
||||
def test_serialization(self):
|
||||
event = CancelCoPilotEvent(session_id="s1")
|
||||
restored = CancelCoPilotEvent.model_validate_json(event.model_dump_json())
|
||||
assert restored.session_id == "s1"
|
||||
|
||||
|
||||
class TestCreateCopilotQueueConfig:
|
||||
def test_returns_valid_config(self):
|
||||
config = create_copilot_queue_config()
|
||||
assert len(config.exchanges) == 2
|
||||
assert len(config.queues) == 2
|
||||
|
||||
def test_execution_queue_properties(self):
|
||||
config = create_copilot_queue_config()
|
||||
exec_queue = next(
|
||||
q for q in config.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME
|
||||
)
|
||||
assert exec_queue.durable is True
|
||||
assert exec_queue.exchange == COPILOT_EXECUTION_EXCHANGE
|
||||
assert exec_queue.routing_key == COPILOT_EXECUTION_ROUTING_KEY
|
||||
|
||||
def test_cancel_queue_uses_fanout(self):
|
||||
config = create_copilot_queue_config()
|
||||
cancel_queue = next(
|
||||
q for q in config.queues if q.name != COPILOT_EXECUTION_QUEUE_NAME
|
||||
)
|
||||
assert cancel_queue.exchange is not None
|
||||
assert cancel_queue.exchange.type.value == "fanout"
|
||||
|
||||
|
||||
class TestCoPilotLogMetadata:
|
||||
def test_creates_logger_with_metadata(self):
|
||||
import logging
|
||||
|
||||
base_logger = logging.getLogger("test")
|
||||
log = CoPilotLogMetadata(base_logger, session_id="s1", user_id="u1")
|
||||
assert log is not None
|
||||
|
||||
def test_filters_none_values(self):
|
||||
import logging
|
||||
|
||||
base_logger = logging.getLogger("test")
|
||||
log = CoPilotLogMetadata(
|
||||
base_logger, session_id="s1", user_id=None, turn_id="t1"
|
||||
)
|
||||
assert log is not None
|
||||
@@ -13,12 +13,21 @@ from .rate_limit import (
|
||||
RateLimitExceeded,
|
||||
SubscriptionTier,
|
||||
UsageWindow,
|
||||
_daily_key,
|
||||
_daily_reset_time,
|
||||
_weekly_key,
|
||||
_weekly_reset_time,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
get_daily_reset_count,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
increment_daily_reset_count,
|
||||
record_token_usage,
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
|
||||
@@ -1210,3 +1219,205 @@ class TestTierLimitsEnforced:
|
||||
assert daily == biz_daily # 20x
|
||||
# Should NOT raise — usage is within the BUSINESS tier allowance
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private key/reset helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKeyHelpers:
|
||||
def test_daily_key_format(self):
|
||||
now = datetime(2026, 4, 3, 12, 0, 0, tzinfo=UTC)
|
||||
key = _daily_key("user-1", now=now)
|
||||
assert "daily" in key
|
||||
assert "user-1" in key
|
||||
assert "2026-04-03" in key
|
||||
|
||||
def test_daily_key_defaults_to_now(self):
|
||||
key = _daily_key("user-1")
|
||||
assert "daily" in key
|
||||
assert "user-1" in key
|
||||
|
||||
def test_weekly_key_format(self):
|
||||
now = datetime(2026, 4, 3, 12, 0, 0, tzinfo=UTC)
|
||||
key = _weekly_key("user-1", now=now)
|
||||
assert "weekly" in key
|
||||
assert "user-1" in key
|
||||
assert "2026-W" in key
|
||||
|
||||
def test_weekly_key_defaults_to_now(self):
|
||||
key = _weekly_key("user-1")
|
||||
assert "weekly" in key
|
||||
|
||||
def test_daily_reset_time_is_next_midnight(self):
|
||||
now = datetime(2026, 4, 3, 15, 30, 0, tzinfo=UTC)
|
||||
reset = _daily_reset_time(now=now)
|
||||
assert reset == datetime(2026, 4, 4, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_daily_reset_time_defaults_to_now(self):
|
||||
reset = _daily_reset_time()
|
||||
assert reset.hour == 0
|
||||
assert reset.minute == 0
|
||||
|
||||
def test_weekly_reset_time_is_next_monday(self):
|
||||
# 2026-04-03 is a Friday
|
||||
now = datetime(2026, 4, 3, 15, 30, 0, tzinfo=UTC)
|
||||
reset = _weekly_reset_time(now=now)
|
||||
assert reset.weekday() == 0 # Monday
|
||||
assert reset == datetime(2026, 4, 6, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_weekly_reset_time_defaults_to_now(self):
|
||||
reset = _weekly_reset_time()
|
||||
assert reset.weekday() == 0 # Monday
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# acquire_reset_lock / release_reset_lock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetLock:
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_lock_success(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
result = await acquire_reset_lock("user-1")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_lock_already_held(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
result = await acquire_reset_lock("user-1")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_lock_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=RedisError("down"),
|
||||
):
|
||||
result = await acquire_reset_lock("user-1")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_lock_success(self):
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
await release_reset_lock("user-1")
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_lock_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=RedisError("down"),
|
||||
):
|
||||
# Should not raise
|
||||
await release_reset_lock("user-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_daily_reset_count / increment_daily_reset_count
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDailyResetCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_count_returns_value(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value="3")
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
count = await get_daily_reset_count("user-1")
|
||||
assert count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_count_returns_zero_when_no_key(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
count = await get_daily_reset_count("user-1")
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_count_returns_none_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=RedisError("down"),
|
||||
):
|
||||
count = await get_daily_reset_count("user-1")
|
||||
assert count is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_count(self):
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.incr = MagicMock()
|
||||
mock_pipe.expire = MagicMock()
|
||||
mock_pipe.execute = AsyncMock()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = MagicMock(return_value=mock_pipe)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
await increment_daily_reset_count("user-1")
|
||||
mock_pipe.incr.assert_called_once()
|
||||
mock_pipe.expire.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_count_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=RedisError("down"),
|
||||
):
|
||||
# Should not raise
|
||||
await increment_daily_reset_count("user-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_user_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetUserUsage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_daily_key(self):
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
await reset_user_usage("user-1")
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_daily_and_weekly(self):
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
await reset_user_usage("user-1", reset_weekly=True)
|
||||
args = mock_redis.delete.call_args[0]
|
||||
assert len(args) == 2 # both daily and weekly keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_redis_failure(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=RedisError("down"),
|
||||
):
|
||||
with pytest.raises(RedisError):
|
||||
await reset_user_usage("user-1")
|
||||
|
||||
@@ -8,20 +8,19 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
@@ -403,7 +402,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -438,7 +437,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -462,7 +461,7 @@ class TestCompactTranscript:
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
@@ -568,11 +567,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
@@ -602,11 +601,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
|
||||
@@ -26,18 +26,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -113,7 +112,7 @@ class TestScenarioCompactAndRetry:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -170,7 +169,7 @@ class TestScenarioCompactFailsFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
@@ -261,7 +260,7 @@ class TestScenarioDoubleFailDBFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -337,7 +336,7 @@ class TestScenarioCompactionIdentical:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -730,7 +729,7 @@ class TestRetryEdgeCases:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -841,7 +840,7 @@ class TestRetryStateReset:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
@@ -1405,9 +1404,9 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
events.append(event)
|
||||
|
||||
# Should NOT retry — only 1 attempt for auth errors
|
||||
assert attempt_count[0] == 1, (
|
||||
f"Expected 1 attempt (no retry for auth error), " f"got {attempt_count[0]}"
|
||||
)
|
||||
assert (
|
||||
attempt_count[0] == 1
|
||||
), f"Expected 1 attempt (no retry for auth error), got {attempt_count[0]}"
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert errors, "Expected StreamError"
|
||||
assert errors[0].code == "sdk_stream_error"
|
||||
|
||||
@@ -34,6 +34,17 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
from backend.copilot.transcript import (
|
||||
_run_compression,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -94,17 +105,6 @@ from .tool_adapter import (
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
_run_compression,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -27,20 +27,19 @@ from backend.copilot.response_model import (
|
||||
StreamTextDelta,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic thinking block content
|
||||
@@ -439,7 +438,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -498,7 +497,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
@@ -551,7 +550,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -601,7 +600,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -638,7 +637,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -699,7 +698,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
|
||||
@@ -1,235 +1,10 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
"""Re-export from shared ``backend.copilot.transcript_builder`` for backward compat.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
The canonical implementation now lives at ``backend.copilot.transcript_builder``
|
||||
so both the SDK and baseline paths can import without cross-package
|
||||
dependencies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder, TranscriptEntry
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
isCompactSummary: bool | None = None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||
"""Parse a single transcript entry, filtering strippable types.
|
||||
|
||||
Returns ``None`` for entries that should be skipped (strippable types
|
||||
that are not compaction summaries).
|
||||
"""
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||
return None
|
||||
return TranscriptEntry(
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self._parse_entry(data)
|
||||
if entry is None:
|
||||
continue
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def replace_entries(
|
||||
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||
) -> None:
|
||||
"""Replace all entries with compacted entries from the CLI session file.
|
||||
|
||||
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||
CLI's active context (compaction summary + post-compaction entries).
|
||||
|
||||
Builds the new list first and validates it's non-empty before swapping,
|
||||
so corrupt input cannot wipe the conversation history.
|
||||
"""
|
||||
new_entries: list[TranscriptEntry] = []
|
||||
for data in compacted_entries:
|
||||
entry = self._parse_entry(data)
|
||||
if entry is not None:
|
||||
new_entries.append(entry)
|
||||
|
||||
if not new_entries:
|
||||
logger.warning(
|
||||
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||
log_prefix,
|
||||
len(compacted_entries),
|
||||
len(self._entries),
|
||||
)
|
||||
return
|
||||
|
||||
old_count = len(self._entries)
|
||||
self._entries = new_entries
|
||||
self._last_uuid = new_entries[-1].uuid
|
||||
|
||||
logger.info(
|
||||
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||
log_prefix,
|
||||
old_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
__all__ = ["TranscriptBuilder", "TranscriptEntry"]
|
||||
|
||||
@@ -303,7 +303,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -323,7 +323,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -341,7 +341,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -850,7 +850,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
@@ -858,11 +858,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
@@ -885,7 +885,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
@@ -894,11 +894,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
@@ -916,7 +916,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
@@ -932,11 +932,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
@@ -953,7 +953,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_timeout_falls_back_to_truncation(self):
|
||||
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated after timeout"}]
|
||||
@@ -970,19 +970,19 @@ class TestRunCompression:
|
||||
fake_client = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=fake_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
5,
|
||||
),
|
||||
):
|
||||
@@ -1007,7 +1007,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1039,12 +1039,12 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1062,7 +1062,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1088,7 +1088,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1114,11 +1114,11 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
@@ -1129,7 +1129,7 @@ class TestCleanupStaleProjectDirs:
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1160,12 +1160,12 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1181,7 +1181,7 @@ class TestCleanupStaleProjectDirs:
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
from .transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -33,12 +33,23 @@ _GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
_GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
|
||||
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
|
||||
|
||||
# Default OrchestratorBlock model/mode — kept in sync with ChatConfig.model.
|
||||
# ChatConfig uses the OpenRouter format ("anthropic/claude-opus-4.6");
|
||||
# OrchestratorBlock uses the native Anthropic model name.
|
||||
ORCHESTRATOR_DEFAULT_MODEL = "claude-opus-4-6"
|
||||
ORCHESTRATOR_DEFAULT_EXECUTION_MODE = "extended_thinking"
|
||||
|
||||
# Defaults applied to OrchestratorBlock nodes by the fixer.
|
||||
_SDM_DEFAULTS: dict[str, int | bool] = {
|
||||
# execution_mode and model match the copilot's default (extended thinking
|
||||
# with Opus) so generated agents inherit the same reasoning capabilities.
|
||||
# If the user explicitly sets these fields, the fixer won't override them.
|
||||
_SDM_DEFAULTS: dict[str, int | bool | str] = {
|
||||
"agent_mode_max_iterations": 10,
|
||||
"conversation_compaction": True,
|
||||
"retry": 3,
|
||||
"multiple_tool_calls": False,
|
||||
"execution_mode": ORCHESTRATOR_DEFAULT_EXECUTION_MODE,
|
||||
"model": ORCHESTRATOR_DEFAULT_MODEL,
|
||||
}
|
||||
|
||||
|
||||
@@ -1649,6 +1660,8 @@ class AgentFixer:
|
||||
2. ``conversation_compaction`` defaults to ``True``
|
||||
3. ``retry`` defaults to ``3``
|
||||
4. ``multiple_tool_calls`` defaults to ``False``
|
||||
5. ``execution_mode`` defaults to ``"extended_thinking"``
|
||||
6. ``model`` defaults to ``"claude-opus-4-6"``
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to fix
|
||||
@@ -1748,6 +1761,12 @@ class AgentFixer:
|
||||
agent = self.fix_node_x_coordinates(agent, node_lookup=node_lookup)
|
||||
agent = self.fix_getcurrentdate_offset(agent)
|
||||
|
||||
# Apply OrchestratorBlock defaults BEFORE fix_ai_model_parameter so that
|
||||
# the orchestrator-specific model (claude-opus-4-6) is set first and
|
||||
# fix_ai_model_parameter sees it as a valid allowed model instead of
|
||||
# overwriting it with the generic default (gpt-4o).
|
||||
agent = self.fix_orchestrator_blocks(agent)
|
||||
|
||||
# Apply fixes that require blocks information
|
||||
if blocks:
|
||||
agent = self.fix_invalid_nested_sink_links(
|
||||
@@ -1765,9 +1784,6 @@ class AgentFixer:
|
||||
# Apply fixes for MCPToolBlock nodes
|
||||
agent = self.fix_mcp_tool_blocks(agent)
|
||||
|
||||
# Apply fixes for OrchestratorBlock nodes (agent-mode defaults)
|
||||
agent = self.fix_orchestrator_blocks(agent)
|
||||
|
||||
# Apply fixes for AgentExecutorBlock nodes (sub-agents)
|
||||
if library_agents:
|
||||
agent = self.fix_agent_executor_blocks(agent, library_agents)
|
||||
|
||||
1247
autogpt_platform/backend/backend/copilot/transcript.py
Normal file
240
autogpt_platform/backend/backend/copilot/transcript_builder.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str = ""
|
||||
isCompactSummary: bool | None = None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||
"""Parse a single transcript entry, filtering strippable types.
|
||||
|
||||
Returns ``None`` for entries that should be skipped (strippable types
|
||||
that are not compaction summaries).
|
||||
"""
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||
return None
|
||||
return TranscriptEntry(
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid") or "",
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self._parse_entry(data)
|
||||
if entry is None:
|
||||
continue
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid or "",
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid or "",
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def replace_entries(
|
||||
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||
) -> None:
|
||||
"""Replace all entries with compacted entries from the CLI session file.
|
||||
|
||||
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||
CLI's active context (compaction summary + post-compaction entries).
|
||||
|
||||
Builds the new list first and validates it's non-empty before swapping,
|
||||
so corrupt input cannot wipe the conversation history.
|
||||
"""
|
||||
new_entries: list[TranscriptEntry] = []
|
||||
for data in compacted_entries:
|
||||
entry = self._parse_entry(data)
|
||||
if entry is not None:
|
||||
new_entries.append(entry)
|
||||
|
||||
if not new_entries:
|
||||
logger.warning(
|
||||
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||
log_prefix,
|
||||
len(compacted_entries),
|
||||
len(self._entries),
|
||||
)
|
||||
return
|
||||
|
||||
old_count = len(self._entries)
|
||||
self._entries = new_entries
|
||||
self._last_uuid = new_entries[-1].uuid
|
||||
|
||||
logger.info(
|
||||
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||
log_prefix,
|
||||
old_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
|
||||
@property
|
||||
def last_entry_type(self) -> str | None:
|
||||
"""Type of the last entry, or None if empty."""
|
||||
return self._entries[-1].type if self._entries else None
|
||||
@@ -0,0 +1,260 @@
|
||||
"""Tests for canonical TranscriptBuilder (backend.copilot.transcript_builder).
|
||||
|
||||
These tests directly import from the canonical module to ensure codecov
|
||||
patch coverage for the new file.
|
||||
"""
|
||||
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder, TranscriptEntry
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
USER_MSG = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hello"},
|
||||
}
|
||||
ASST_MSG = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "hi"}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestTranscriptEntry:
|
||||
def test_basic_construction(self):
|
||||
entry = TranscriptEntry(
|
||||
type="user", uuid="u1", message={"role": "user", "content": "hi"}
|
||||
)
|
||||
assert entry.type == "user"
|
||||
assert entry.uuid == "u1"
|
||||
assert entry.parentUuid == ""
|
||||
assert entry.isCompactSummary is None
|
||||
|
||||
def test_optional_fields(self):
|
||||
entry = TranscriptEntry(
|
||||
type="summary",
|
||||
uuid="s1",
|
||||
parentUuid="p1",
|
||||
isCompactSummary=True,
|
||||
message={"role": "user", "content": "summary"},
|
||||
)
|
||||
assert entry.isCompactSummary is True
|
||||
assert entry.parentUuid == "p1"
|
||||
|
||||
|
||||
class TestTranscriptBuilderInit:
|
||||
def test_starts_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
assert builder.is_empty
|
||||
assert builder.entry_count == 0
|
||||
assert builder.last_entry_type is None
|
||||
assert builder.to_jsonl() == ""
|
||||
|
||||
|
||||
class TestAppendUser:
|
||||
def test_appends_user_entry(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
assert builder.entry_count == 1
|
||||
assert builder.last_entry_type == "user"
|
||||
|
||||
def test_chains_parent_uuid(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("first", uuid="u1")
|
||||
builder.append_user("second", uuid="u2")
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == "u1"
|
||||
|
||||
def test_custom_uuid(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello", uuid="custom-id")
|
||||
output = builder.to_jsonl()
|
||||
entry = json.loads(output.strip())
|
||||
assert entry["uuid"] == "custom-id"
|
||||
|
||||
|
||||
class TestAppendToolResult:
|
||||
def test_appends_as_user_entry(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_tool_result(tool_use_id="tc_1", content="result text")
|
||||
assert builder.entry_count == 1
|
||||
assert builder.last_entry_type == "user"
|
||||
output = builder.to_jsonl()
|
||||
entry = json.loads(output.strip())
|
||||
content = entry["message"]["content"]
|
||||
assert len(content) == 1
|
||||
assert content[0]["type"] == "tool_result"
|
||||
assert content[0]["tool_use_id"] == "tc_1"
|
||||
assert content[0]["content"] == "result text"
|
||||
|
||||
|
||||
class TestAppendAssistant:
|
||||
def test_appends_assistant_entry(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
def test_consecutive_assistants_share_message_id(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "part 1"}],
|
||||
model="m",
|
||||
)
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "part 2"}],
|
||||
model="m",
|
||||
)
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
# The two assistant entries share the same message ID
|
||||
assert entries[1]["message"]["id"] == entries[2]["message"]["id"]
|
||||
|
||||
def test_non_consecutive_assistants_get_different_ids(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("q1")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "a1"}],
|
||||
model="m",
|
||||
)
|
||||
builder.append_user("q2")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "a2"}],
|
||||
model="m",
|
||||
)
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[1]["message"]["id"] != entries[3]["message"]["id"]
|
||||
|
||||
|
||||
class TestLoadPrevious:
|
||||
def test_loads_valid_entries(self):
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_skips_empty_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous("")
|
||||
assert builder.is_empty
|
||||
builder.load_previous(" ")
|
||||
assert builder.is_empty
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
progress = {"type": "progress", "uuid": "p1", "message": {}}
|
||||
content = _make_jsonl(USER_MSG, progress, ASST_MSG)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 2 # progress was skipped
|
||||
|
||||
def test_preserves_compact_summary(self):
|
||||
compact = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "summary"},
|
||||
}
|
||||
content = _make_jsonl(compact, ASST_MSG)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_skips_invalid_json_lines(self):
|
||||
content = '{"type":"user","uuid":"u1","message":{}}\nnot-valid-json\n'
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 1
|
||||
|
||||
|
||||
class TestToJsonl:
|
||||
def test_roundtrip(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello", uuid="u1")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "world"}],
|
||||
model="m",
|
||||
)
|
||||
output = builder.to_jsonl()
|
||||
assert output.endswith("\n")
|
||||
lines = output.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
|
||||
class TestReplaceEntries:
|
||||
def test_replaces_all_entries(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("old")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "old answer"}], model="m"
|
||||
)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
compacted = [
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "compacted"},
|
||||
}
|
||||
]
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_empty_replacement_keeps_existing(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("keep me")
|
||||
builder.replace_entries([])
|
||||
assert builder.entry_count == 1
|
||||
|
||||
|
||||
class TestParseEntry:
|
||||
def test_filters_strippable_non_compact(self):
|
||||
result = TranscriptBuilder._parse_entry(
|
||||
{"type": "progress", "uuid": "p1", "message": {}}
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_keeps_compact_summary(self):
|
||||
result = TranscriptBuilder._parse_entry(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {},
|
||||
}
|
||||
)
|
||||
assert result is not None
|
||||
assert result.isCompactSummary is True
|
||||
|
||||
def test_generates_uuid_if_missing(self):
|
||||
result = TranscriptBuilder._parse_entry(
|
||||
{"type": "user", "message": {"role": "user", "content": "hi"}}
|
||||
)
|
||||
assert result is not None
|
||||
assert result.uuid # Should be a generated UUID
|
||||
726
autogpt_platform/backend/backend/copilot/transcript_test.py
Normal file
@@ -0,0 +1,726 @@
|
||||
"""Tests for canonical transcript module (backend.copilot.transcript).
|
||||
|
||||
Covers pure helper functions that are not exercised by the SDK re-export tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
TranscriptDownload,
|
||||
_build_path_from_parts,
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_meta_storage_path_parts,
|
||||
_rechain_tail,
|
||||
_sanitize_id,
|
||||
_storage_path_parts,
|
||||
_transcript_to_messages,
|
||||
strip_for_upload,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeId:
|
||||
def test_uuid_passes_through(self):
|
||||
assert _sanitize_id("abcdef12-3456-7890-abcd-ef1234567890") == (
|
||||
"abcdef12-3456-7890-abcd-ef1234567890"
|
||||
)
|
||||
|
||||
def test_strips_non_hex_characters(self):
|
||||
# Only hex chars (0-9, a-f, A-F) and hyphens are kept
|
||||
result = _sanitize_id("abc/../../etc/passwd")
|
||||
assert "/" not in result
|
||||
assert "." not in result
|
||||
# 'p', 's', 'w' are not hex chars, so they are stripped
|
||||
assert all(c in "0123456789abcdefABCDEF-" for c in result)
|
||||
|
||||
def test_truncates_to_max_len(self):
|
||||
long_id = "a" * 100
|
||||
result = _sanitize_id(long_id, max_len=10)
|
||||
assert len(result) == 10
|
||||
|
||||
def test_empty_returns_unknown(self):
|
||||
assert _sanitize_id("") == "unknown"
|
||||
|
||||
def test_none_returns_unknown(self):
|
||||
assert _sanitize_id(None) == "unknown" # type: ignore[arg-type]
|
||||
|
||||
def test_special_chars_only_returns_unknown(self):
|
||||
assert _sanitize_id("!@#$%^&*()") == "unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _storage_path_parts / _meta_storage_path_parts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStoragePathParts:
|
||||
def test_returns_triple(self):
|
||||
prefix, uid, fname = _storage_path_parts("user-1", "sess-2")
|
||||
assert prefix == "chat-transcripts"
|
||||
assert "e" in uid # hex chars from "user-1" sanitized
|
||||
assert fname.endswith(".jsonl")
|
||||
|
||||
def test_meta_returns_meta_json(self):
|
||||
prefix, uid, fname = _meta_storage_path_parts("user-1", "sess-2")
|
||||
assert prefix == "chat-transcripts"
|
||||
assert fname.endswith(".meta.json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_path_from_parts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildPathFromParts:
|
||||
def test_gcs_backend(self):
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
mock_gcs = MagicMock(spec=GCSWorkspaceStorage)
|
||||
mock_gcs.bucket_name = "my-bucket"
|
||||
path = _build_path_from_parts(("wid", "fid", "file.jsonl"), mock_gcs)
|
||||
assert path == "gcs://my-bucket/workspaces/wid/fid/file.jsonl"
|
||||
|
||||
def test_local_backend(self):
|
||||
# Use a plain object (not MagicMock) so isinstance(GCSWorkspaceStorage) is False
|
||||
local_backend = type("LocalBackend", (), {})()
|
||||
path = _build_path_from_parts(("wid", "fid", "file.jsonl"), local_backend)
|
||||
assert path == "local://wid/fid/file.jsonl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TranscriptDownload dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscriptDownload:
|
||||
def test_defaults(self):
|
||||
td = TranscriptDownload(content="hello")
|
||||
assert td.content == "hello"
|
||||
assert td.message_count == 0
|
||||
assert td.uploaded_at == 0.0
|
||||
|
||||
def test_custom_values(self):
|
||||
td = TranscriptDownload(content="data", message_count=5, uploaded_at=123.45)
|
||||
assert td.message_count == 5
|
||||
assert td.uploaded_at == 123.45
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_thinking_blocks_stripped(self):
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "hmm..."},
|
||||
{"type": "text", "text": "answer"},
|
||||
{"type": "redacted_thinking", "data": "secret"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "answer"
|
||||
|
||||
def test_tool_use_blocks_stripped(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "I'll run a tool"},
|
||||
{"type": "tool_use", "name": "bash", "id": "tc1", "input": {}},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "I'll run a tool"
|
||||
|
||||
def test_string_blocks(self):
|
||||
blocks = ["hello", "world"]
|
||||
assert _flatten_assistant_content(blocks) == "hello\nworld"
|
||||
|
||||
def test_empty_blocks(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
def test_unknown_dict_blocks_skipped(self):
|
||||
blocks = [{"type": "image", "data": "base64..."}]
|
||||
assert _flatten_assistant_content(blocks) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text_content(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tc1",
|
||||
"content": [{"type": "text", "text": "output data"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "output data"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [
|
||||
{"type": "tool_result", "tool_use_id": "tc1", "content": "simple string"}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "simple string"
|
||||
|
||||
def test_tool_result_with_image_placeholder(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tc1",
|
||||
"content": [{"type": "image", "data": "base64..."}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "[__image__]"
|
||||
|
||||
def test_tool_result_with_document_placeholder(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tc1",
|
||||
"content": [{"type": "document", "data": "base64..."}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "[__document__]"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "tc1", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_text_block_outside_tool_result(self):
|
||||
blocks = [{"type": "text", "text": "standalone"}]
|
||||
assert _flatten_tool_result_content(blocks) == "standalone"
|
||||
|
||||
def test_unknown_dict_block_placeholder(self):
|
||||
blocks = [{"type": "custom_widget", "data": "x"}]
|
||||
assert _flatten_tool_result_content(blocks) == "[__custom_widget__]"
|
||||
|
||||
def test_string_blocks(self):
|
||||
blocks = ["raw text"]
|
||||
assert _flatten_tool_result_content(blocks) == "raw text"
|
||||
|
||||
def test_empty_blocks(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_mixed_content_in_tool_result(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tc1",
|
||||
"content": [
|
||||
{"type": "text", "text": "line1"},
|
||||
{"type": "image", "data": "..."},
|
||||
"raw string",
|
||||
],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "line1" in result
|
||||
assert "[__image__]" in result
|
||||
assert "raw string" in result
|
||||
|
||||
def test_tool_result_with_dict_without_text_key(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tc1",
|
||||
"content": [{"count": 42}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "42" in result
|
||||
|
||||
def test_tool_result_content_list_with_list_content(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tc1",
|
||||
"content": [{"type": "text", "text": None}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert result == "None"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
USER_ENTRY = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hello"},
|
||||
}
|
||||
ASST_ENTRY = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"content": [{"type": "text", "text": "hi there"}],
|
||||
},
|
||||
}
|
||||
PROGRESS_ENTRY = {
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"parentUuid": "u1",
|
||||
"data": {},
|
||||
}
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_conversion(self):
|
||||
content = _make_jsonl(USER_ENTRY, ASST_ENTRY)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "hello"}
|
||||
assert messages[1]["role"] == "assistant"
|
||||
assert messages[1]["content"] == "hi there"
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
content = _make_jsonl(USER_ENTRY, PROGRESS_ENTRY, ASST_ENTRY)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_skips_entries_without_role(self):
|
||||
no_role = {"type": "user", "uuid": "x", "message": {"content": "no role"}}
|
||||
content = _make_jsonl(no_role)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_handles_string_content(self):
|
||||
entry = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "plain string"},
|
||||
}
|
||||
content = _make_jsonl(entry)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "plain string"
|
||||
|
||||
def test_handles_tool_result_content(self):
|
||||
entry = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "tc1", "content": "output"}
|
||||
],
|
||||
},
|
||||
}
|
||||
content = _make_jsonl(entry)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "output"
|
||||
|
||||
def test_handles_none_content(self):
|
||||
entry = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "content": None},
|
||||
}
|
||||
content = _make_jsonl(entry)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == ""
|
||||
|
||||
def test_skips_invalid_json(self):
|
||||
content = "not valid json\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_preserves_compact_summary(self):
|
||||
compact = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "summary of conversation"},
|
||||
}
|
||||
content = _make_jsonl(compact)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_strips_summary_without_compact_flag(self):
|
||||
summary = {
|
||||
"type": "summary",
|
||||
"uuid": "s1",
|
||||
"message": {"role": "user", "content": "summary"},
|
||||
}
|
||||
content = _make_jsonl(summary)
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_basic_roundtrip(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert result.endswith("\n")
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
|
||||
user_entry = json.loads(lines[0])
|
||||
assert user_entry["type"] == "user"
|
||||
assert user_entry["message"]["role"] == "user"
|
||||
assert user_entry["message"]["content"] == "hello"
|
||||
assert user_entry["parentUuid"] == ""
|
||||
|
||||
asst_entry = json.loads(lines[1])
|
||||
assert asst_entry["type"] == "assistant"
|
||||
assert asst_entry["message"]["role"] == "assistant"
|
||||
assert asst_entry["message"]["content"] == [{"type": "text", "text": "world"}]
|
||||
assert asst_entry["parentUuid"] == user_entry["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_assistant_has_message_envelope(self):
|
||||
messages = [{"role": "assistant", "content": "test"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
msg = entry["message"]
|
||||
assert "id" in msg
|
||||
assert msg["id"].startswith("msg_compact_")
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert msg["stop_sequence"] is None
|
||||
|
||||
def test_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "a"},
|
||||
{"role": "assistant", "content": "b"},
|
||||
{"role": "user", "content": "c"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_assistant_with_empty_content(self):
|
||||
messages = [{"role": "assistant", "content": ""}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["message"]["content"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_last_assistant_entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindLastAssistantEntry:
|
||||
def test_splits_at_last_assistant(self):
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "id": "msg1", "content": "answer"},
|
||||
}
|
||||
content = _make_jsonl(user, asst)
|
||||
prefix, tail = _find_last_assistant_entry(content)
|
||||
assert len(prefix) == 1
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_no_assistant_returns_all_in_prefix(self):
|
||||
user1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
user2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "hey"},
|
||||
}
|
||||
content = _make_jsonl(user1, user2)
|
||||
prefix, tail = _find_last_assistant_entry(content)
|
||||
assert len(prefix) == 2
|
||||
assert len(tail) == 0
|
||||
|
||||
def test_multi_entry_turn_preserved(self):
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "q"},
|
||||
}
|
||||
asst1 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_turn",
|
||||
"content": [{"type": "thinking", "thinking": "hmm"}],
|
||||
},
|
||||
}
|
||||
asst2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_turn",
|
||||
"content": [{"type": "text", "text": "answer"}],
|
||||
},
|
||||
}
|
||||
content = _make_jsonl(user, asst1, asst2)
|
||||
prefix, tail = _find_last_assistant_entry(content)
|
||||
assert len(prefix) == 1 # just the user
|
||||
assert len(tail) == 2 # both assistant entries
|
||||
|
||||
def test_assistant_without_id(self):
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "q"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "content": "no id"},
|
||||
}
|
||||
content = _make_jsonl(user, asst)
|
||||
prefix, tail = _find_last_assistant_entry(content)
|
||||
assert len(prefix) == 1
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_trailing_user_after_assistant(self):
|
||||
user1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "q"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "id": "msg1", "content": "a"},
|
||||
}
|
||||
user2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "follow"},
|
||||
}
|
||||
content = _make_jsonl(user1, asst, user2)
|
||||
prefix, tail = _find_last_assistant_entry(content)
|
||||
assert len(prefix) == 1 # user1
|
||||
assert len(tail) == 2 # asst + user2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _rechain_tail
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRechainTail:
|
||||
def test_empty_tail(self):
|
||||
assert _rechain_tail("some prefix\n", []) == ""
|
||||
|
||||
def test_patches_first_entry_parent(self):
|
||||
prefix_entry = {"uuid": "last-prefix-uuid", "type": "user", "message": {}}
|
||||
prefix = json.dumps(prefix_entry) + "\n"
|
||||
|
||||
tail_entry = {
|
||||
"uuid": "t1",
|
||||
"parentUuid": "old-parent",
|
||||
"type": "assistant",
|
||||
"message": {},
|
||||
}
|
||||
tail_lines = [json.dumps(tail_entry)]
|
||||
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
parsed = json.loads(result.strip())
|
||||
assert parsed["parentUuid"] == "last-prefix-uuid"
|
||||
|
||||
def test_chains_consecutive_tail_entries(self):
|
||||
prefix_entry = {"uuid": "p1", "type": "user", "message": {}}
|
||||
prefix = json.dumps(prefix_entry) + "\n"
|
||||
|
||||
t1 = {"uuid": "t1", "parentUuid": "old1", "type": "assistant", "message": {}}
|
||||
t2 = {"uuid": "t2", "parentUuid": "old2", "type": "user", "message": {}}
|
||||
tail_lines = [json.dumps(t1), json.dumps(t2)]
|
||||
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entries = [json.loads(line) for line in result.strip().split("\n")]
|
||||
assert entries[0]["parentUuid"] == "p1"
|
||||
assert entries[1]["parentUuid"] == "t1"
|
||||
|
||||
def test_non_dict_lines_passed_through(self):
|
||||
prefix_entry = {"uuid": "p1", "type": "user", "message": {}}
|
||||
prefix = json.dumps(prefix_entry) + "\n"
|
||||
|
||||
tail_lines = ["not-a-json-dict"]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
assert "not-a-json-dict" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# strip_for_upload (combined single-parse)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStripForUpload:
|
||||
def test_strips_progress_and_thinking(self):
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1", "data": {}}
|
||||
asst_old = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_old",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "stale thinking"},
|
||||
{"type": "text", "text": "old answer"},
|
||||
],
|
||||
},
|
||||
}
|
||||
user2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1",
|
||||
"message": {"role": "user", "content": "next"},
|
||||
}
|
||||
asst_new = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_new",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "fresh thinking"},
|
||||
{"type": "text", "text": "new answer"},
|
||||
],
|
||||
},
|
||||
}
|
||||
content = _make_jsonl(user, progress, asst_old, user2, asst_new)
|
||||
result = strip_for_upload(content)
|
||||
|
||||
lines = result.strip().split("\n")
|
||||
# Progress should be stripped -> 4 entries remain
|
||||
assert len(lines) == 4
|
||||
|
||||
# First entry (user) should be reparented since its child (progress) was stripped
|
||||
entries = [json.loads(line) for line in lines]
|
||||
types = [e.get("type") for e in entries]
|
||||
assert "progress" not in types
|
||||
|
||||
# Old assistant thinking stripped, new assistant thinking preserved
|
||||
old_asst = next(
|
||||
e for e in entries if e.get("message", {}).get("id") == "msg_old"
|
||||
)
|
||||
old_content = old_asst["message"]["content"]
|
||||
old_types = [b["type"] for b in old_content if isinstance(b, dict)]
|
||||
assert "thinking" not in old_types
|
||||
assert "text" in old_types
|
||||
|
||||
new_asst = next(
|
||||
e for e in entries if e.get("message", {}).get("id") == "msg_new"
|
||||
)
|
||||
new_content = new_asst["message"]["content"]
|
||||
new_types = [b["type"] for b in new_content if isinstance(b, dict)]
|
||||
assert "thinking" in new_types # last assistant preserved
|
||||
|
||||
def test_empty_content(self):
|
||||
result = strip_for_upload("")
|
||||
# Empty string produces a single empty line after split, resulting in "\n"
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_preserves_compact_summary(self):
|
||||
compact = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "summary"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "id": "msg1", "content": "answer"},
|
||||
}
|
||||
content = _make_jsonl(compact, asst)
|
||||
result = strip_for_upload(content)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
|
||||
def test_no_assistant_entries(self):
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
content = _make_jsonl(user)
|
||||
result = strip_for_upload(content)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_transcript (additional edge cases)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateTranscript:
|
||||
def test_valid_with_assistant(self):
|
||||
content = _make_jsonl(
|
||||
USER_ENTRY,
|
||||
ASST_ENTRY,
|
||||
)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_none_returns_false(self):
|
||||
assert validate_transcript(None) is False
|
||||
|
||||
def test_whitespace_only_returns_false(self):
|
||||
assert validate_transcript(" \n ") is False
|
||||
|
||||
def test_no_assistant_returns_false(self):
|
||||
content = _make_jsonl(USER_ENTRY)
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_invalid_json_returns_false(self):
|
||||
assert validate_transcript("not json\n") is False
|
||||
|
||||
def test_assistant_only_is_valid(self):
|
||||
content = _make_jsonl(ASST_ENTRY)
|
||||
assert validate_transcript(content) is True
|
||||
@@ -38,6 +38,7 @@ class Flag(str, Enum):
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
CHAT_MODE_OPTION = "chat-mode-option"
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit"
|
||||
COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit"
|
||||
|
||||
@@ -140,7 +140,9 @@ class TestFixOrchestratorBlocks:
|
||||
assert defaults["conversation_compaction"] is True
|
||||
assert defaults["retry"] == 3
|
||||
assert defaults["multiple_tool_calls"] is False
|
||||
assert len(fixer.fixes_applied) == 4
|
||||
assert defaults["execution_mode"] == "extended_thinking"
|
||||
assert defaults["model"] == "claude-opus-4-6"
|
||||
assert len(fixer.fixes_applied) == 6
|
||||
|
||||
def test_preserves_existing_values(self):
|
||||
"""Existing user-set values are never overwritten."""
|
||||
@@ -153,6 +155,8 @@ class TestFixOrchestratorBlocks:
|
||||
"conversation_compaction": False,
|
||||
"retry": 1,
|
||||
"multiple_tool_calls": True,
|
||||
"execution_mode": "built_in",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
)
|
||||
],
|
||||
@@ -166,6 +170,8 @@ class TestFixOrchestratorBlocks:
|
||||
assert defaults["conversation_compaction"] is False
|
||||
assert defaults["retry"] == 1
|
||||
assert defaults["multiple_tool_calls"] is True
|
||||
assert defaults["execution_mode"] == "built_in"
|
||||
assert defaults["model"] == "gpt-4o"
|
||||
assert len(fixer.fixes_applied) == 0
|
||||
|
||||
def test_partial_defaults(self):
|
||||
@@ -189,7 +195,9 @@ class TestFixOrchestratorBlocks:
|
||||
assert defaults["conversation_compaction"] is True # filled
|
||||
assert defaults["retry"] == 3 # filled
|
||||
assert defaults["multiple_tool_calls"] is False # filled
|
||||
assert len(fixer.fixes_applied) == 3
|
||||
assert defaults["execution_mode"] == "extended_thinking" # filled
|
||||
assert defaults["model"] == "claude-opus-4-6" # filled
|
||||
assert len(fixer.fixes_applied) == 5
|
||||
|
||||
def test_skips_non_sdm_nodes(self):
|
||||
"""Non-Orchestrator nodes are untouched."""
|
||||
@@ -258,11 +266,13 @@ class TestFixOrchestratorBlocks:
|
||||
result = fixer.fix_orchestrator_blocks(agent)
|
||||
|
||||
defaults = result["nodes"][0]["input_default"]
|
||||
assert defaults["agent_mode_max_iterations"] == 10 # None → default
|
||||
assert defaults["conversation_compaction"] is True # None → default
|
||||
assert defaults["agent_mode_max_iterations"] == 10 # None -> default
|
||||
assert defaults["conversation_compaction"] is True # None -> default
|
||||
assert defaults["retry"] == 3 # kept
|
||||
assert defaults["multiple_tool_calls"] is False # kept
|
||||
assert len(fixer.fixes_applied) == 2
|
||||
assert defaults["execution_mode"] == "extended_thinking" # filled
|
||||
assert defaults["model"] == "claude-opus-4-6" # filled
|
||||
assert len(fixer.fixes_applied) == 4
|
||||
|
||||
def test_multiple_sdm_nodes(self):
|
||||
"""Multiple SDM nodes are all fixed independently."""
|
||||
@@ -277,11 +287,11 @@ class TestFixOrchestratorBlocks:
|
||||
|
||||
result = fixer.fix_orchestrator_blocks(agent)
|
||||
|
||||
# First node: 3 defaults filled (agent_mode was already set)
|
||||
# First node: 5 defaults filled (agent_mode was already set)
|
||||
assert result["nodes"][0]["input_default"]["agent_mode_max_iterations"] == 3
|
||||
# Second node: all 4 defaults filled
|
||||
# Second node: all 6 defaults filled
|
||||
assert result["nodes"][1]["input_default"]["agent_mode_max_iterations"] == 10
|
||||
assert len(fixer.fixes_applied) == 7 # 3 + 4
|
||||
assert len(fixer.fixes_applied) == 11 # 5 + 6
|
||||
|
||||
def test_registered_in_apply_all_fixes(self):
|
||||
"""fix_orchestrator_blocks runs as part of apply_all_fixes."""
|
||||
@@ -655,6 +665,7 @@ class TestOrchestratorE2EPipeline:
|
||||
"conversation_compaction": {"type": "boolean"},
|
||||
"retry": {"type": "integer"},
|
||||
"multiple_tool_calls": {"type": "boolean"},
|
||||
"execution_mode": {"type": "string"},
|
||||
},
|
||||
"required": ["prompt"],
|
||||
},
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import { useOnboardingWizardStore } from "../store";
|
||||
|
||||
beforeEach(() => {
|
||||
useOnboardingWizardStore.getState().reset();
|
||||
});
|
||||
|
||||
describe("useOnboardingWizardStore", () => {
|
||||
describe("initial state", () => {
|
||||
it("starts at step 1 with empty fields", () => {
|
||||
const state = useOnboardingWizardStore.getState();
|
||||
expect(state.currentStep).toBe(1);
|
||||
expect(state.name).toBe("");
|
||||
expect(state.role).toBe("");
|
||||
expect(state.otherRole).toBe("");
|
||||
expect(state.painPoints).toEqual([]);
|
||||
expect(state.otherPainPoint).toBe("");
|
||||
});
|
||||
});
|
||||
|
||||
describe("setName", () => {
|
||||
it("updates the name", () => {
|
||||
useOnboardingWizardStore.getState().setName("Alice");
|
||||
expect(useOnboardingWizardStore.getState().name).toBe("Alice");
|
||||
});
|
||||
});
|
||||
|
||||
describe("setRole", () => {
|
||||
it("updates the role", () => {
|
||||
useOnboardingWizardStore.getState().setRole("Engineer");
|
||||
expect(useOnboardingWizardStore.getState().role).toBe("Engineer");
|
||||
});
|
||||
});
|
||||
|
||||
describe("setOtherRole", () => {
|
||||
it("updates the other role text", () => {
|
||||
useOnboardingWizardStore.getState().setOtherRole("Designer");
|
||||
expect(useOnboardingWizardStore.getState().otherRole).toBe("Designer");
|
||||
});
|
||||
});
|
||||
|
||||
describe("togglePainPoint", () => {
|
||||
it("adds a pain point", () => {
|
||||
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
|
||||
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
|
||||
"slow builds",
|
||||
]);
|
||||
});
|
||||
|
||||
it("removes a pain point when toggled again", () => {
|
||||
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
|
||||
expect(useOnboardingWizardStore.getState().painPoints).toEqual([]);
|
||||
});
|
||||
|
||||
it("handles multiple pain points", () => {
|
||||
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("no tests");
|
||||
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
|
||||
"slow builds",
|
||||
"no tests",
|
||||
]);
|
||||
|
||||
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
|
||||
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
|
||||
"no tests",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setOtherPainPoint", () => {
|
||||
it("updates the other pain point text", () => {
|
||||
useOnboardingWizardStore.getState().setOtherPainPoint("flaky CI");
|
||||
expect(useOnboardingWizardStore.getState().otherPainPoint).toBe(
|
||||
"flaky CI",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("nextStep", () => {
|
||||
it("increments the step", () => {
|
||||
useOnboardingWizardStore.getState().nextStep();
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
|
||||
});
|
||||
|
||||
it("clamps at step 4", () => {
|
||||
useOnboardingWizardStore.getState().goToStep(4);
|
||||
useOnboardingWizardStore.getState().nextStep();
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(4);
|
||||
});
|
||||
});
|
||||
|
||||
describe("prevStep", () => {
|
||||
it("decrements the step", () => {
|
||||
useOnboardingWizardStore.getState().goToStep(3);
|
||||
useOnboardingWizardStore.getState().prevStep();
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
|
||||
});
|
||||
|
||||
it("clamps at step 1", () => {
|
||||
useOnboardingWizardStore.getState().prevStep();
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("goToStep", () => {
|
||||
it("jumps to an arbitrary step", () => {
|
||||
useOnboardingWizardStore.getState().goToStep(3);
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe("reset", () => {
|
||||
it("resets all fields to defaults", () => {
|
||||
useOnboardingWizardStore.getState().setName("Alice");
|
||||
useOnboardingWizardStore.getState().setRole("Engineer");
|
||||
useOnboardingWizardStore.getState().setOtherRole("Other");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
|
||||
useOnboardingWizardStore.getState().setOtherPainPoint("flaky CI");
|
||||
useOnboardingWizardStore.getState().goToStep(3);
|
||||
|
||||
useOnboardingWizardStore.getState().reset();
|
||||
|
||||
const state = useOnboardingWizardStore.getState();
|
||||
expect(state.currentStep).toBe(1);
|
||||
expect(state.name).toBe("");
|
||||
expect(state.role).toBe("");
|
||||
expect(state.otherRole).toBe("");
|
||||
expect(state.painPoints).toEqual([]);
|
||||
expect(state.otherPainPoint).toBe("");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,221 @@
|
||||
import { describe, expect, it, beforeEach, vi } from "vitest";
|
||||
import { useCopilotUIStore } from "../store";
|
||||
|
||||
vi.mock("@sentry/nextjs", () => ({
|
||||
captureException: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/services/environment", () => ({
|
||||
environment: {
|
||||
isServerSide: vi.fn(() => false),
|
||||
},
|
||||
}));
|
||||
|
||||
describe("useCopilotUIStore", () => {
|
||||
beforeEach(() => {
|
||||
window.localStorage.clear();
|
||||
useCopilotUIStore.setState({
|
||||
initialPrompt: null,
|
||||
sessionToDelete: null,
|
||||
isDrawerOpen: false,
|
||||
completedSessionIDs: new Set<string>(),
|
||||
isNotificationsEnabled: false,
|
||||
isSoundEnabled: true,
|
||||
showNotificationDialog: false,
|
||||
copilotMode: "extended_thinking",
|
||||
});
|
||||
});
|
||||
|
||||
describe("initialPrompt", () => {
|
||||
it("starts as null", () => {
|
||||
expect(useCopilotUIStore.getState().initialPrompt).toBeNull();
|
||||
});
|
||||
|
||||
it("sets and clears prompt", () => {
|
||||
useCopilotUIStore.getState().setInitialPrompt("Hello");
|
||||
expect(useCopilotUIStore.getState().initialPrompt).toBe("Hello");
|
||||
|
||||
useCopilotUIStore.getState().setInitialPrompt(null);
|
||||
expect(useCopilotUIStore.getState().initialPrompt).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("sessionToDelete", () => {
|
||||
it("starts as null", () => {
|
||||
expect(useCopilotUIStore.getState().sessionToDelete).toBeNull();
|
||||
});
|
||||
|
||||
it("sets and clears a delete target", () => {
|
||||
useCopilotUIStore
|
||||
.getState()
|
||||
.setSessionToDelete({ id: "abc", title: "Test" });
|
||||
expect(useCopilotUIStore.getState().sessionToDelete).toEqual({
|
||||
id: "abc",
|
||||
title: "Test",
|
||||
});
|
||||
|
||||
useCopilotUIStore.getState().setSessionToDelete(null);
|
||||
expect(useCopilotUIStore.getState().sessionToDelete).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("drawer", () => {
|
||||
it("starts closed", () => {
|
||||
expect(useCopilotUIStore.getState().isDrawerOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("opens and closes", () => {
|
||||
useCopilotUIStore.getState().setDrawerOpen(true);
|
||||
expect(useCopilotUIStore.getState().isDrawerOpen).toBe(true);
|
||||
|
||||
useCopilotUIStore.getState().setDrawerOpen(false);
|
||||
expect(useCopilotUIStore.getState().isDrawerOpen).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("completedSessionIDs", () => {
|
||||
it("starts empty", () => {
|
||||
expect(useCopilotUIStore.getState().completedSessionIDs.size).toBe(0);
|
||||
});
|
||||
|
||||
it("adds a completed session", () => {
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
expect(useCopilotUIStore.getState().completedSessionIDs.has("s1")).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it("persists added sessions to localStorage", () => {
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
useCopilotUIStore.getState().addCompletedSession("s2");
|
||||
const raw = window.localStorage.getItem("copilot-completed-sessions");
|
||||
expect(raw).not.toBeNull();
|
||||
const parsed = JSON.parse(raw!) as string[];
|
||||
expect(parsed).toContain("s1");
|
||||
expect(parsed).toContain("s2");
|
||||
});
|
||||
|
||||
it("clears a single completed session", () => {
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
useCopilotUIStore.getState().addCompletedSession("s2");
|
||||
useCopilotUIStore.getState().clearCompletedSession("s1");
|
||||
expect(useCopilotUIStore.getState().completedSessionIDs.has("s1")).toBe(
|
||||
false,
|
||||
);
|
||||
expect(useCopilotUIStore.getState().completedSessionIDs.has("s2")).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it("updates localStorage when a session is cleared", () => {
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
useCopilotUIStore.getState().addCompletedSession("s2");
|
||||
useCopilotUIStore.getState().clearCompletedSession("s1");
|
||||
const raw = window.localStorage.getItem("copilot-completed-sessions");
|
||||
const parsed = JSON.parse(raw!) as string[];
|
||||
expect(parsed).not.toContain("s1");
|
||||
expect(parsed).toContain("s2");
|
||||
});
|
||||
|
||||
it("clears all completed sessions", () => {
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
useCopilotUIStore.getState().addCompletedSession("s2");
|
||||
useCopilotUIStore.getState().clearAllCompletedSessions();
|
||||
expect(useCopilotUIStore.getState().completedSessionIDs.size).toBe(0);
|
||||
});
|
||||
|
||||
it("removes localStorage key when all sessions are cleared", () => {
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
useCopilotUIStore.getState().clearAllCompletedSessions();
|
||||
expect(
|
||||
window.localStorage.getItem("copilot-completed-sessions"),
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("sound toggle", () => {
|
||||
it("starts enabled", () => {
|
||||
expect(useCopilotUIStore.getState().isSoundEnabled).toBe(true);
|
||||
});
|
||||
|
||||
it("toggles sound off and on", () => {
|
||||
useCopilotUIStore.getState().toggleSound();
|
||||
expect(useCopilotUIStore.getState().isSoundEnabled).toBe(false);
|
||||
|
||||
useCopilotUIStore.getState().toggleSound();
|
||||
expect(useCopilotUIStore.getState().isSoundEnabled).toBe(true);
|
||||
});
|
||||
|
||||
it("persists to localStorage", () => {
|
||||
useCopilotUIStore.getState().toggleSound();
|
||||
expect(window.localStorage.getItem("copilot-sound-enabled")).toBe(
|
||||
"false",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("copilotMode", () => {
|
||||
it("defaults to extended_thinking", () => {
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe(
|
||||
"extended_thinking",
|
||||
);
|
||||
});
|
||||
|
||||
it("sets mode to fast", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe("fast");
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBe("fast");
|
||||
});
|
||||
|
||||
it("sets mode back to extended_thinking", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotMode("extended_thinking");
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe(
|
||||
"extended_thinking",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearCopilotLocalData", () => {
|
||||
it("resets state and clears localStorage keys", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
useCopilotUIStore.getState().setNotificationsEnabled(true);
|
||||
useCopilotUIStore.getState().toggleSound();
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
|
||||
useCopilotUIStore.getState().clearCopilotLocalData();
|
||||
|
||||
const state = useCopilotUIStore.getState();
|
||||
expect(state.copilotMode).toBe("extended_thinking");
|
||||
expect(state.isNotificationsEnabled).toBe(false);
|
||||
expect(state.isSoundEnabled).toBe(true);
|
||||
expect(state.completedSessionIDs.size).toBe(0);
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
|
||||
expect(
|
||||
window.localStorage.getItem("copilot-notifications-enabled"),
|
||||
).toBeNull();
|
||||
expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull();
|
||||
expect(
|
||||
window.localStorage.getItem("copilot-completed-sessions"),
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("notifications", () => {
|
||||
it("sets notification preference", () => {
|
||||
useCopilotUIStore.getState().setNotificationsEnabled(true);
|
||||
expect(useCopilotUIStore.getState().isNotificationsEnabled).toBe(true);
|
||||
expect(window.localStorage.getItem("copilot-notifications-enabled")).toBe(
|
||||
"true",
|
||||
);
|
||||
});
|
||||
|
||||
it("shows and hides notification dialog", () => {
|
||||
useCopilotUIStore.getState().setShowNotificationDialog(true);
|
||||
expect(useCopilotUIStore.getState().showNotificationDialog).toBe(true);
|
||||
|
||||
useCopilotUIStore.getState().setShowNotificationDialog(false);
|
||||
expect(useCopilotUIStore.getState().showNotificationDialog).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,17 +5,21 @@ import {
|
||||
PromptInputTextarea,
|
||||
PromptInputTools,
|
||||
} from "@/components/ai-elements/prompt-input";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { InputGroup } from "@/components/ui/input-group";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { ChangeEvent, useEffect, useState } from "react";
|
||||
import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { ModeToggleButton } from "./components/ModeToggleButton";
|
||||
import { RecordingButton } from "./components/RecordingButton";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
import { useVoiceRecording } from "./useVoiceRecording";
|
||||
|
||||
export interface Props {
|
||||
interface Props {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
disabled?: boolean;
|
||||
isStreaming?: boolean;
|
||||
@@ -42,8 +46,26 @@ export function ChatInput({
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: Props) {
|
||||
const { copilotMode, setCopilotMode } = useCopilotUIStore();
|
||||
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
function handleToggleMode() {
|
||||
const next =
|
||||
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
|
||||
setCopilotMode(next);
|
||||
toast({
|
||||
title:
|
||||
next === "fast"
|
||||
? "Switched to Fast mode"
|
||||
: "Switched to Extended Thinking mode",
|
||||
description:
|
||||
next === "fast"
|
||||
? "Response quality may differ."
|
||||
: "Responses may take longer.",
|
||||
});
|
||||
}
|
||||
|
||||
// Merge files dropped onto the chat window into internal state.
|
||||
useEffect(() => {
|
||||
if (droppedFiles && droppedFiles.length > 0) {
|
||||
@@ -157,6 +179,13 @@ export function ChatInput({
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
{showModeToggle && (
|
||||
<ModeToggleButton
|
||||
mode={copilotMode}
|
||||
isStreaming={isStreaming}
|
||||
onToggle={handleToggleMode}
|
||||
/>
|
||||
)}
|
||||
</PromptInputTools>
|
||||
|
||||
<div className="flex items-center gap-4">
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
cleanup,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { ChatInput } from "../ChatInput";
|
||||
|
||||
let mockCopilotMode = "extended_thinking";
|
||||
const mockSetCopilotMode = vi.fn((mode: string) => {
|
||||
mockCopilotMode = mode;
|
||||
});
|
||||
|
||||
vi.mock("@/app/(platform)/copilot/store", () => ({
|
||||
useCopilotUIStore: () => ({
|
||||
copilotMode: mockCopilotMode,
|
||||
setCopilotMode: mockSetCopilotMode,
|
||||
initialPrompt: null,
|
||||
setInitialPrompt: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
let mockFlagValue = false;
|
||||
vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
|
||||
useGetFlag: () => mockFlagValue,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
toast: vi.fn(),
|
||||
useToast: () => ({ toast: vi.fn(), dismiss: vi.fn() }),
|
||||
}));
|
||||
|
||||
vi.mock("../useVoiceRecording", () => ({
|
||||
useVoiceRecording: () => ({
|
||||
isRecording: false,
|
||||
isTranscribing: false,
|
||||
elapsedTime: 0,
|
||||
toggleRecording: vi.fn(),
|
||||
handleKeyDown: vi.fn(),
|
||||
showMicButton: false,
|
||||
isInputDisabled: false,
|
||||
audioStream: null,
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("@/components/ai-elements/prompt-input", () => ({
|
||||
PromptInputBody: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PromptInputFooter: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PromptInputSubmit: ({ disabled }: { disabled?: boolean }) => (
|
||||
<button disabled={disabled} data-testid="submit">
|
||||
Send
|
||||
</button>
|
||||
),
|
||||
PromptInputTextarea: (props: {
|
||||
id?: string;
|
||||
value?: string;
|
||||
onChange?: React.ChangeEventHandler<HTMLTextAreaElement>;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
}) => (
|
||||
<textarea
|
||||
id={props.id}
|
||||
value={props.value}
|
||||
onChange={props.onChange}
|
||||
disabled={props.disabled}
|
||||
placeholder={props.placeholder}
|
||||
data-testid="textarea"
|
||||
/>
|
||||
),
|
||||
PromptInputTools: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="tools">{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("@/components/ui/input-group", () => ({
|
||||
InputGroup: ({
|
||||
children,
|
||||
className,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
className?: string;
|
||||
}) => <div className={className}>{children}</div>,
|
||||
}));
|
||||
|
||||
vi.mock("../components/AttachmentMenu", () => ({
|
||||
AttachmentMenu: () => <div data-testid="attachment-menu" />,
|
||||
}));
|
||||
vi.mock("../components/FileChips", () => ({
|
||||
FileChips: () => null,
|
||||
}));
|
||||
vi.mock("../components/RecordingButton", () => ({
|
||||
RecordingButton: () => null,
|
||||
}));
|
||||
vi.mock("../components/RecordingIndicator", () => ({
|
||||
RecordingIndicator: () => null,
|
||||
}));
|
||||
|
||||
const mockOnSend = vi.fn();
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.clearAllMocks();
|
||||
mockCopilotMode = "extended_thinking";
|
||||
});
|
||||
|
||||
describe("ChatInput mode toggle", () => {
|
||||
it("does not render mode toggle when flag is disabled", () => {
|
||||
mockFlagValue = false;
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
|
||||
});
|
||||
|
||||
it("renders mode toggle when flag is enabled", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(screen.getByLabelText(/switch to fast mode/i)).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows Thinking label in extended_thinking mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(screen.getByText("Thinking")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows Fast label in fast mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "fast";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(screen.getByText("Fast")).toBeDefined();
|
||||
});
|
||||
|
||||
it("toggles from extended_thinking to fast on click", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
|
||||
expect(mockSetCopilotMode).toHaveBeenCalledWith("fast");
|
||||
});
|
||||
|
||||
it("toggles from fast to extended_thinking on click", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "fast";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to extended thinking/i));
|
||||
expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking");
|
||||
});
|
||||
|
||||
it("disables toggle button when streaming", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} isStreaming />);
|
||||
const button = screen.getByLabelText(/switch to fast mode/i);
|
||||
expect(button.hasAttribute("disabled")).toBe(true);
|
||||
});
|
||||
|
||||
it("exposes aria-pressed=true in extended_thinking mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
const button = screen.getByLabelText(/switch to fast mode/i);
|
||||
expect(button.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
|
||||
it("sets aria-pressed=false in fast mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "fast";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
const button = screen.getByLabelText(/switch to extended thinking/i);
|
||||
expect(button.getAttribute("aria-pressed")).toBe("false");
|
||||
});
|
||||
|
||||
it("uses streaming-specific tooltip when disabled", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} isStreaming />);
|
||||
const button = screen.getByLabelText(/switch to fast mode/i);
|
||||
expect(button.getAttribute("title")).toBe(
|
||||
"Mode cannot be changed while streaming",
|
||||
);
|
||||
});
|
||||
|
||||
it("shows a toast when the user toggles mode", async () => {
|
||||
const { toast } = await import("@/components/molecules/Toast/use-toast");
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to fast mode/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,122 @@
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { useChatInput } from "../useChatInput";
|
||||
|
||||
vi.mock("@/app/(platform)/copilot/store", () => ({
|
||||
useCopilotUIStore: () => ({
|
||||
initialPrompt: null,
|
||||
setInitialPrompt: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("useChatInput", () => {
|
||||
const mockOnSend = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockOnSend.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
it("does not send when value is empty", async () => {
|
||||
const { result } = renderHook(() => useChatInput({ onSend: mockOnSend }));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend();
|
||||
});
|
||||
|
||||
expect(mockOnSend).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("sends trimmed value and clears input", async () => {
|
||||
const { result } = renderHook(() => useChatInput({ onSend: mockOnSend }));
|
||||
|
||||
act(() => {
|
||||
result.current.setValue(" hello ");
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend();
|
||||
});
|
||||
|
||||
expect(mockOnSend).toHaveBeenCalledWith("hello");
|
||||
expect(result.current.value).toBe("");
|
||||
});
|
||||
|
||||
it("does not send when disabled", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useChatInput({ onSend: mockOnSend, disabled: true }),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.setValue("hello");
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend();
|
||||
});
|
||||
|
||||
expect(mockOnSend).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("prevents double-submit via ref guard", async () => {
|
||||
let resolveFirst: () => void;
|
||||
const slowSend = vi.fn(
|
||||
() =>
|
||||
new Promise<void>((resolve) => {
|
||||
resolveFirst = resolve;
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useChatInput({ onSend: slowSend }));
|
||||
|
||||
act(() => {
|
||||
result.current.setValue("hello");
|
||||
});
|
||||
|
||||
act(() => {
|
||||
void result.current.handleSend();
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend();
|
||||
});
|
||||
|
||||
expect(slowSend).toHaveBeenCalledTimes(1);
|
||||
|
||||
await act(async () => {
|
||||
resolveFirst!();
|
||||
});
|
||||
});
|
||||
|
||||
it("allows sending empty when canSendEmpty is true", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useChatInput({ onSend: mockOnSend, canSendEmpty: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend();
|
||||
});
|
||||
|
||||
expect(mockOnSend).toHaveBeenCalledWith("");
|
||||
});
|
||||
|
||||
it("resets isSending after onSend throws", async () => {
|
||||
mockOnSend.mockRejectedValue(new Error("fail"));
|
||||
|
||||
const { result } = renderHook(() => useChatInput({ onSend: mockOnSend }));
|
||||
|
||||
act(() => {
|
||||
result.current.setValue("hello");
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current.handleSend();
|
||||
} catch {
|
||||
// expected
|
||||
}
|
||||
});
|
||||
|
||||
expect(result.current.isSending).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,53 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Brain, Lightning } from "@phosphor-icons/react";
|
||||
|
||||
type CopilotMode = "extended_thinking" | "fast";
|
||||
|
||||
interface Props {
|
||||
mode: CopilotMode;
|
||||
isStreaming: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
export function ModeToggleButton({ mode, isStreaming, onToggle }: Props) {
|
||||
const isExtended = mode === "extended_thinking";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isExtended}
|
||||
disabled={isStreaming}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isExtended
|
||||
? "bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-300"
|
||||
: "bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-300",
|
||||
isStreaming && "cursor-not-allowed opacity-50",
|
||||
)}
|
||||
aria-label={
|
||||
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
|
||||
}
|
||||
title={
|
||||
isStreaming
|
||||
? "Mode cannot be changed while streaming"
|
||||
: isExtended
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
}
|
||||
>
|
||||
{isExtended ? (
|
||||
<>
|
||||
<Brain size={14} />
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Lightning size={14} />
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useCopilotUIStore } from "@/app/(platform)/copilot/store";
|
||||
import { ChangeEvent, FormEvent, useEffect, useState } from "react";
|
||||
import { ChangeEvent, FormEvent, useEffect, useRef, useState } from "react";
|
||||
|
||||
interface Args {
|
||||
onSend: (message: string) => void;
|
||||
@@ -17,6 +17,9 @@ export function useChatInput({
|
||||
}: Args) {
|
||||
const [value, setValue] = useState("");
|
||||
const [isSending, setIsSending] = useState(false);
|
||||
// Synchronous guard against double-submit — refs update immediately,
|
||||
// unlike state which batches and can leave a gap for a second call.
|
||||
const isSubmittingRef = useRef(false);
|
||||
const { initialPrompt, setInitialPrompt } = useCopilotUIStore();
|
||||
|
||||
useEffect(
|
||||
@@ -47,12 +50,15 @@ export function useChatInput({
|
||||
|
||||
async function handleSend() {
|
||||
if (disabled || isSending || (!value.trim() && !canSendEmpty)) return;
|
||||
if (isSubmittingRef.current) return;
|
||||
|
||||
isSubmittingRef.current = true;
|
||||
setIsSending(true);
|
||||
try {
|
||||
await onSend(value.trim());
|
||||
setValue("");
|
||||
} finally {
|
||||
isSubmittingRef.current = false;
|
||||
setIsSending(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import type { UIMessage } from "ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
ORIGINAL_TITLE,
|
||||
extractSendMessageText,
|
||||
formatNotificationTitle,
|
||||
parseSessionIDs,
|
||||
shouldSuppressDuplicateSend,
|
||||
} from "./helpers";
|
||||
|
||||
describe("formatNotificationTitle", () => {
|
||||
@@ -74,3 +77,118 @@ describe("parseSessionIDs", () => {
|
||||
expect(parseSessionIDs('["a","a","b"]')).toEqual(new Set(["a", "b"]));
|
||||
});
|
||||
});
|
||||
|
||||
describe("extractSendMessageText", () => {
|
||||
it("extracts text from a string argument", () => {
|
||||
expect(extractSendMessageText("hello")).toBe("hello");
|
||||
});
|
||||
|
||||
it("extracts text from an object with text property", () => {
|
||||
expect(extractSendMessageText({ text: "world" })).toBe("world");
|
||||
});
|
||||
|
||||
it("returns empty string for null", () => {
|
||||
expect(extractSendMessageText(null)).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for undefined", () => {
|
||||
expect(extractSendMessageText(undefined)).toBe("");
|
||||
});
|
||||
|
||||
it("converts numbers to string", () => {
|
||||
expect(extractSendMessageText(42)).toBe("42");
|
||||
});
|
||||
});
|
||||
|
||||
function makeMsg(role: "user" | "assistant", text: string): UIMessage {
|
||||
return {
|
||||
id: `msg-${Math.random()}`,
|
||||
role,
|
||||
parts: [{ type: "text", text }],
|
||||
};
|
||||
}
|
||||
|
||||
describe("shouldSuppressDuplicateSend", () => {
|
||||
it("suppresses when reconnect is scheduled", () => {
|
||||
expect(
|
||||
shouldSuppressDuplicateSend({
|
||||
text: "hello",
|
||||
isReconnectScheduled: true,
|
||||
lastSubmittedText: null,
|
||||
messages: [],
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it("allows send when not reconnecting and no prior submission", () => {
|
||||
expect(
|
||||
shouldSuppressDuplicateSend({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: null,
|
||||
messages: [],
|
||||
}),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("suppresses when text matches last submitted AND last user message", () => {
|
||||
const messages = [makeMsg("user", "hello"), makeMsg("assistant", "hi")];
|
||||
expect(
|
||||
shouldSuppressDuplicateSend({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages,
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it("allows send when text matches last submitted but differs from last user message", () => {
|
||||
const messages = [
|
||||
makeMsg("user", "different"),
|
||||
makeMsg("assistant", "reply"),
|
||||
];
|
||||
expect(
|
||||
shouldSuppressDuplicateSend({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages,
|
||||
}),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("allows send when text differs from last submitted", () => {
|
||||
const messages = [makeMsg("user", "hello")];
|
||||
expect(
|
||||
shouldSuppressDuplicateSend({
|
||||
text: "new message",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages,
|
||||
}),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("allows send when text is empty", () => {
|
||||
expect(
|
||||
shouldSuppressDuplicateSend({
|
||||
text: "",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "",
|
||||
messages: [],
|
||||
}),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("allows send with empty messages array even if text matches lastSubmitted", () => {
|
||||
expect(
|
||||
shouldSuppressDuplicateSend({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages: [],
|
||||
}),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -65,6 +65,72 @@ export function resolveInProgressTools(
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the user-visible text from the arguments passed to `sendMessage`.
|
||||
* Handles both `sendMessage("hello")` and `sendMessage({ text: "hello" })`.
|
||||
*/
|
||||
export function extractSendMessageText(firstArg: unknown): string {
|
||||
if (firstArg && typeof firstArg === "object" && "text" in firstArg)
|
||||
return (firstArg as { text: string }).text;
|
||||
return String(firstArg ?? "");
|
||||
}
|
||||
|
||||
interface SuppressDuplicateArgs {
|
||||
text: string;
|
||||
isReconnectScheduled: boolean;
|
||||
lastSubmittedText: string | null;
|
||||
messages: UIMessage[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Reason a sendMessage was suppressed, or ``null`` to pass through.
|
||||
*
|
||||
* - ``"reconnecting"``: the stream is reconnecting; the caller should
|
||||
* notify the user (the UI may not yet reflect the disabled state).
|
||||
* - ``"duplicate"``: the same text was just submitted and echoed back
|
||||
* by the session — safe to silently drop (user double-clicked).
|
||||
*/
|
||||
export type SuppressReason = "reconnecting" | "duplicate" | null;
|
||||
|
||||
/**
|
||||
* Determine whether a sendMessage call should be suppressed to prevent
|
||||
* duplicate POSTs during reconnect cycles or re-submits of the same text.
|
||||
*
|
||||
* Returns the reason so callers can surface user-visible feedback when
|
||||
* the suppression isn't just a silent duplicate.
|
||||
*/
|
||||
export function getSendSuppressionReason({
|
||||
text,
|
||||
isReconnectScheduled,
|
||||
lastSubmittedText,
|
||||
messages,
|
||||
}: SuppressDuplicateArgs): SuppressReason {
|
||||
if (isReconnectScheduled) return "reconnecting";
|
||||
|
||||
if (text && lastSubmittedText === text) {
|
||||
const lastUserMsg = messages.filter((m) => m.role === "user").pop();
|
||||
const lastUserText = lastUserMsg?.parts
|
||||
?.map((p) => ("text" in p ? p.text : ""))
|
||||
.join("")
|
||||
.trim();
|
||||
if (lastUserText === text) return "duplicate";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Backwards-compatible boolean wrapper for ``getSendSuppressionReason``.
|
||||
*
|
||||
* @deprecated Call ``getSendSuppressionReason`` directly so callers can
|
||||
* distinguish between reconnect and duplicate suppression.
|
||||
*/
|
||||
export function shouldSuppressDuplicateSend(
|
||||
args: SuppressDuplicateArgs,
|
||||
): boolean {
|
||||
return getSendSuppressionReason(args) !== null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by consecutive content fingerprint.
|
||||
*
|
||||
|
||||
@@ -47,6 +47,10 @@ interface CopilotUIState {
|
||||
showNotificationDialog: boolean;
|
||||
setShowNotificationDialog: (show: boolean) => void;
|
||||
|
||||
/** Autopilot mode: 'extended_thinking' (default) or 'fast'. */
|
||||
copilotMode: "extended_thinking" | "fast";
|
||||
setCopilotMode: (mode: "extended_thinking" | "fast") => void;
|
||||
|
||||
clearCopilotLocalData: () => void;
|
||||
}
|
||||
|
||||
@@ -104,16 +108,27 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
showNotificationDialog: false,
|
||||
setShowNotificationDialog: (show) => set({ showNotificationDialog: show }),
|
||||
|
||||
copilotMode:
|
||||
isClient && storage.get(Key.COPILOT_MODE) === "fast"
|
||||
? "fast"
|
||||
: "extended_thinking",
|
||||
setCopilotMode: (mode) => {
|
||||
storage.set(Key.COPILOT_MODE, mode);
|
||||
set({ copilotMode: mode });
|
||||
},
|
||||
|
||||
clearCopilotLocalData: () => {
|
||||
storage.clean(Key.COPILOT_NOTIFICATIONS_ENABLED);
|
||||
storage.clean(Key.COPILOT_SOUND_ENABLED);
|
||||
storage.clean(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED);
|
||||
storage.clean(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED);
|
||||
storage.clean(Key.COPILOT_MODE);
|
||||
storage.clean(Key.COPILOT_COMPLETED_SESSIONS);
|
||||
set({
|
||||
completedSessionIDs: new Set<string>(),
|
||||
isNotificationsEnabled: false,
|
||||
isSoundEnabled: true,
|
||||
copilotMode: "extended_thinking",
|
||||
});
|
||||
if (isClient) {
|
||||
document.title = ORIGINAL_TITLE;
|
||||
|
||||
@@ -10,6 +10,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
@@ -32,8 +33,15 @@ export function useCopilotPage() {
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { sessionToDelete, setSessionToDelete, isDrawerOpen, setDrawerOpen } =
|
||||
useCopilotUIStore();
|
||||
const isModeToggleEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
|
||||
const {
|
||||
sessionToDelete,
|
||||
setSessionToDelete,
|
||||
isDrawerOpen,
|
||||
setDrawerOpen,
|
||||
copilotMode,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
const {
|
||||
sessionId,
|
||||
@@ -64,6 +72,7 @@ export function useCopilotPage() {
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
|
||||
});
|
||||
|
||||
useCopilotNotifications(sessionId);
|
||||
|
||||
@@ -10,11 +10,13 @@ import { useChat } from "@ai-sdk/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { DefaultChatTransport } from "ai";
|
||||
import type { FileUIPart, UIMessage } from "ai";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import {
|
||||
deduplicateMessages,
|
||||
extractSendMessageText,
|
||||
hasActiveBackendStream,
|
||||
resolveInProgressTools,
|
||||
getSendSuppressionReason,
|
||||
} from "./helpers";
|
||||
|
||||
const RECONNECT_BASE_DELAY_MS = 1_000;
|
||||
@@ -38,6 +40,8 @@ interface UseCopilotStreamArgs {
|
||||
hydratedMessages: UIMessage[] | undefined;
|
||||
hasActiveStream: boolean;
|
||||
refetchSession: () => Promise<{ data?: unknown }>;
|
||||
/** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */
|
||||
copilotMode: "extended_thinking" | "fast" | undefined;
|
||||
}
|
||||
|
||||
export function useCopilotStream({
|
||||
@@ -45,10 +49,18 @@ export function useCopilotStream({
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode,
|
||||
}: UseCopilotStreamArgs) {
|
||||
const queryClient = useQueryClient();
|
||||
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
|
||||
const dismissRateLimit = useCallback(() => setRateLimitMessage(null), []);
|
||||
function dismissRateLimit() {
|
||||
setRateLimitMessage(null);
|
||||
}
|
||||
// Use a ref for copilotMode so the transport closure always reads the
|
||||
// latest value without recreating the DefaultChatTransport (which would
|
||||
// reset useChat's internal Chat instance and break mid-session streaming).
|
||||
const copilotModeRef = useRef(copilotMode);
|
||||
copilotModeRef.current = copilotMode;
|
||||
|
||||
// Connect directly to the Python backend for SSE, bypassing the Next.js
|
||||
// serverless proxy. This eliminates the Vercel 800s function timeout that
|
||||
@@ -79,6 +91,7 @@ export function useCopilotStream({
|
||||
is_user_message: last.role === "user",
|
||||
context: null,
|
||||
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
|
||||
mode: copilotModeRef.current ?? null,
|
||||
},
|
||||
headers: await getAuthHeaders(),
|
||||
};
|
||||
@@ -147,9 +160,14 @@ export function useCopilotStream({
|
||||
}, delay);
|
||||
}
|
||||
|
||||
// Tracks the ID of the last user message that was submitted via sendMessage.
|
||||
// During a reconnect cycle, if the session already contains this message, we
|
||||
// must not POST it again — only GET-resume is safe.
|
||||
const lastSubmittedMsgRef = useRef<string | null>(null);
|
||||
|
||||
const {
|
||||
messages: rawMessages,
|
||||
sendMessage,
|
||||
sendMessage: sdkSendMessage,
|
||||
stop: sdkStop,
|
||||
status,
|
||||
error,
|
||||
@@ -236,6 +254,36 @@ export function useCopilotStream({
|
||||
},
|
||||
});
|
||||
|
||||
// Wrap sdkSendMessage to guard against re-sending the user message during a
|
||||
// reconnect cycle. If the session already has the message (i.e. we are in a
|
||||
// reconnect/resume flow), only GET-resume is safe — never re-POST.
|
||||
const sendMessage: typeof sdkSendMessage = async (...args) => {
|
||||
const text = extractSendMessageText(args[0]);
|
||||
|
||||
const suppressReason = getSendSuppressionReason({
|
||||
text,
|
||||
isReconnectScheduled: isReconnectScheduledRef.current,
|
||||
lastSubmittedText: lastSubmittedMsgRef.current,
|
||||
messages: rawMessages,
|
||||
});
|
||||
|
||||
if (suppressReason === "reconnecting") {
|
||||
// The ref flips to ``true`` synchronously while the React state that
|
||||
// drives the UI's disabled state only updates on the next render, so
|
||||
// the user may have clicked send against a still-enabled input. Tell
|
||||
// them their message wasn't dropped silently.
|
||||
toast({
|
||||
title: "Reconnecting",
|
||||
description: "Wait for the connection to resume before sending.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (suppressReason === "duplicate") return;
|
||||
|
||||
lastSubmittedMsgRef.current = text;
|
||||
return sdkSendMessage(...args);
|
||||
};
|
||||
|
||||
// Deduplicate messages continuously to prevent duplicates when resuming streams
|
||||
const messages = useMemo(
|
||||
() => deduplicateMessages(rawMessages),
|
||||
@@ -381,6 +429,7 @@ export function useCopilotStream({
|
||||
setRateLimitMessage(null);
|
||||
hasShownDisconnectToast.current = false;
|
||||
isUserStoppingRef.current = false;
|
||||
lastSubmittedMsgRef.current = null;
|
||||
setReconnectExhausted(false);
|
||||
setIsSyncing(false);
|
||||
hasResumedRef.current.clear();
|
||||
@@ -409,6 +458,7 @@ export function useCopilotStream({
|
||||
if (status === "ready") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
hasShownDisconnectToast.current = false;
|
||||
lastSubmittedMsgRef.current = null;
|
||||
setReconnectExhausted(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13175,6 +13175,14 @@
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "File Ids"
|
||||
},
|
||||
"mode": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "enum": ["fast", "extended_thinking"] },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Mode",
|
||||
"description": "Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. If None, uses the server default (extended_thinking)."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -13,6 +13,7 @@ export enum Flag {
|
||||
AGENT_FAVORITING = "agent-favoriting",
|
||||
MARKETPLACE_SEARCH_TERMS = "marketplace-search-terms",
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment",
|
||||
CHAT_MODE_OPTION = "chat-mode-option",
|
||||
}
|
||||
|
||||
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
||||
@@ -26,6 +27,7 @@ const defaultFlags = {
|
||||
[Flag.AGENT_FAVORITING]: false,
|
||||
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
|
||||
[Flag.ENABLE_PLATFORM_PAYMENT]: false,
|
||||
[Flag.CHAT_MODE_OPTION]: true,
|
||||
};
|
||||
|
||||
type FlagValues = typeof defaultFlags;
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import { describe, expect, it, beforeEach, vi } from "vitest";
|
||||
|
||||
vi.mock("@sentry/nextjs", () => ({
|
||||
captureException: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/services/environment", () => ({
|
||||
environment: {
|
||||
isServerSide: vi.fn(() => false),
|
||||
},
|
||||
}));
|
||||
|
||||
import { Key, storage } from "../local-storage";
|
||||
import { environment } from "@/services/environment";
|
||||
|
||||
describe("storage", () => {
|
||||
beforeEach(() => {
|
||||
window.localStorage.clear();
|
||||
vi.mocked(environment.isServerSide).mockReturnValue(false);
|
||||
});
|
||||
|
||||
describe("set and get", () => {
|
||||
it("stores and retrieves a value", () => {
|
||||
storage.set(Key.COPILOT_MODE, "fast");
|
||||
expect(storage.get(Key.COPILOT_MODE)).toBe("fast");
|
||||
});
|
||||
|
||||
it("returns null for unset keys", () => {
|
||||
expect(storage.get(Key.COPILOT_MODE)).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("clean", () => {
|
||||
it("removes a stored value", () => {
|
||||
storage.set(Key.COPILOT_SOUND_ENABLED, "true");
|
||||
storage.clean(Key.COPILOT_SOUND_ENABLED);
|
||||
expect(storage.get(Key.COPILOT_SOUND_ENABLED)).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("server-side guard", () => {
|
||||
it("returns undefined for get when on server side", () => {
|
||||
vi.mocked(environment.isServerSide).mockReturnValue(true);
|
||||
expect(storage.get(Key.COPILOT_MODE)).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for set when on server side", () => {
|
||||
vi.mocked(environment.isServerSide).mockReturnValue(true);
|
||||
expect(storage.set(Key.COPILOT_MODE, "fast")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for clean when on server side", () => {
|
||||
vi.mocked(environment.isServerSide).mockReturnValue(true);
|
||||
expect(storage.clean(Key.COPILOT_MODE)).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Key enum", () => {
|
||||
it("has expected keys", () => {
|
||||
expect(Key.COPILOT_MODE).toBe("copilot-mode");
|
||||
expect(Key.COPILOT_SOUND_ENABLED).toBe("copilot-sound-enabled");
|
||||
expect(Key.COPILOT_NOTIFICATIONS_ENABLED).toBe(
|
||||
"copilot-notifications-enabled",
|
||||
);
|
||||
expect(Key.CHAT_SESSION_ID).toBe("chat_session_id");
|
||||
});
|
||||
});
|
||||
@@ -15,6 +15,7 @@ export enum Key {
|
||||
COPILOT_NOTIFICATIONS_ENABLED = "copilot-notifications-enabled",
|
||||
COPILOT_NOTIFICATION_BANNER_DISMISSED = "copilot-notification-banner-dismissed",
|
||||
COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed",
|
||||
COPILOT_MODE = "copilot-mode",
|
||||
COPILOT_COMPLETED_SESSIONS = "copilot-completed-sessions",
|
||||
}
|
||||
|
||||
|
||||
BIN
test-screenshots/PR-12623/01-copilot-initial.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
test-screenshots/PR-12623/02-copilot-page.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
test-screenshots/PR-12623/03-copilot-with-toggle.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
test-screenshots/PR-12623/04-copilot-page.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
test-screenshots/PR-12623/05-copilot-loaded.png
Normal file
|
After Width: | Height: | Size: 77 KiB |
BIN
test-screenshots/PR-12623/06-chat-input-area.png
Normal file
|
After Width: | Height: | Size: 77 KiB |
BIN
test-screenshots/PR-12623/07-chat-streaming.png
Normal file
|
After Width: | Height: | Size: 60 KiB |
BIN
test-screenshots/PR-12623/08-chat-response.png
Normal file
|
After Width: | Height: | Size: 101 KiB |
BIN
test-screenshots/PR-12623/10-copilot-extended-thinking-mode.png
Normal file
|
After Width: | Height: | Size: 77 KiB |
BIN
test-screenshots/PR-12623/11-copilot-fast-mode-after-toggle.png
Normal file
|
After Width: | Height: | Size: 77 KiB |
BIN
test-screenshots/PR-12623/12-copilot-fast-mode-message-sent.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
|
After Width: | Height: | Size: 76 KiB |