fix(copilot): prevent duplicate side effects from double-submit and stale-cache race (#12660)

## Why

#12604 (intermediate persistence) introduced two bugs on dev:

1. **Duplicate user messages** — `set_turn_duration` calls
`invalidate_session_cache()` which deletes the Redis key. Concurrent
`get_chat_session()` calls re-populate it from DB with stale data. The
executor loads this stale cache, misses the user message, and re-appends
it.

2. **Tool outputs lost on hydration** — Intermediate flushes save
assistant messages to DB before `StreamToolInputAvailable` sets
`tool_calls` on them. Since `_save_session_to_db` is append-only (uses
`start_sequence`), the `tool_calls` update is lost — subsequent flushes
start past that index. On page refresh / SSE reconnect, tool UIs
(SetupRequirementsCard, run_block output, etc.) are invisible.

3. **Sessions stuck running** — If a tool call hangs (e.g. WebSearch
provider not responding), the stream never completes,
`mark_session_completed` never runs, and the `active_stream` flag stays
stale in Redis.

## What

- **In-place cache update** in `set_turn_duration` — replaces
`invalidate_session_cache()` with a read-modify-write that patches the
duration on the cached session, eliminating the stale-cache repopulation
window
- **tool_calls backfill** — tracks the flush watermark and assistant
message index; when `StreamToolInputAvailable` sets `tool_calls` on an
already-flushed assistant, updates the DB record directly via
`update_message_tool_calls()`
- **Improved message dedup** — `is_message_duplicate()` /
`maybe_append_user_message()` scans trailing same-role messages (current
turn) instead of only checking `messages[-1]`
- **Idle timeout** — aborts the stream with a retryable error if no
meaningful SDK message arrives for 10 minutes, preventing hung tool
calls from leaving sessions stuck

## Changes

- `copilot/db.py` — `update_message_tool_calls()`, in-place cache update
in `set_turn_duration`
- `copilot/model.py` — `is_message_duplicate()`,
`maybe_append_user_message()`
- `copilot/sdk/service.py` — flush watermark tracking, tool_calls
backfill, idle timeout
- `copilot/baseline/service.py` — use `maybe_append_user_message()`
- `copilot/model_test.py` — unit tests for dedup
- `copilot/db_test.py` — unit tests for set_turn_duration cache update

## Checklist

- [x] My PR title follows [conventional
commit](https://www.conventionalcommits.org/) format
- [x] Out-of-scope changes are less than 20% of the PR
- [x] Changes to `data/*.py` validated for user ID checks (N/A)
- [x] Protected routes updated in middleware (N/A)
This commit is contained in:
Zamil Majdy
2026-04-03 20:09:42 +02:00
committed by GitHub
parent f6ddcbc6cb
commit 48a653dc63
7 changed files with 319 additions and 35 deletions

View File

@@ -23,6 +23,7 @@ from backend.copilot.model import (
ChatMessage,
ChatSession,
get_chat_session,
maybe_append_user_message,
update_session_title,
upsert_chat_session,
)
@@ -397,21 +398,12 @@ async def stream_chat_completion_baseline(
f"Session {session_id} not found. Please create a new session first."
)
# Append user message
new_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_role, content=message))
if maybe_append_user_message(session, message, is_user_message):
if is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
message_length=len(message or ""),
)
session = await upsert_chat_session(session)

View File

@@ -23,8 +23,9 @@ from .model import (
ChatSession,
ChatSessionInfo,
ChatSessionMetadata,
invalidate_session_cache,
cache_chat_session,
)
from .model import get_chat_session as get_chat_session_cached
logger = logging.getLogger(__name__)
@@ -380,8 +381,11 @@ async def update_tool_message_content(
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.
Also invalidates the Redis session cache so the next GET returns
the updated duration.
Updates the Redis cache in-place instead of invalidating it.
Invalidation would delete the key, creating a window where concurrent
``get_chat_session`` calls re-populate the cache from DB — potentially
with stale data if the DB write from the previous turn hasn't propagated.
This race caused duplicate user messages on the next turn.
"""
last_msg = await PrismaChatMessage.prisma().find_first(
where={"sessionId": session_id, "role": "assistant"},
@@ -392,5 +396,13 @@ async def set_turn_duration(session_id: str, duration_ms: int) -> None:
where={"id": last_msg.id},
data={"durationMs": duration_ms},
)
# Invalidate cache so the session is re-fetched from DB with durationMs
await invalidate_session_cache(session_id)
# Update cache in-place rather than invalidating to avoid a
# race window where the empty cache gets re-populated with
# stale data by a concurrent get_chat_session call.
session = await get_chat_session_cached(session_id)
if session and session.messages:
for msg in reversed(session.messages):
if msg.role == "assistant":
msg.duration_ms = duration_ms
break
await cache_chat_session(session)

View File

@@ -0,0 +1,54 @@
import pytest
from .db import set_turn_duration
from .model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
@pytest.mark.asyncio(loop_scope="session")
async def test_set_turn_duration_updates_cache_in_place(setup_test_user, test_user_id):
"""set_turn_duration patches the cached session without invalidation.
Verifies that after calling set_turn_duration the Redis-cached session
reflects the updated durationMs on the last assistant message, without
the cache having been deleted and re-populated (which could race with
concurrent get_chat_session calls).
"""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi there"),
]
session = await upsert_chat_session(session)
# Ensure the session is in cache
cached = await get_chat_session(session.session_id, test_user_id)
assert cached is not None
assert cached.messages[-1].duration_ms is None
# Update turn duration — should patch cache in-place
await set_turn_duration(session.session_id, 1234)
# Read from cache (not DB) — the cache should already have the update
updated = await get_chat_session(session.session_id, test_user_id)
assert updated is not None
assistant_msgs = [m for m in updated.messages if m.role == "assistant"]
assert len(assistant_msgs) == 1
assert assistant_msgs[0].duration_ms == 1234
@pytest.mark.asyncio(loop_scope="session")
async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user_id):
"""set_turn_duration is a no-op when there are no assistant messages."""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
ChatMessage(role="user", content="hello"),
]
session = await upsert_chat_session(session)
# Should not raise
await set_turn_duration(session.session_id, 5678)
cached = await get_chat_session(session.session_id, test_user_id)
assert cached is not None
# User message should not have durationMs
assert cached.messages[0].duration_ms is None

View File

@@ -81,6 +81,49 @@ class ChatMessage(BaseModel):
)
def is_message_duplicate(
messages: list[ChatMessage],
role: str,
content: str,
) -> bool:
"""Check whether *content* is already present in the current pending turn.
Only inspects trailing messages that share the given *role* (i.e. the
current turn). This ensures legitimately repeated messages across different
turns are not suppressed, while same-turn duplicates from stale cache are
still caught.
"""
for m in reversed(messages):
if m.role == role:
if m.content == content:
return True
else:
break
return False
def maybe_append_user_message(
session: "ChatSession",
message: str | None,
is_user_message: bool,
) -> bool:
"""Append a user/assistant message to the session if not already present.
The route handler already persists the user message before enqueueing,
so we check trailing same-role messages to avoid re-appending when the
session cache is slightly stale.
Returns True if the message was appended, False if skipped.
"""
if not message:
return False
role = "user" if is_user_message else "assistant"
if is_message_duplicate(session.messages, role, message):
return False
session.messages.append(ChatMessage(role=role, content=message))
return True
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int

View File

@@ -17,6 +17,8 @@ from .model import (
ChatSession,
Usage,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
upsert_chat_session,
)
@@ -424,3 +426,151 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
assert "Streaming message 1" in contents
assert "Streaming message 2" in contents
assert "Callback result" in contents
# --------------------------------------------------------------------------- #
# is_message_duplicate #
# --------------------------------------------------------------------------- #
def test_duplicate_detected_in_trailing_same_role():
"""Duplicate user message at the tail is detected."""
msgs = [
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi there"),
ChatMessage(role="user", content="yes"),
]
assert is_message_duplicate(msgs, "user", "yes") is True
def test_duplicate_not_detected_across_turns():
"""Same text in a previous turn (separated by assistant) is NOT a duplicate."""
msgs = [
ChatMessage(role="user", content="yes"),
ChatMessage(role="assistant", content="ok"),
]
assert is_message_duplicate(msgs, "user", "yes") is False
def test_no_duplicate_on_empty_messages():
"""Empty message list never reports a duplicate."""
assert is_message_duplicate([], "user", "hello") is False
def test_no_duplicate_when_content_differs():
"""Different content in the trailing same-role block is not a duplicate."""
msgs = [
ChatMessage(role="assistant", content="response"),
ChatMessage(role="user", content="first message"),
]
assert is_message_duplicate(msgs, "user", "second message") is False
def test_duplicate_with_multiple_trailing_same_role():
"""Detects duplicate among multiple consecutive same-role messages."""
msgs = [
ChatMessage(role="assistant", content="response"),
ChatMessage(role="user", content="msg1"),
ChatMessage(role="user", content="msg2"),
]
assert is_message_duplicate(msgs, "user", "msg1") is True
assert is_message_duplicate(msgs, "user", "msg2") is True
assert is_message_duplicate(msgs, "user", "msg3") is False
def test_duplicate_check_for_assistant_role():
"""Works correctly when checking assistant role too."""
msgs = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="assistant", content="how can I help?"),
]
assert is_message_duplicate(msgs, "assistant", "hello") is True
assert is_message_duplicate(msgs, "assistant", "new response") is False
def test_no_false_positive_when_content_is_none():
"""Messages with content=None in the trailing block do not match."""
msgs = [
ChatMessage(role="user", content=None),
ChatMessage(role="user", content="hello"),
]
assert is_message_duplicate(msgs, "user", "hello") is True
# None-content message should not match any string
msgs2 = [
ChatMessage(role="user", content=None),
]
assert is_message_duplicate(msgs2, "user", "hello") is False
def test_all_same_role_messages():
"""When all messages share the same role, the entire list is scanned."""
msgs = [
ChatMessage(role="user", content="first"),
ChatMessage(role="user", content="second"),
ChatMessage(role="user", content="third"),
]
assert is_message_duplicate(msgs, "user", "first") is True
assert is_message_duplicate(msgs, "user", "new") is False
# --------------------------------------------------------------------------- #
# maybe_append_user_message #
# --------------------------------------------------------------------------- #
def test_maybe_append_user_message_appends_new():
"""A new user message is appended and returns True."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="hello"),
]
result = maybe_append_user_message(session, "new msg", is_user_message=True)
assert result is True
assert len(session.messages) == 2
assert session.messages[-1].role == "user"
assert session.messages[-1].content == "new msg"
def test_maybe_append_user_message_skips_duplicate():
"""A duplicate user message is skipped and returns False."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="user", content="dup"),
]
result = maybe_append_user_message(session, "dup", is_user_message=True)
assert result is False
assert len(session.messages) == 2
def test_maybe_append_user_message_none_message():
"""None/empty message returns False without appending."""
session = ChatSession.new(user_id="u", dry_run=False)
assert maybe_append_user_message(session, None, is_user_message=True) is False
assert maybe_append_user_message(session, "", is_user_message=True) is False
assert len(session.messages) == 0
def test_maybe_append_assistant_message():
"""Works for assistant role when is_user_message=False."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
]
result = maybe_append_user_message(session, "response", is_user_message=False)
assert result is True
assert session.messages[-1].role == "assistant"
assert session.messages[-1].content == "response"
def test_maybe_append_assistant_skips_duplicate():
"""Duplicate assistant message is skipped."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="dup"),
]
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2

View File

@@ -52,6 +52,7 @@ from ..model import (
ChatMessage,
ChatSession,
get_chat_session,
maybe_append_user_message,
update_session_title,
upsert_chat_session,
)
@@ -130,6 +131,11 @@ _CIRCUIT_BREAKER_ERROR_MSG = (
"Try breaking your request into smaller parts."
)
# Idle timeout: abort the stream if no meaningful SDK message (only heartbeats)
# arrives for this many seconds. This catches hung tool calls (e.g. WebSearch
# hanging on a search provider that never responds).
_IDLE_TIMEOUT_SECONDS = 10 * 60 # 10 minutes
# Patterns that indicate the prompt/request exceeds the model's context limit.
# Matched case-insensitively against the full exception chain.
_PROMPT_TOO_LONG_PATTERNS: tuple[str, ...] = (
@@ -1272,6 +1278,8 @@ async def _run_stream_attempt(
await client.query(state.query_message, session_id=ctx.session_id)
state.transcript_builder.append_user(content=ctx.current_message)
_last_real_msg_time = time.monotonic()
async for sdk_msg in _iter_sdk_messages(client):
# Heartbeat sentinel — refresh lock and keep SSE alive
if sdk_msg is None:
@@ -1279,8 +1287,34 @@ async def _run_stream_attempt(
for ev in ctx.compaction.emit_start_if_ready():
yield ev
yield StreamHeartbeat()
# Idle timeout: if no real SDK message for too long, a tool
# call is likely hung (e.g. WebSearch provider not responding).
idle_seconds = time.monotonic() - _last_real_msg_time
if idle_seconds >= _IDLE_TIMEOUT_SECONDS:
logger.error(
"%s Idle timeout after %.0fs with no SDK message — "
"aborting stream (likely hung tool call)",
ctx.log_prefix,
idle_seconds,
)
stream_error_msg = (
"A tool call appears to be stuck "
"(no response for 10 minutes). "
"Please try again."
)
stream_error_code = "idle_timeout"
_append_error_marker(ctx.session, stream_error_msg, retryable=True)
yield StreamError(
errorText=stream_error_msg,
code=stream_error_code,
)
ended_with_stream_error = True
break
continue
_last_real_msg_time = time.monotonic()
logger.info(
"%s Received: %s %s (unresolved=%d, current=%d, resolved=%d)",
ctx.log_prefix,
@@ -1529,9 +1563,21 @@ async def _run_stream_attempt(
# --- Intermediate persistence ---
# Flush session messages to DB periodically so page reloads
# show progress during long-running turns.
#
# IMPORTANT: Skip the flush while tool calls are pending
# (tool_calls set on assistant but results not yet received).
# The DB save is append-only (uses start_sequence), so if we
# flush the assistant message before tool_calls are set on it
# (text and tool_use arrive as separate SDK events), the
# tool_calls update is lost — the next flush starts past it.
_msgs_since_flush += 1
now = time.monotonic()
if (
has_pending_tools = (
acc.has_appended_assistant
and acc.accumulated_tool_calls
and not acc.has_tool_results
)
if not has_pending_tools and (
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
):
@@ -1670,19 +1716,12 @@ async def stream_chat_completion_sdk(
)
session.messages.pop()
# Append the new message to the session if it's not already there
new_message_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_message_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_message_role, content=message))
if maybe_append_user_message(session, message, is_user_message):
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
user_id=user_id,
session_id=session_id,
message_length=len(message or ""),
)
# Structured log prefix: [SDK][<session>][T<turn>]

View File

@@ -2,7 +2,6 @@ import { MessageResponse } from "@/components/ai-elements/message";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { ExclamationMarkIcon } from "@phosphor-icons/react";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useState } from "react";
import { AskQuestionTool } from "../../../tools/AskQuestion/AskQuestion";
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
@@ -29,12 +28,10 @@ import { parseSpecialMarkers, resolveWorkspaceUrls } from "../helpers";
*/
function WorkspaceMediaImage(props: React.JSX.IntrinsicElements["img"]) {
const { src, alt, ...rest } = props;
const [imgFailed, setImgFailed] = useState(false);
const isWorkspace = src?.includes("/workspace/files/") ?? false;
if (!src) return null;
if (alt?.startsWith("video:") || (imgFailed && isWorkspace)) {
if (alt?.startsWith("video:")) {
return (
<span className="my-2 inline-block">
<video
@@ -56,9 +53,6 @@ function WorkspaceMediaImage(props: React.JSX.IntrinsicElements["img"]) {
alt={alt || "Image"}
className="h-auto max-w-full rounded-md border border-zinc-200"
loading="lazy"
onError={() => {
if (isWorkspace) setImgFailed(true);
}}
{...rest}
/>
);