From b2f7faabc778e8bf8b5a28de1dd3143763825845 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Wed, 15 Apr 2026 21:09:44 +0700 Subject: [PATCH 01/10] fix(backend/copilot): pre-create assistant msg before first yield to prevent last_role=tool (#12797) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes **Root cause:** When a copilot session ends with a tool result as the last saved message (`last_role=tool`), the next assistant response is never persisted. This happens when: 1. An intermediate flush saves the session with `last_role=tool` (after a tool call completes) 2. The Claude Agent SDK generates a text response for the next turn 3. The client disconnects (`GeneratorExit`) at the `yield StreamStartStep` — the very first yield of the new turn 4. `_dispatch_response(StreamTextDelta)` is never called, so the assistant message is never appended to `ctx.session.messages` 5. The session `finally` block persists the session still with `last_role=tool` **Fix:** In `_run_stream_attempt`, after `convert_message()` returns the full list of adapter responses but *before* entering the yield loop, pre-create the assistant message placeholder in `ctx.session.messages` when: - `acc.has_tool_results` is True (there are pending tool results) - `acc.has_appended_assistant` is True (at least one prior message exists) - A `StreamTextDelta` is present in the batch (confirms this is a text response turn) This ensures that even if `GeneratorExit` fires at the first `yield`, the placeholder assistant message is already in the session and will be persisted by the `finally` block. **Tests:** Added `session_persistence_test.py` with 7 unit tests covering the pre-create condition logic and delta accumulation behavior. **Confirmed:** Langfuse trace `e57ebd26` for session `465bf5cf-7219-4313-a1f6-5194d2a44ff8` showed the final assistant response was logged at 13:06:49 but never reached DB — session had 51 messages with `last_role=tool`. ## Checklist - [x] My code follows the code style of this project - [x] I have performed a self-review of my own code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have made corresponding changes to the documentation (N/A) - [x] My changes generate no new warnings (Pyright warnings are pre-existing) - [x] I have added tests that prove my fix is effective - [x] New and existing unit tests pass locally with my changes --------- Co-authored-by: Zamil Majdy --- .../backend/backend/copilot/sdk/service.py | 33 +++ .../copilot/sdk/session_persistence_test.py | 217 ++++++++++++++++++ 2 files changed, 250 insertions(+) create mode 100644 autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index a55380ea6e..d76f2ece80 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -1994,6 +1994,39 @@ async def _run_stream_attempt( # --- Dispatch adapter responses --- adapter_responses = state.adapter.convert_message(sdk_msg) + + # Pre-create the new assistant message in the session BEFORE + # yielding any events so it survives a GeneratorExit (client + # disconnect) that interrupts the yield loop at StreamStartStep. + # + # Without this, the sequence is: + # tool result saved → intermediate flush → StreamStartStep + # yield → GeneratorExit → finally saves session with + # last_role=tool (the text response was generated but never + # appended because _dispatch_response(StreamTextDelta) was + # skipped). + # + # We only pre-create when: + # 1. Tool results were received this turn (has_tool_results). + # 2. The prior assistant message is already appended + # (has_appended_assistant) — so this is a post-tool turn. + # 3. This batch contains StreamTextDelta — text IS coming, so + # we won't leave a spurious empty message for tool-only turns. + # + # Subsequent StreamTextDelta dispatches accumulate content into + # acc.assistant_response in-place (ChatMessage is mutable), so + # the DB record is updated without a second append. + if ( + acc.has_tool_results + and acc.has_appended_assistant + and any(isinstance(r, StreamTextDelta) for r in adapter_responses) + ): + acc.assistant_response = ChatMessage(role="assistant", content="") + acc.accumulated_tool_calls = [] + acc.has_tool_results = False + ctx.session.messages.append(acc.assistant_response) + # acc.has_appended_assistant stays True — placeholder is live + # When StreamFinish is in this batch (ResultMessage), flush any # text buffered by the thinking stripper and inject it as a # StreamTextDelta BEFORE the StreamTextEnd so the Vercel AI SDK diff --git a/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py new file mode 100644 index 0000000000..ea7b128927 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py @@ -0,0 +1,217 @@ +"""Tests for the pre-create assistant message logic that prevents +last_role=tool after client disconnect. + +Reproduces the bug where: + 1. Tool result is saved by intermediate flush → last_role=tool + 2. SDK generates a text response + 3. GeneratorExit at StreamStartStep yield (client disconnect) + 4. _dispatch_response(StreamTextDelta) is never called + 5. Session saved with last_role=tool instead of last_role=assistant + +The fix: before yielding any events, pre-create the assistant message in +ctx.session.messages when has_tool_results=True and a StreamTextDelta is +present in adapter_responses. This test verifies the resulting accumulator +state allows correct content accumulation by _dispatch_response. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +from backend.copilot.model import ChatMessage, ChatSession +from backend.copilot.response_model import StreamStartStep, StreamTextDelta +from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator + +_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc) + + +def _make_session() -> ChatSession: + return ChatSession( + session_id="test", + user_id="test-user", + title="test", + messages=[], + usage=[], + started_at=_NOW, + updated_at=_NOW, + ) + + +def _make_ctx(session: ChatSession | None = None) -> MagicMock: + ctx = MagicMock() + ctx.session = session or _make_session() + ctx.log_prefix = "[test]" + return ctx + + +def _make_state() -> MagicMock: + state = MagicMock() + state.transcript_builder = MagicMock() + return state + + +def _simulate_pre_create(acc: _StreamAccumulator, ctx: MagicMock) -> None: + """Mirror the pre-create block from _run_stream_attempt so tests + can verify its effect without invoking the full async generator. + + Keep in sync with the block in service.py _run_stream_attempt + (search: "Pre-create the new assistant message"). + """ + acc.assistant_response = ChatMessage(role="assistant", content="") + acc.accumulated_tool_calls = [] + acc.has_tool_results = False + ctx.session.messages.append(acc.assistant_response) + # acc.has_appended_assistant stays True + + +class TestPreCreateAssistantMessage: + """Verify that the pre-create logic correctly seeds the session message + and that subsequent _dispatch_response(StreamTextDelta) accumulates + content in-place without a double-append.""" + + def test_pre_create_adds_message_to_session(self) -> None: + """After pre-create, session has one assistant message.""" + session = _make_session() + ctx = _make_ctx(session) + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + + _simulate_pre_create(acc, ctx) + + assert len(session.messages) == 1 + assert session.messages[-1].role == "assistant" + assert session.messages[-1].content == "" + + def test_pre_create_resets_tool_result_flag(self) -> None: + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + ctx = _make_ctx() + _simulate_pre_create(acc, ctx) + + assert acc.has_tool_results is False + + def test_pre_create_resets_accumulated_tool_calls(self) -> None: + existing_call = { + "id": "call_1", + "type": "function", + "function": {"name": "bash"}, + } + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[existing_call], + has_appended_assistant=True, + has_tool_results=True, + ) + ctx = _make_ctx() + _simulate_pre_create(acc, ctx) + + assert acc.accumulated_tool_calls == [] + + def test_text_delta_accumulates_in_preexisting_message(self) -> None: + """StreamTextDelta after pre-create updates the already-appended message + in-place — no double-append.""" + session = _make_session() + ctx = _make_ctx(session) + state = _make_state() + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + + _simulate_pre_create(acc, ctx) + assert len(session.messages) == 1 + + # Simulate the first text delta arriving after pre-create + delta = StreamTextDelta(id="t1", delta="Hello world") + _dispatch_response(delta, acc, ctx, state, False, "[test]") + + # Still only one message (no double-append) + assert len(session.messages) == 1 + # Content accumulated in the pre-created message + assert session.messages[-1].content == "Hello world" + assert session.messages[-1].role == "assistant" + + def test_subsequent_deltas_append_to_content(self) -> None: + """Multiple deltas build up the full response text.""" + session = _make_session() + ctx = _make_ctx(session) + state = _make_state() + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + + _simulate_pre_create(acc, ctx) + + for word in ["You're ", "right ", "about ", "that."]: + _dispatch_response( + StreamTextDelta(id="t1", delta=word), acc, ctx, state, False, "[test]" + ) + + assert len(session.messages) == 1 + assert session.messages[-1].content == "You're right about that." + + def test_pre_create_not_triggered_without_tool_results(self) -> None: + """Pre-create condition requires has_tool_results=True; no-op otherwise.""" + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=False, # no prior tool results + ) + ctx = _make_ctx() + + # Condition is False — simulate: do nothing + if acc.has_tool_results and acc.has_appended_assistant: + _simulate_pre_create(acc, ctx) + + assert len(ctx.session.messages) == 0 + + def test_pre_create_not_triggered_when_not_yet_appended(self) -> None: + """Pre-create requires has_appended_assistant=True.""" + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=False, # first turn, nothing appended yet + has_tool_results=True, + ) + ctx = _make_ctx() + + if acc.has_tool_results and acc.has_appended_assistant: + _simulate_pre_create(acc, ctx) + + assert len(ctx.session.messages) == 0 + + def test_pre_create_not_triggered_without_text_delta(self) -> None: + """Pre-create is skipped when adapter_responses has no StreamTextDelta + (e.g. a tool-only batch). Verifies the third guard condition.""" + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + ctx = _make_ctx() + adapter_responses = [StreamStartStep()] # no StreamTextDelta + + if ( + acc.has_tool_results + and acc.has_appended_assistant + and any(isinstance(r, StreamTextDelta) for r in adapter_responses) + ): + _simulate_pre_create(acc, ctx) + + assert len(ctx.session.messages) == 0 From ab3221a2511a0690fcb2013d942ff0819b27a279 Mon Sep 17 00:00:00 2001 From: Nicholas Tindle Date: Wed, 15 Apr 2026 09:40:43 -0500 Subject: [PATCH 02/10] feat(backend): MemoryEnvelope metadata model, scoped retrieval, and memory hardening (#12765) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How **Why:** CoPilot's Graphiti memory system needed structured metadata to distinguish memory types (rules, procedures, facts, preferences), support scoped retrieval, enable targeted deletion, and track memory costs under the AutoPilot billing account separately from the platform. **What:** Adds the MemoryEnvelope metadata model, structured rule/procedure memory types, a derived-finding lane for assistant-distilled knowledge, two-step forget tools, scope-aware retrieval filtering, AutoPilot-dedicated API key routing, and several reliability fixes (streaming socket leaks, event-loop-scoped caches, ingestion hardening). **How:** MemoryEnvelope wraps every stored episode with typed metadata (source_kind, memory_kind, scope, status, confidence) serialized as JSON. Retrieval filters by scope at the context layer. The forget flow uses a search-then-confirm two-step pattern. Ingestion queues and client caches are scoped per event loop via WeakKeyDictionary to prevent cross-loop RuntimeErrors in multi-worker deployments. API key resolution falls back to AutoPilot-dedicated keys (CHAT_API_KEY, CHAT_OPENAI_API_KEY) before platform-wide keys. ### Changes 🏗️ **New: MemoryEnvelope metadata model** (`memory_model.py`) - Typed memory categories: fact, preference, rule, finding, plan, event, procedure - Source tracking: user_asserted, assistant_derived, tool_observed - Scope namespacing: `real:global`, `project:`, `book:`, `session:<id>` - Status lifecycle: active, tentative, superseded, contradicted - Structured `RuleMemory` and `ProcedureMemory` models for complex instructions **New: Targeted forget tools** (`graphiti_forget.py`) - `memory_forget_search`: returns candidate facts with UUIDs for user confirmation - `memory_forget_confirm`: deletes specific edges by UUID after confirmation **New: Architecture test** (`architecture_test.py`) - Validates no new `@cached(...)` usage around event-loop-bound async clients - Allowlists pre-existing violations for future cleanup **Enhanced: memory_store tool** (`graphiti_store.py`) - Accepts MemoryEnvelope metadata fields (source_kind, scope, memory_kind, rule, procedure) - Wraps content in MemoryEnvelope before ingestion **Enhanced: memory_search tool** (`graphiti_search.py`) - Scope-aware retrieval with hard filtering on group_id **Enhanced: Ingestion pipeline** (`ingest.py`) - Derived-finding lane: distills substantive assistant responses into tentative findings - Event-loop-scoped queues and workers via WeakKeyDictionary (fixes multi-worker RuntimeError) - Improved error handling and dropped-episode reporting **Enhanced: Client cache** (`client.py`) - Per-loop client cache and lock via WeakKeyDictionary (fixes "Future attached to a different loop") **Enhanced: Warm context** (`context.py`) - Filters out non-global-scope episodes from warm context **Fix: Streaming socket leak** (`baseline/service.py`) - try/finally around async stream iteration to release httpx connections on early exit **Config: AutoPilot key routing** (`config.py`, `.env.default`) - LLM key fallback: GRAPHITI_LLM_API_KEY → CHAT_API_KEY → OPEN_ROUTER_API_KEY - Embedder key fallback: GRAPHITI_EMBEDDER_API_KEY → CHAT_OPENAI_API_KEY → OPENAI_API_KEY - Backwards-compatible: existing behavior unchanged until new keys are provisioned ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] `poetry run pytest backend/copilot/graphiti/config_test.py` — 16 tests pass (key fallback priority) - [x] `poetry run pytest backend/copilot/tools/graphiti_store_test.py` — store envelope tests pass - [x] `poetry run pytest backend/copilot/graphiti/ingest_test.py` — ingestion tests pass - [x] `poetry run pytest backend/util/architecture_test.py` — structural validation passes - [x] Verify memory store/retrieve/forget cycle via copilot chat - [x] Run AgentProbe multi-session memory benchmark (31 scenarios x3 repeats) - [x] Confirm no CLOSE_WAIT socket accumulation under sustained streaming load - [x] Verify multi-worker deployment doesn't produce loop-binding errors #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - Configuration changes: - New optional env var `CHAT_OPENAI_API_KEY` — AutoPilot-dedicated OpenAI key for Graphiti embeddings (falls back to `OPENAI_API_KEY` if not set) - `CHAT_API_KEY` now used as first fallback for Graphiti LLM calls (was `OPEN_ROUTER_API_KEY`) - Infra action needed: add `CHAT_OPENAI_API_KEY` sealed secret in `autogpt-shared-config` values (dev + prod) 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- CURSOR_SUMMARY --> --- > [!NOTE] > **Medium Risk** > Touches Graphiti memory ingestion/retrieval and introduces hard-delete capabilities plus event-loop–scoped caching/queues; failures could affect memory correctness or delete the wrong edges. Also changes streaming resource cleanup and key routing, which could surface as connection or billing/cost attribution issues if misconfigured. > > **Overview** > **Graphiti memory is upgraded from plain text episodes to a structured JSON `MemoryEnvelope`.** `memory_store` now wraps content with typed metadata (source, kind, scope, status) and optional structured `rule`/`procedure` payloads, and ingestion supports JSON episodes. > > **Memory retrieval and lifecycle controls are expanded.** `memory_search` adds optional scope hard-filtering to prevent cross-scope leakage, warm-context formatting drops non-global scoped episodes (and avoids empty wrappers), and new two-step tools (`memory_forget_search` → `memory_forget_confirm`) enable targeted soft- or hard-deletion of specific graph edges by UUID. > > **Reliability and multi-worker safety improvements.** Graphiti client caching and ingestion worker registries are now per-event-loop (avoiding cross-loop `Future` errors), streaming chat completions explicitly close async streams to prevent `CLOSE_WAIT` socket leaks, warm-context is injected into the first user message to keep the system prompt cacheable, and a new `architecture_test.py` blocks future process-wide caching of event-loop–bound async clients. Config updates route Graphiti LLM/embedder keys to AutoPilot-specific env vars first, and OpenAPI schema exports include the new memory response types. > > <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit 5fb4bd0a43ac2a6d7a5c9dcd0ea97834547538cf. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- autogpt_platform/backend/.env.default | 3 +- .../backend/api/features/chat/routes.py | 8 + .../backend/copilot/baseline/service.py | 140 ++++--- .../backend/copilot/graphiti/_format.py | 17 +- .../backend/copilot/graphiti/client.py | 47 ++- .../backend/copilot/graphiti/config.py | 18 +- .../backend/copilot/graphiti/config_test.py | 22 +- .../backend/copilot/graphiti/context.py | 36 +- .../backend/copilot/graphiti/context_test.py | 214 ++++++++++- .../backend/copilot/graphiti/ingest.py | 154 +++++++- .../backend/copilot/graphiti/ingest_test.py | 174 +++++++-- .../backend/copilot/graphiti/memory_model.py | 118 ++++++ .../backend/backend/copilot/permissions.py | 2 + .../backend/backend/copilot/sdk/service.py | 21 +- .../backend/backend/copilot/tools/__init__.py | 3 + .../backend/copilot/tools/graphiti_forget.py | 349 ++++++++++++++++++ .../copilot/tools/graphiti_forget_test.py | 77 ++++ .../backend/copilot/tools/graphiti_search.py | 55 ++- .../copilot/tools/graphiti_search_test.py | 64 ++++ .../backend/copilot/tools/graphiti_store.py | 147 +++++++- .../copilot/tools/graphiti_store_test.py | 152 +++++++- .../backend/backend/copilot/tools/models.py | 17 + .../backend/backend/util/architecture_test.py | 134 +++++++ .../frontend/src/app/api/openapi.json | 111 +++++- 24 files changed, 1946 insertions(+), 137 deletions(-) create mode 100644 autogpt_platform/backend/backend/copilot/graphiti/memory_model.py create mode 100644 autogpt_platform/backend/backend/copilot/tools/graphiti_forget.py create mode 100644 autogpt_platform/backend/backend/copilot/tools/graphiti_forget_test.py create mode 100644 autogpt_platform/backend/backend/copilot/tools/graphiti_search_test.py create mode 100644 autogpt_platform/backend/backend/util/architecture_test.py diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index c01da95a03..e731f9f9bf 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -60,7 +60,8 @@ NVIDIA_API_KEY= # Graphiti Temporal Knowledge Graph Memory # Rollout controlled by LaunchDarkly flag "graphiti-memory" -# LLM/embedder keys fall back to OPEN_ROUTER_API_KEY and OPENAI_API_KEY when empty. +# LLM key falls back to CHAT_API_KEY (AutoPilot), then OPEN_ROUTER_API_KEY. +# Embedder key falls back to CHAT_OPENAI_API_KEY (AutoPilot), then OPENAI_API_KEY. GRAPHITI_FALKORDB_HOST=localhost GRAPHITI_FALKORDB_PORT=6380 GRAPHITI_FALKORDB_PASSWORD= diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 7d0521cb81..7496c214ac 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -62,6 +62,10 @@ from backend.copilot.tools.models import ( InputValidationErrorResponse, MCPToolOutputResponse, MCPToolsDiscoveredResponse, + MemoryForgetCandidatesResponse, + MemoryForgetConfirmResponse, + MemorySearchResponse, + MemoryStoreResponse, NeedLoginResponse, NoResultsResponse, SetupRequirementsResponse, @@ -1365,6 +1369,10 @@ ToolResponseUnion = ( | DocPageResponse | MCPToolsDiscoveredResponse | MCPToolOutputResponse + | MemoryStoreResponse + | MemorySearchResponse + | MemoryForgetCandidatesResponse + | MemoryForgetConfirmResponse ) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index bb3906811c..dd6aa121b6 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -293,56 +293,69 @@ async def _baseline_llm_caller( ) tool_calls_by_index: dict[int, dict[str, str]] = {} - async for chunk in response: - if chunk.usage: - state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 - state.turn_completion_tokens += chunk.usage.completion_tokens or 0 - # Extract cache token details when available (OpenAI / - # OpenRouter include these in prompt_tokens_details). - ptd = getattr(chunk.usage, "prompt_tokens_details", None) - if ptd: - state.turn_cache_read_tokens += ( - getattr(ptd, "cached_tokens", 0) or 0 - ) - # cache_creation_input_tokens is reported by some providers - # (e.g. Anthropic native) but not standard OpenAI streaming. - state.turn_cache_creation_tokens += ( - getattr(ptd, "cache_creation_input_tokens", 0) or 0 - ) - - delta = chunk.choices[0].delta if chunk.choices else None - if not delta: - continue - - if delta.content: - emit = state.thinking_stripper.process(delta.content) - if emit: - if not state.text_started: - state.pending_events.append( - StreamTextStart(id=state.text_block_id) + # Iterate under an inner try/finally so early exits (cancel, tool-call + # break, exception) always release the underlying httpx connection. + # Without this, openai.AsyncStream leaks the streaming response and + # the TCP socket ends up in CLOSE_WAIT until the process exits. + try: + async for chunk in response: + if chunk.usage: + state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 + state.turn_completion_tokens += chunk.usage.completion_tokens or 0 + # Extract cache token details when available (OpenAI / + # OpenRouter include these in prompt_tokens_details). + ptd = getattr(chunk.usage, "prompt_tokens_details", None) + if ptd: + state.turn_cache_read_tokens += ( + getattr(ptd, "cached_tokens", 0) or 0 + ) + # cache_creation_input_tokens is reported by some providers + # (e.g. Anthropic native) but not standard OpenAI streaming. + state.turn_cache_creation_tokens += ( + getattr(ptd, "cache_creation_input_tokens", 0) or 0 ) - state.text_started = True - round_text += emit - state.pending_events.append( - StreamTextDelta(id=state.text_block_id, delta=emit) - ) - if delta.tool_calls: - for tc in delta.tool_calls: - idx = tc.index - if idx not in tool_calls_by_index: - tool_calls_by_index[idx] = { - "id": "", - "name": "", - "arguments": "", - } - entry = tool_calls_by_index[idx] - if tc.id: - entry["id"] = tc.id - if tc.function and tc.function.name: - entry["name"] = tc.function.name - if tc.function and tc.function.arguments: - entry["arguments"] += tc.function.arguments + delta = chunk.choices[0].delta if chunk.choices else None + if not delta: + continue + + if delta.content: + emit = state.thinking_stripper.process(delta.content) + if emit: + if not state.text_started: + state.pending_events.append( + StreamTextStart(id=state.text_block_id) + ) + state.text_started = True + round_text += emit + state.pending_events.append( + StreamTextDelta(id=state.text_block_id, delta=emit) + ) + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": "", + "name": "", + "arguments": "", + } + entry = tool_calls_by_index[idx] + if tc.id: + entry["id"] = tc.id + if tc.function and tc.function.name: + entry["name"] = tc.function.name + if tc.function and tc.function.arguments: + entry["arguments"] += tc.function.arguments + finally: + # Release the streaming httpx connection back to the pool on every + # exit path (normal completion, break, exception). openai.AsyncStream + # does not auto-close when the async-for loop exits early. + try: + await response.close() + except Exception: + pass # Flush any buffered text held back by the thinking stripper. tail = state.thinking_stripper.flush() @@ -940,13 +953,14 @@ async def stream_chat_completion_baseline( graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement - # Warm context: pre-load relevant facts from Graphiti on first turn + # Warm context: pre-load relevant facts from Graphiti on first turn. + # Stored here but injected into the user message (not the system prompt) + # after openai_messages is built — keeps system prompt static for caching. + warm_ctx: str | None = None if graphiti_enabled and user_id and len(session.messages) <= 1: from backend.copilot.graphiti.context import fetch_warm_context warm_ctx = await fetch_warm_context(user_id, message or "") - if warm_ctx: - system_prompt += f"\n\n{warm_ctx}" # Compress context if approaching the model's token limit messages_for_context = await _compress_session_messages( @@ -996,6 +1010,20 @@ async def stream_chat_completion_baseline( else: logger.warning("[Baseline] No user message found for context injection") + # Inject Graphiti warm context into the first user message (not the + # system prompt) so the system prompt stays static and cacheable. + # warm_ctx is already wrapped in <temporal_context>. + # Appended AFTER user_context so <user_context> stays at the very start. + if warm_ctx: + for msg in openai_messages: + if msg["role"] == "user": + existing = msg.get("content", "") + if isinstance(existing, str): + msg["content"] = f"{existing}\n\n{warm_ctx}" + break + # Do NOT append warm_ctx to user_message_for_transcript — it would + # persist stale temporal context into the transcript for future turns. + # Append user message to transcript. # Always append when the message is present and is from the user, # even on duplicate-suppressed retries (is_new_message=False). @@ -1253,8 +1281,16 @@ async def stream_chat_completion_baseline( if graphiti_enabled and user_id and message and is_user_message: from backend.copilot.graphiti.ingest import enqueue_conversation_turn + # Pass only the final assistant reply (after stripping tool-loop + # chatter) so derived-finding distillation sees the substantive + # response, not intermediate tool-planning text. _ingest_task = asyncio.create_task( - enqueue_conversation_turn(user_id, session_id, message) + enqueue_conversation_turn( + user_id, + session_id, + message, + assistant_msg=final_text if state else "", + ) ) _background_tasks.add(_ingest_task) _ingest_task.add_done_callback(_background_tasks.discard) diff --git a/autogpt_platform/backend/backend/copilot/graphiti/_format.py b/autogpt_platform/backend/backend/copilot/graphiti/_format.py index fb4a93e393..c6975c5c39 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/_format.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/_format.py @@ -18,15 +18,24 @@ def extract_temporal_validity(edge) -> tuple[str, str]: return str(valid_from), str(valid_to) -def extract_episode_body(episode, max_len: int = 500) -> str: - """Extract the body text from an episode object, truncated to *max_len*.""" - body = str( +def extract_episode_body_raw(episode) -> str: + """Extract the full body text from an episode object (no truncation). + + Use this when the body needs to be parsed as JSON (e.g. scope filtering + on MemoryEnvelope payloads). For display purposes, use + ``extract_episode_body()`` which truncates. + """ + return str( getattr(episode, "content", None) or getattr(episode, "body", None) or getattr(episode, "episode_body", None) or "" ) - return body[:max_len] + + +def extract_episode_body(episode, max_len: int = 500) -> str: + """Extract the body text from an episode object, truncated to *max_len*.""" + return extract_episode_body_raw(episode)[:max_len] def extract_episode_timestamp(episode) -> str: diff --git a/autogpt_platform/backend/backend/copilot/graphiti/client.py b/autogpt_platform/backend/backend/copilot/graphiti/client.py index 9710354915..65fcdb3abb 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/client.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/client.py @@ -3,6 +3,7 @@ import asyncio import logging import re +import weakref from cachetools import TTLCache @@ -13,8 +14,36 @@ logger = logging.getLogger(__name__) _GROUP_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") _MAX_GROUP_ID_LEN = 128 -_client_cache: TTLCache | None = None -_cache_lock = asyncio.Lock() + +# Graphiti clients wrap redis.asyncio connections whose internal Futures are +# pinned to the event loop they were first used on. The CoPilot executor runs +# one asyncio loop per worker thread, so a process-wide client cache would +# hand a loop-1-bound connection to a task running on loop 2 → RuntimeError +# "got Future attached to a different loop". Scope the cache (and its lock) +# per running loop so each loop gets its own clients. +class _LoopState: + __slots__ = ("cache", "lock") + + def __init__(self) -> None: + self.cache: TTLCache = _EvictingTTLCache( + maxsize=graphiti_config.client_cache_maxsize, + ttl=graphiti_config.client_cache_ttl, + ) + self.lock = asyncio.Lock() + + +_loop_state: "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _LoopState]" = ( + weakref.WeakKeyDictionary() +) + + +def _get_loop_state() -> _LoopState: + loop = asyncio.get_running_loop() + state = _loop_state.get(loop) + if state is None: + state = _LoopState() + _loop_state[loop] = state + return state def derive_group_id(user_id: str) -> str: @@ -88,13 +117,8 @@ class _EvictingTTLCache(TTLCache): def _get_cache() -> TTLCache: - global _client_cache - if _client_cache is None: - _client_cache = _EvictingTTLCache( - maxsize=graphiti_config.client_cache_maxsize, - ttl=graphiti_config.client_cache_ttl, - ) - return _client_cache + """Return the client cache for the current running event loop.""" + return _get_loop_state().cache async def get_graphiti_client(group_id: str): @@ -113,9 +137,10 @@ async def get_graphiti_client(group_id: str): from .falkordb_driver import AutoGPTFalkorDriver - cache = _get_cache() + state = _get_loop_state() + cache = state.cache - async with _cache_lock: + async with state.lock: if group_id in cache: return cache[group_id] diff --git a/autogpt_platform/backend/backend/copilot/graphiti/config.py b/autogpt_platform/backend/backend/copilot/graphiti/config.py index 94a452165a..08b533b6fc 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/config.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/config.py @@ -20,8 +20,10 @@ class GraphitiConfig(BaseSettings): """Configuration for Graphiti memory integration. All fields use the ``GRAPHITI_`` env-var prefix, e.g. ``GRAPHITI_ENABLED``. - LLM/embedder keys fall back to the platform-wide OpenRouter and OpenAI keys - when left empty so that operators don't need to manage separate credentials. + LLM/embedder keys fall back to the AutoPilot-dedicated keys + (``CHAT_API_KEY`` / ``CHAT_OPENAI_API_KEY``) so that memory costs are + tracked under AutoPilot, then to the platform-wide OpenRouter / OpenAI + keys as a last resort. """ model_config = SettingsConfigDict(env_prefix="GRAPHITI_", extra="allow") @@ -42,7 +44,7 @@ class GraphitiConfig(BaseSettings): ) llm_api_key: str = Field( default="", - description="API key for LLM — empty falls back to OPEN_ROUTER_API_KEY", + description="API key for LLM — empty falls back to CHAT_API_KEY, then OPEN_ROUTER_API_KEY", ) # Embedder (separate from LLM — embeddings go direct to OpenAI) @@ -53,7 +55,7 @@ class GraphitiConfig(BaseSettings): ) embedder_api_key: str = Field( default="", - description="API key for embedder — empty falls back to OPENAI_API_KEY", + description="API key for embedder — empty falls back to CHAT_OPENAI_API_KEY, then OPENAI_API_KEY", ) # Concurrency @@ -96,7 +98,9 @@ class GraphitiConfig(BaseSettings): def resolve_llm_api_key(self) -> str: if self.llm_api_key: return self.llm_api_key - return os.getenv("OPEN_ROUTER_API_KEY", "") + # Prefer the AutoPilot-dedicated key so memory costs are tracked + # separately from the platform-wide OpenRouter key. + return os.getenv("CHAT_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY", "") def resolve_llm_base_url(self) -> str: if self.llm_base_url: @@ -106,7 +110,9 @@ class GraphitiConfig(BaseSettings): def resolve_embedder_api_key(self) -> str: if self.embedder_api_key: return self.embedder_api_key - return os.getenv("OPENAI_API_KEY", "") + # Prefer the AutoPilot-dedicated OpenAI key so memory costs are + # tracked separately from the platform-wide OpenAI key. + return os.getenv("CHAT_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY", "") def resolve_embedder_base_url(self) -> str | None: if self.embedder_base_url: diff --git a/autogpt_platform/backend/backend/copilot/graphiti/config_test.py b/autogpt_platform/backend/backend/copilot/graphiti/config_test.py index 7c7a90d7bc..efe36c8586 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/config_test.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/config_test.py @@ -8,6 +8,8 @@ _ENV_VARS_TO_CLEAR = ( "GRAPHITI_FALKORDB_HOST", "GRAPHITI_FALKORDB_PORT", "GRAPHITI_FALKORDB_PASSWORD", + "CHAT_API_KEY", + "CHAT_OPENAI_API_KEY", "OPEN_ROUTER_API_KEY", "OPENAI_API_KEY", ) @@ -31,7 +33,15 @@ class TestResolveLlmApiKey: cfg = GraphitiConfig(llm_api_key="my-llm-key") assert cfg.resolve_llm_api_key() == "my-llm-key" - def test_falls_back_to_open_router_env( + def test_falls_back_to_chat_api_key_first( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CHAT_API_KEY", "autopilot-key") + monkeypatch.setenv("OPEN_ROUTER_API_KEY", "platform-key") + cfg = GraphitiConfig(llm_api_key="") + assert cfg.resolve_llm_api_key() == "autopilot-key" + + def test_falls_back_to_open_router_when_no_chat_key( self, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setenv("OPEN_ROUTER_API_KEY", "fallback-router-key") @@ -59,7 +69,15 @@ class TestResolveEmbedderApiKey: cfg = GraphitiConfig(embedder_api_key="my-embedder-key") assert cfg.resolve_embedder_api_key() == "my-embedder-key" - def test_falls_back_to_openai_api_key_env( + def test_falls_back_to_chat_openai_api_key_first( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CHAT_OPENAI_API_KEY", "autopilot-openai-key") + monkeypatch.setenv("OPENAI_API_KEY", "platform-openai-key") + cfg = GraphitiConfig(embedder_api_key="") + assert cfg.resolve_embedder_api_key() == "autopilot-openai-key" + + def test_falls_back_to_openai_when_no_chat_openai_key( self, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setenv("OPENAI_API_KEY", "fallback-openai-key") diff --git a/autogpt_platform/backend/backend/copilot/graphiti/context.py b/autogpt_platform/backend/backend/copilot/graphiti/context.py index 46f9855ab7..29d4e95f47 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/context.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/context.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from ._format import ( extract_episode_body, + extract_episode_body_raw, extract_episode_timestamp, extract_fact, extract_temporal_validity, @@ -68,7 +69,7 @@ async def _fetch(user_id: str, message: str) -> str | None: return _format_context(edges, episodes) -def _format_context(edges, episodes) -> str: +def _format_context(edges, episodes) -> str | None: sections: list[str] = [] if edges: @@ -82,12 +83,35 @@ def _format_context(edges, episodes) -> str: if episodes: ep_lines = [] for ep in episodes: + # Use raw body (no truncation) for scope parsing — truncated + # JSON from extract_episode_body() would fail json.loads(). + raw_body = extract_episode_body_raw(ep) + if _is_non_global_scope(raw_body): + continue + display_body = extract_episode_body(ep) ts = extract_episode_timestamp(ep) - body = extract_episode_body(ep) - ep_lines.append(f" - [{ts}] {body}") - sections.append( - "<RECENT_EPISODES>\n" + "\n".join(ep_lines) + "\n</RECENT_EPISODES>" - ) + ep_lines.append(f" - [{ts}] {display_body}") + if ep_lines: + sections.append( + "<RECENT_EPISODES>\n" + "\n".join(ep_lines) + "\n</RECENT_EPISODES>" + ) + + if not sections: + return None body = "\n\n".join(sections) return f"<temporal_context>\n{body}\n</temporal_context>" + + +def _is_non_global_scope(body: str) -> bool: + """Check if an episode body is a MemoryEnvelope with a non-global scope.""" + import json + + try: + data = json.loads(body) + if not isinstance(data, dict): + return False + scope = data.get("scope", "real:global") + return scope != "real:global" + except (json.JSONDecodeError, TypeError): + return False diff --git a/autogpt_platform/backend/backend/copilot/graphiti/context_test.py b/autogpt_platform/backend/backend/copilot/graphiti/context_test.py index 616fefa218..ce419b11ff 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/context_test.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/context_test.py @@ -1,12 +1,15 @@ """Tests for Graphiti warm context retrieval.""" import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, patch import pytest from . import context -from .context import fetch_warm_context +from ._format import extract_episode_body +from .context import _format_context, _is_non_global_scope, fetch_warm_context +from .memory_model import MemoryEnvelope, MemoryKind, SourceKind class TestFetchWarmContextEmptyUserId: @@ -52,3 +55,212 @@ class TestFetchWarmContextGeneralError: result = await fetch_warm_context("abc", "hello") assert result is None + + +# --------------------------------------------------------------------------- +# Bug: extract_episode_body() truncation breaks scope filtering +# --------------------------------------------------------------------------- + + +class TestFetchInternal: + """Test the internal _fetch function with mocked graphiti client.""" + + @pytest.mark.asyncio + async def test_returns_none_when_no_edges_or_episodes(self) -> None: + mock_client = AsyncMock() + mock_client.search.return_value = [] + mock_client.retrieve_episodes.return_value = [] + + with ( + patch.object(context, "derive_group_id", return_value="user_abc"), + patch.object( + context, + "get_graphiti_client", + new_callable=AsyncMock, + return_value=mock_client, + ), + ): + result = await context._fetch("test-user", "hello") + + assert result is None + + @pytest.mark.asyncio + async def test_returns_context_with_edges(self) -> None: + edge = SimpleNamespace( + fact="user likes python", + name="preference", + valid_at="2025-01-01", + invalid_at=None, + ) + mock_client = AsyncMock() + mock_client.search.return_value = [edge] + mock_client.retrieve_episodes.return_value = [] + + with ( + patch.object(context, "derive_group_id", return_value="user_abc"), + patch.object( + context, + "get_graphiti_client", + new_callable=AsyncMock, + return_value=mock_client, + ), + ): + result = await context._fetch("test-user", "hello") + + assert result is not None + assert "<temporal_context>" in result + assert "user likes python" in result + + @pytest.mark.asyncio + async def test_returns_context_with_episodes(self) -> None: + ep = SimpleNamespace( + content="talked about coffee", + created_at="2025-06-01T00:00:00Z", + ) + mock_client = AsyncMock() + mock_client.search.return_value = [] + mock_client.retrieve_episodes.return_value = [ep] + + with ( + patch.object(context, "derive_group_id", return_value="user_abc"), + patch.object( + context, + "get_graphiti_client", + new_callable=AsyncMock, + return_value=mock_client, + ), + ): + result = await context._fetch("test-user", "hello") + + assert result is not None + assert "talked about coffee" in result + + +class TestFormatContextWithContent: + """Test _format_context with actual edges and episodes.""" + + def test_with_edges_only(self) -> None: + edge = SimpleNamespace( + fact="user likes coffee", + name="preference", + valid_at="2025-01-01", + invalid_at="present", + ) + result = _format_context(edges=[edge], episodes=[]) + assert result is not None + assert "<FACTS>" in result + assert "user likes coffee" in result + assert "<temporal_context>" in result + + def test_with_episodes_only(self) -> None: + ep = SimpleNamespace( + content="plain conversation text", + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is not None + assert "<RECENT_EPISODES>" in result + assert "plain conversation text" in result + + def test_with_both_edges_and_episodes(self) -> None: + edge = SimpleNamespace( + fact="user likes coffee", + valid_at="2025-01-01", + invalid_at=None, + ) + ep = SimpleNamespace( + content="talked about coffee", + created_at="2025-06-01T00:00:00Z", + ) + result = _format_context(edges=[edge], episodes=[ep]) + assert result is not None + assert "<FACTS>" in result + assert "<RECENT_EPISODES>" in result + + def test_global_scope_episode_included(self) -> None: + envelope = MemoryEnvelope(content="global note", scope="real:global") + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is not None + assert "<RECENT_EPISODES>" in result + + def test_non_global_scope_episode_excluded(self) -> None: + envelope = MemoryEnvelope(content="project note", scope="project:crm") + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is None + + +class TestIsNonGlobalScopeEdgeCases: + """Verify _is_non_global_scope handles non-dict JSON without crashing.""" + + def test_list_json_treated_as_global(self) -> None: + assert _is_non_global_scope("[1, 2, 3]") is False + + def test_string_json_treated_as_global(self) -> None: + assert _is_non_global_scope('"just a string"') is False + + def test_null_json_treated_as_global(self) -> None: + assert _is_non_global_scope("null") is False + + def test_plain_text_treated_as_global(self) -> None: + assert _is_non_global_scope("plain conversation text") is False + + +class TestIsNonGlobalScopeTruncation: + """Verify _is_non_global_scope handles long MemoryEnvelope JSON. + + extract_episode_body() truncates to 500 chars. A MemoryEnvelope with + a long content field serializes to >500 chars, so the truncated string + is invalid JSON. The except clause falls through to return False, + incorrectly treating a project-scoped episode as global. + """ + + def test_long_envelope_with_non_global_scope_detected(self) -> None: + """Long MemoryEnvelope JSON should be parsed with raw (untruncated) body.""" + envelope = MemoryEnvelope( + content="x" * 600, + source_kind=SourceKind.user_asserted, + scope="project:crm", + memory_kind=MemoryKind.fact, + ) + full_json = envelope.model_dump_json() + assert len(full_json) > 500, "precondition: JSON must exceed truncation limit" + + # With the fix: _is_non_global_scope on the raw (untruncated) body + # correctly detects the non-global scope. + assert _is_non_global_scope(full_json) is True + + # Truncated body still fails — that's expected; callers must use raw body. + ep = SimpleNamespace(content=full_json) + truncated = extract_episode_body(ep) + assert _is_non_global_scope(truncated) is False # truncated JSON → parse fails + + +# --------------------------------------------------------------------------- +# Bug: empty <temporal_context> wrapper when all episodes are non-global +# --------------------------------------------------------------------------- + + +class TestFormatContextEmptyWrapper: + """When all episodes are non-global and edges is empty, _format_context + should return None (no useful content) instead of an empty XML wrapper. + """ + + def test_returns_none_when_all_episodes_filtered(self) -> None: + envelope = MemoryEnvelope( + content="project-only note", + scope="project:crm", + ) + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is None diff --git a/autogpt_platform/backend/backend/copilot/graphiti/ingest.py b/autogpt_platform/backend/backend/copilot/graphiti/ingest.py index e36f521a35..58d086e55c 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/ingest.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/ingest.py @@ -7,17 +7,45 @@ ingestion while keeping it fire-and-forget from the caller's perspective. import asyncio import logging +import weakref from datetime import datetime, timezone from graphiti_core.nodes import EpisodeType from .client import derive_group_id, get_graphiti_client +from .memory_model import MemoryEnvelope, MemoryKind, MemoryStatus, SourceKind logger = logging.getLogger(__name__) -_user_queues: dict[str, asyncio.Queue] = {} -_user_workers: dict[str, asyncio.Task] = {} -_workers_lock = asyncio.Lock() + +# The CoPilot executor runs one asyncio loop per worker thread, and +# asyncio.Queue / asyncio.Lock / asyncio.Task are all bound to the loop they +# were first used on. A process-wide worker registry would hand a loop-1-bound +# Queue to a coroutine running on loop 2 → RuntimeError "Future attached to a +# different loop". Scope the registry per running loop so each loop has its +# own queues, workers, and lock. Entries auto-clean when the loop is GC'd. +class _LoopIngestState: + __slots__ = ("user_queues", "user_workers", "workers_lock") + + def __init__(self) -> None: + self.user_queues: dict[str, asyncio.Queue] = {} + self.user_workers: dict[str, asyncio.Task] = {} + self.workers_lock = asyncio.Lock() + + +_loop_state: ( + "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _LoopIngestState]" +) = weakref.WeakKeyDictionary() + + +def _get_loop_state() -> _LoopIngestState: + loop = asyncio.get_running_loop() + state = _loop_state.get(loop) + if state is None: + state = _LoopIngestState() + _loop_state[loop] = state + return state + # Idle workers are cleaned up after this many seconds of inactivity. _WORKER_IDLE_TIMEOUT = 60 @@ -37,6 +65,10 @@ async def _ingestion_worker(user_id: str, queue: asyncio.Queue) -> None: Exits after ``_WORKER_IDLE_TIMEOUT`` seconds of inactivity so that idle workers don't leak memory indefinitely. """ + # Snapshot the loop-local state at task start so cleanup always runs + # against the same state dict the worker was registered in, even if the + # worker is cancelled from another task. + state = _get_loop_state() try: while True: try: @@ -63,20 +95,25 @@ async def _ingestion_worker(user_id: str, queue: asyncio.Queue) -> None: raise finally: # Clean up so the next message re-creates the worker. - _user_queues.pop(user_id, None) - _user_workers.pop(user_id, None) + state.user_queues.pop(user_id, None) + state.user_workers.pop(user_id, None) async def enqueue_conversation_turn( user_id: str, session_id: str, user_msg: str, + assistant_msg: str = "", ) -> None: """Enqueue a conversation turn for async background ingestion. This returns almost immediately — the actual graphiti-core ``add_episode()`` call (which triggers LLM entity extraction) runs in a background worker task. + + If ``assistant_msg`` is provided and contains substantive findings + (not just acknowledgments), a separate derived-finding episode is + queued with ``source_kind=assistant_derived`` and ``status=tentative``. """ if not user_id: return @@ -117,6 +154,35 @@ async def enqueue_conversation_turn( "Graphiti ingestion queue full for user %s — dropping episode", user_id[:12], ) + return + + # --- Derived-finding lane --- + # If the assistant response is substantive, distill it into a + # structured finding with tentative status. + if assistant_msg and _is_finding_worthy(assistant_msg): + finding = _distill_finding(assistant_msg) + if finding: + envelope = MemoryEnvelope( + content=finding, + source_kind=SourceKind.assistant_derived, + memory_kind=MemoryKind.finding, + status=MemoryStatus.tentative, + provenance=f"session:{session_id}", + ) + try: + queue.put_nowait( + { + "name": f"finding_{session_id}", + "episode_body": envelope.model_dump_json(), + "source": EpisodeType.json, + "source_description": f"Assistant-derived finding in session {session_id}", + "reference_time": datetime.now(timezone.utc), + "group_id": group_id, + "custom_extraction_instructions": CUSTOM_EXTRACTION_INSTRUCTIONS, + } + ) + except asyncio.QueueFull: + pass # user canonical episode already queued — finding is best-effort async def enqueue_episode( @@ -126,12 +192,18 @@ async def enqueue_episode( name: str, episode_body: str, source_description: str = "Conversation memory", + is_json: bool = False, ) -> bool: """Enqueue an arbitrary episode for background ingestion. Used by ``MemoryStoreTool`` so that explicit memory-store calls go through the same per-user serialization queue as conversation turns. + Args: + is_json: When ``True``, ingest as ``EpisodeType.json`` (for + structured ``MemoryEnvelope`` payloads). Otherwise uses + ``EpisodeType.text``. + Returns ``True`` if the episode was queued, ``False`` if it was dropped. """ if not user_id: @@ -145,12 +217,14 @@ async def enqueue_episode( queue = await _ensure_worker(user_id) + source = EpisodeType.json if is_json else EpisodeType.text + try: queue.put_nowait( { "name": name, "episode_body": episode_body, - "source": EpisodeType.text, + "source": source, "source_description": source_description, "reference_time": datetime.now(timezone.utc), "group_id": group_id, @@ -170,18 +244,19 @@ async def _ensure_worker(user_id: str) -> asyncio.Queue: """Create a queue and worker for *user_id* if one doesn't exist. Returns the queue directly so callers don't need to look it up from - ``_user_queues`` (which avoids a TOCTOU race if the worker times out + the state dict (which avoids a TOCTOU race if the worker times out and cleans up between this call and the put_nowait). """ - async with _workers_lock: - if user_id not in _user_queues: + state = _get_loop_state() + async with state.workers_lock: + if user_id not in state.user_queues: q: asyncio.Queue = asyncio.Queue(maxsize=100) - _user_queues[user_id] = q - _user_workers[user_id] = asyncio.create_task( + state.user_queues[user_id] = q + state.user_workers[user_id] = asyncio.create_task( _ingestion_worker(user_id, q), name=f"graphiti-ingest-{user_id[:12]}", ) - return _user_queues[user_id] + return state.user_queues[user_id] async def _resolve_user_name(user_id: str) -> str: @@ -195,3 +270,58 @@ async def _resolve_user_name(user_id: str) -> str: except Exception: logger.debug("Could not resolve user name for %s", user_id[:12]) return "User" + + +# --- Derived-finding distillation --- + +# Phrases that indicate workflow chatter, not substantive findings. +_CHATTER_PREFIXES = ( + "done", + "got it", + "sure, i", + "sure!", + "ok", + "okay", + "i've created", + "i've updated", + "i've sent", + "i'll ", + "let me ", + "a sign-in button", + "please click", +) + +# Minimum length for an assistant message to be considered finding-worthy. +_MIN_FINDING_LENGTH = 150 + + +def _is_finding_worthy(assistant_msg: str) -> bool: + """Heuristic gate: is this assistant response worth distilling into a finding? + + Skips short acknowledgments, workflow chatter, and UI prompts. + Only passes through responses that likely contain substantive + factual content (research results, analysis, conclusions). + """ + if len(assistant_msg) < _MIN_FINDING_LENGTH: + return False + + lower = assistant_msg.lower().strip() + for prefix in _CHATTER_PREFIXES: + if lower.startswith(prefix): + return False + + return True + + +def _distill_finding(assistant_msg: str) -> str | None: + """Extract the core finding from an assistant response. + + For now, uses a simple truncation approach. Phase 3+ could use + a lightweight LLM call for proper distillation. + """ + # Take the first 500 chars as the finding content. + # Strip markdown formatting artifacts. + content = assistant_msg.strip() + if len(content) > 500: + content = content[:500] + "..." + return content if content else None diff --git a/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py b/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py index 3aebd283a5..6cb9c5fbaf 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py @@ -8,21 +8,9 @@ import pytest from . import ingest - -def _clean_module_state() -> None: - """Reset module-level state to avoid cross-test contamination.""" - ingest._user_queues.clear() - ingest._user_workers.clear() - - -@pytest.fixture(autouse=True) -def _reset_state(): - _clean_module_state() - yield - # Cancel any lingering worker tasks. - for task in ingest._user_workers.values(): - task.cancel() - _clean_module_state() +# Per-loop state in ingest.py auto-isolates between tests: pytest-asyncio +# creates a fresh event loop per test function, and the WeakKeyDictionary +# forgets the previous loop's state when it is GC'd. No manual reset needed. class TestIngestionWorkerExceptionHandling: @@ -75,7 +63,7 @@ class TestEnqueueConversationTurn: user_msg="hi", ) # No queue should have been created. - assert len(ingest._user_queues) == 0 + assert len(ingest._get_loop_state().user_queues) == 0 class TestQueueFullScenario: @@ -106,7 +94,7 @@ class TestQueueFullScenario: # Replace the queue with one that is already full. tiny_q: asyncio.Queue = asyncio.Queue(maxsize=1) tiny_q.put_nowait({"dummy": True}) - ingest._user_queues[user_id] = tiny_q + ingest._get_loop_state().user_queues[user_id] = tiny_q # Should not raise even though the queue is full. await ingest.enqueue_conversation_turn( @@ -162,6 +150,149 @@ class TestResolveUserName: assert name == "User" +class TestEnqueueEpisode: + @pytest.mark.asyncio + async def test_enqueue_episode_returns_true_on_success(self) -> None: + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + result = await ingest.enqueue_episode( + user_id="abc", + session_id="sess1", + name="test_ep", + episode_body="hello", + is_json=False, + ) + assert result is True + assert not q.empty() + + @pytest.mark.asyncio + async def test_enqueue_episode_returns_false_for_empty_user(self) -> None: + result = await ingest.enqueue_episode( + user_id="", + session_id="sess1", + name="test_ep", + episode_body="hello", + ) + assert result is False + + @pytest.mark.asyncio + async def test_enqueue_episode_returns_false_on_invalid_user(self) -> None: + with patch.object(ingest, "derive_group_id", side_effect=ValueError("bad id")): + result = await ingest.enqueue_episode( + user_id="bad", + session_id="sess1", + name="test_ep", + episode_body="hello", + ) + assert result is False + + @pytest.mark.asyncio + async def test_enqueue_episode_json_mode(self) -> None: + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + result = await ingest.enqueue_episode( + user_id="abc", + session_id="sess1", + name="test_ep", + episode_body='{"content": "hello"}', + is_json=True, + ) + assert result is True + item = q.get_nowait() + from graphiti_core.nodes import EpisodeType + + assert item["source"] == EpisodeType.json + + +class TestDerivedFindingLane: + @pytest.mark.asyncio + async def test_finding_worthy_message_enqueues_two_episodes(self) -> None: + """A substantive assistant message should enqueue both the user + episode and a derived-finding episode.""" + long_msg = "The analysis reveals significant growth patterns " + "x" * 200 + + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + patch( + "backend.copilot.graphiti.ingest._resolve_user_name", + new_callable=AsyncMock, + return_value="Alice", + ), + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + await ingest.enqueue_conversation_turn( + user_id="abc", + session_id="sess1", + user_msg="tell me about growth", + assistant_msg=long_msg, + ) + # Should have 2 items: user episode + derived finding + assert q.qsize() == 2 + + @pytest.mark.asyncio + async def test_short_assistant_msg_skips_finding(self) -> None: + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + patch( + "backend.copilot.graphiti.ingest._resolve_user_name", + new_callable=AsyncMock, + return_value="Alice", + ), + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + await ingest.enqueue_conversation_turn( + user_id="abc", + session_id="sess1", + user_msg="hi", + assistant_msg="ok", + ) + # Only 1 item: the user episode (no finding for short msg) + assert q.qsize() == 1 + + +class TestDerivedFindingDistillation: + """_is_finding_worthy and _distill_finding gate derived-finding creation.""" + + def test_short_message_not_finding_worthy(self) -> None: + assert ingest._is_finding_worthy("ok") is False + + def test_chatter_prefix_not_finding_worthy(self) -> None: + assert ingest._is_finding_worthy("done " + "x" * 200) is False + + def test_long_substantive_message_is_finding_worthy(self) -> None: + msg = "The quarterly revenue analysis shows a 15% increase " + "x" * 200 + assert ingest._is_finding_worthy(msg) is True + + def test_distill_finding_truncates_to_500(self) -> None: + result = ingest._distill_finding("x" * 600) + assert result is not None + assert len(result) == 503 # 500 + "..." + + class TestWorkerIdleTimeout: @pytest.mark.asyncio async def test_worker_cleans_up_on_idle(self) -> None: @@ -169,9 +300,10 @@ class TestWorkerIdleTimeout: queue: asyncio.Queue = asyncio.Queue(maxsize=10) # Pre-populate state so cleanup can remove entries. - ingest._user_queues[user_id] = queue + state = ingest._get_loop_state() + state.user_queues[user_id] = queue task_sentinel = MagicMock() - ingest._user_workers[user_id] = task_sentinel + state.user_workers[user_id] = task_sentinel original_timeout = ingest._WORKER_IDLE_TIMEOUT ingest._WORKER_IDLE_TIMEOUT = 0.05 @@ -181,5 +313,5 @@ class TestWorkerIdleTimeout: ingest._WORKER_IDLE_TIMEOUT = original_timeout # After idle timeout the worker should have cleaned up. - assert user_id not in ingest._user_queues - assert user_id not in ingest._user_workers + assert user_id not in state.user_queues + assert user_id not in state.user_workers diff --git a/autogpt_platform/backend/backend/copilot/graphiti/memory_model.py b/autogpt_platform/backend/backend/copilot/graphiti/memory_model.py new file mode 100644 index 0000000000..d8105cb731 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/graphiti/memory_model.py @@ -0,0 +1,118 @@ +"""Generic memory metadata model for Graphiti episodes. + +Domain-agnostic envelope that works across business, fiction, research, +personal life, and arbitrary knowledge domains. Designed so retrieval +can distinguish user-asserted facts from assistant-derived findings +and filter by scope. +""" + +from enum import Enum + +from pydantic import BaseModel, Field + + +class SourceKind(str, Enum): + user_asserted = "user_asserted" + assistant_derived = "assistant_derived" + tool_observed = "tool_observed" + + +class MemoryKind(str, Enum): + fact = "fact" + preference = "preference" + rule = "rule" + finding = "finding" + plan = "plan" + event = "event" + procedure = "procedure" + + +class MemoryStatus(str, Enum): + active = "active" + tentative = "tentative" + superseded = "superseded" + contradicted = "contradicted" + + +class RuleMemory(BaseModel): + """Structured representation of a standing instruction or rule. + + Preserves the exact user intent rather than relying on LLM + extraction to reconstruct it from prose. + """ + + instruction: str = Field( + description="The actionable instruction (e.g. 'CC Sarah on client communications')" + ) + actor: str | None = Field( + default=None, description="Who performs or is subject to the rule" + ) + trigger: str | None = Field( + default=None, + description="When the rule applies (e.g. 'client-related communications')", + ) + negation: str | None = Field( + default=None, + description="What NOT to do, if applicable (e.g. 'do not use SMTP')", + ) + + +class ProcedureStep(BaseModel): + """A single step in a multi-step procedure.""" + + order: int = Field(description="Step number (1-based)") + action: str = Field(description="What to do in this step") + tool: str | None = Field(default=None, description="Tool or service to use") + condition: str | None = Field(default=None, description="When/if this step applies") + negation: str | None = Field( + default=None, description="What NOT to do in this step" + ) + + +class ProcedureMemory(BaseModel): + """Structured representation of a multi-step workflow. + + Steps with ordering, tools, conditions, and negations that don't + decompose cleanly into fact triples. + """ + + description: str = Field(description="What this procedure accomplishes") + steps: list[ProcedureStep] = Field(default_factory=list) + + +class MemoryEnvelope(BaseModel): + """Structured wrapper for explicit memory storage. + + Serialized as JSON and ingested via ``EpisodeType.json`` so that + Graphiti extracts entities from the ``content`` field while the + metadata fields survive as episode-level context. + + For ``memory_kind=rule``, populate the ``rule`` field with a + ``RuleMemory`` to preserve the exact instruction. For + ``memory_kind=procedure``, populate ``procedure`` with a + ``ProcedureMemory`` for structured steps. + """ + + content: str = Field( + description="The memory content — the actual fact, rule, or finding" + ) + source_kind: SourceKind = Field(default=SourceKind.user_asserted) + scope: str = Field( + default="real:global", + description="Namespace: 'real:global', 'project:<name>', 'book:<title>', 'session:<id>'", + ) + memory_kind: MemoryKind = Field(default=MemoryKind.fact) + status: MemoryStatus = Field(default=MemoryStatus.active) + confidence: float | None = Field(default=None, ge=0.0, le=1.0) + provenance: str | None = Field( + default=None, + description="Origin reference — session_id, tool_call_id, or URL", + ) + rule: RuleMemory | None = Field( + default=None, + description="Structured rule data — populate when memory_kind=rule", + ) + procedure: ProcedureMemory | None = Field( + default=None, + description="Structured procedure data — populate when memory_kind=procedure", + ) diff --git a/autogpt_platform/backend/backend/copilot/permissions.py b/autogpt_platform/backend/backend/copilot/permissions.py index cc01a124c4..a30ee282f7 100644 --- a/autogpt_platform/backend/backend/copilot/permissions.py +++ b/autogpt_platform/backend/backend/copilot/permissions.py @@ -89,6 +89,8 @@ ToolName = Literal[ "get_mcp_guide", "list_folders", "list_workspace_files", + "memory_forget_confirm", + "memory_forget_search", "memory_search", "memory_store", "move_agents_to_folder", diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index d76f2ece80..19f151f008 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -2370,6 +2370,7 @@ async def stream_chat_completion_sdk( turn_cache_creation_tokens = 0 turn_cost_usd: float | None = None graphiti_enabled = False + pre_attempt_msg_count = 0 # Defaults ensure the finally block can always reference these safely even when # an early return (e.g. sdk_cwd error) skips their normal assignment below. sdk_model: str | None = None @@ -2788,6 +2789,9 @@ async def stream_chat_completion_sdk( if attachments.hint: query_message = f"{query_message}\n\n{attachments.hint}" + # warm_ctx is injected via inject_user_context above (warm_ctx= kwarg). + # No separate injection needed here. + # When running without --resume and no prior transcript in storage, # seed the transcript builder from compressed DB messages so that # upload_transcript saves a compact version for future turns. @@ -2937,6 +2941,8 @@ async def stream_chat_completion_sdk( ) if attachments.hint: state.query_message = f"{state.query_message}\n\n{attachments.hint}" + # warm_ctx is already baked into current_message via + # inject_user_context — no separate injection needed. state.adapter = SDKResponseAdapter( message_id=message_id, session_id=session_id ) @@ -3341,8 +3347,21 @@ async def stream_chat_completion_sdk( if graphiti_enabled and user_id and message and is_user_message: from ..graphiti.ingest import enqueue_conversation_turn + # Extract last assistant message from THIS TURN only (not all + # session history) to avoid distilling stale content from prior + # turns when the current turn errors before producing output. + _this_turn_msgs = ( + session.messages[pre_attempt_msg_count:] if session else [] + ) + _assistant_msgs = [ + m.content or "" for m in _this_turn_msgs if m.role == "assistant" + ] + _last_assistant = _assistant_msgs[-1] if _assistant_msgs else "" + _ingest_task = asyncio.create_task( - enqueue_conversation_turn(user_id, session_id, message) + enqueue_conversation_turn( + user_id, session_id, message, assistant_msg=_last_assistant + ) ) _background_tasks.add(_ingest_task) _ingest_task.add_done_callback(_background_tasks.discard) diff --git a/autogpt_platform/backend/backend/copilot/tools/__init__.py b/autogpt_platform/backend/backend/copilot/tools/__init__.py index c4913a9411..75a0a8f4e4 100644 --- a/autogpt_platform/backend/backend/copilot/tools/__init__.py +++ b/autogpt_platform/backend/backend/copilot/tools/__init__.py @@ -26,6 +26,7 @@ from .fix_agent import FixAgentGraphTool from .get_agent_building_guide import GetAgentBuildingGuideTool from .get_doc_page import GetDocPageTool from .get_mcp_guide import GetMCPGuideTool +from .graphiti_forget import MemoryForgetConfirmTool, MemoryForgetSearchTool from .graphiti_search import MemorySearchTool from .graphiti_store import MemoryStoreTool from .manage_folders import ( @@ -66,6 +67,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "find_block": FindBlockTool(), "find_library_agent": FindLibraryAgentTool(), # Graphiti memory tools + "memory_forget_confirm": MemoryForgetConfirmTool(), + "memory_forget_search": MemoryForgetSearchTool(), "memory_search": MemorySearchTool(), "memory_store": MemoryStoreTool(), # Folder management tools diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_forget.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget.py new file mode 100644 index 0000000000..c3a30a583e --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget.py @@ -0,0 +1,349 @@ +"""Two-step tool for targeted memory deletion. + +Step 1 (memory_forget_search): search for matching facts, return candidates. +Step 2 (memory_forget_confirm): delete specific edges by UUID after user confirms. +""" + +import logging +from typing import Any + +from backend.copilot.graphiti._format import extract_fact, extract_temporal_validity +from backend.copilot.graphiti.client import derive_group_id, get_graphiti_client +from backend.copilot.graphiti.config import is_enabled_for_user +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import ( + ErrorResponse, + MemoryForgetCandidatesResponse, + MemoryForgetConfirmResponse, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class MemoryForgetSearchTool(BaseTool): + """Search for memories to forget — returns candidates for user confirmation.""" + + @property + def name(self) -> str: + return "memory_forget_search" + + @property + def description(self) -> str: + return ( + "Search for stored memories matching a description so the user can " + "choose which to delete. Returns candidate facts with UUIDs. " + "Use memory_forget_confirm with the UUIDs to actually delete them." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Natural language description of what to forget (e.g. 'the Q2 marketing budget')", + }, + }, + "required": ["query"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + *, + query: str = "", + **kwargs, + ) -> ToolResponseBase: + if not user_id: + return ErrorResponse( + message="Authentication required.", + session_id=session.session_id, + ) + + if not await is_enabled_for_user(user_id): + return ErrorResponse( + message="Memory features are not enabled for your account.", + session_id=session.session_id, + ) + + if not query: + return ErrorResponse( + message="A search query is required to find memories to forget.", + session_id=session.session_id, + ) + + try: + group_id = derive_group_id(user_id) + except ValueError: + return ErrorResponse( + message="Invalid user ID for memory operations.", + session_id=session.session_id, + ) + + try: + client = await get_graphiti_client(group_id) + edges = await client.search( + query=query, + group_ids=[group_id], + num_results=10, + ) + except Exception: + logger.warning( + "Memory forget search failed for user %s", user_id[:12], exc_info=True + ) + return ErrorResponse( + message="Memory search is temporarily unavailable.", + session_id=session.session_id, + ) + + if not edges: + return MemoryForgetCandidatesResponse( + message="No matching memories found.", + session_id=session.session_id, + candidates=[], + ) + + candidates = [] + for e in edges: + edge_uuid = getattr(e, "uuid", None) or getattr(e, "id", None) + if not edge_uuid: + continue + fact = extract_fact(e) + valid_from, valid_to = extract_temporal_validity(e) + candidates.append( + { + "uuid": str(edge_uuid), + "fact": fact, + "valid_from": str(valid_from), + "valid_to": str(valid_to), + } + ) + + return MemoryForgetCandidatesResponse( + message=f"Found {len(candidates)} candidate(s). Show these to the user and ask which to delete, then call memory_forget_confirm with the UUIDs.", + session_id=session.session_id, + candidates=candidates, + ) + + +class MemoryForgetConfirmTool(BaseTool): + """Delete specific memory edges by UUID after user confirmation. + + Supports both soft delete (temporal invalidation — reversible) and + hard delete (remove from graph — irreversible, for GDPR). + """ + + @property + def name(self) -> str: + return "memory_forget_confirm" + + @property + def description(self) -> str: + return ( + "Delete specific memories by UUID. Use after memory_forget_search " + "returns candidates and the user confirms which to delete. " + "Default is soft delete (marks as expired but keeps history). " + "Set hard_delete=true for permanent removal (GDPR)." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "uuids": { + "type": "array", + "items": {"type": "string"}, + "description": "List of edge UUIDs to delete (from memory_forget_search results)", + }, + "hard_delete": { + "type": "boolean", + "description": "If true, permanently removes edges from the graph (GDPR). Default false (soft delete — marks as expired).", + "default": False, + }, + }, + "required": ["uuids"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + *, + uuids: list[str] | None = None, + hard_delete: bool = False, + **kwargs, + ) -> ToolResponseBase: + if not user_id: + return ErrorResponse( + message="Authentication required.", + session_id=session.session_id, + ) + + if not await is_enabled_for_user(user_id): + return ErrorResponse( + message="Memory features are not enabled for your account.", + session_id=session.session_id, + ) + + if not uuids: + return ErrorResponse( + message="At least one UUID is required. Use memory_forget_search first.", + session_id=session.session_id, + ) + + try: + group_id = derive_group_id(user_id) + except ValueError: + return ErrorResponse( + message="Invalid user ID for memory operations.", + session_id=session.session_id, + ) + + try: + client = await get_graphiti_client(group_id) + except Exception: + logger.warning( + "Failed to get Graphiti client for user %s", user_id[:12], exc_info=True + ) + return ErrorResponse( + message="Memory service is temporarily unavailable.", + session_id=session.session_id, + ) + + driver = getattr(client, "graph_driver", None) or getattr( + client, "driver", None + ) + if not driver: + return ErrorResponse( + message="Could not access graph driver for deletion.", + session_id=session.session_id, + ) + + if hard_delete: + deleted, failed = await _hard_delete_edges(driver, uuids, user_id) + mode = "permanently deleted" + else: + deleted, failed = await _soft_delete_edges(driver, uuids, user_id) + mode = "invalidated" + + return MemoryForgetConfirmResponse( + message=( + f"{len(deleted)} memory edge(s) {mode}." + + (f" {len(failed)} failed." if failed else "") + ), + session_id=session.session_id, + deleted_uuids=deleted, + failed_uuids=failed, + ) + + +async def _soft_delete_edges( + driver, uuids: list[str], user_id: str +) -> tuple[list[str], list[str]]: + """Temporal invalidation — mark edges as expired without removing them. + + Sets ``invalid_at`` and ``expired_at`` to now, which excludes them + from default search results while preserving history. + + Matches the same edge types as ``_hard_delete_edges`` so that edges of + any type (RELATES_TO, MENTIONS, HAS_MEMBER) can be soft-deleted. + """ + deleted = [] + failed = [] + for uuid in uuids: + try: + records, _, _ = await driver.execute_query( + """ + MATCH ()-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->() + SET e.invalid_at = datetime(), + e.expired_at = datetime() + RETURN e.uuid AS uuid + """, + uuid=uuid, + ) + if records: + deleted.append(uuid) + else: + failed.append(uuid) + except Exception: + logger.warning( + "Failed to soft-delete edge %s for user %s", + uuid, + user_id[:12], + exc_info=True, + ) + failed.append(uuid) + return deleted, failed + + +async def _hard_delete_edges( + driver, uuids: list[str], user_id: str +) -> tuple[list[str], list[str]]: + """Permanent removal — delete edges and clean up back-references. + + Uses graphiti's ``Edge.delete()`` pattern (handles MENTIONS, + RELATES_TO, HAS_MEMBER in one query). Does NOT delete orphaned + entity nodes — they may have summaries, embeddings, or future + connections. Cleans up episode ``entity_edges`` back-references. + """ + deleted = [] + failed = [] + for uuid in uuids: + try: + # Use WITH to capture the uuid before DELETE so we don't + # access properties of deleted relationships (FalkorDB #1393). + # Single atomic query avoids TOCTOU between check and delete. + records, _, _ = await driver.execute_query( + """ + MATCH ()-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->() + WITH e.uuid AS uuid, e + DELETE e + RETURN uuid + """, + uuid=uuid, + ) + if not records: + failed.append(uuid) + continue + # Edge was deleted — report success regardless of cleanup outcome. + deleted.append(uuid) + # Clean up episode back-references (best-effort). + try: + await driver.execute_query( + """ + MATCH (ep:Episodic) + WHERE $uuid IN ep.entity_edges + SET ep.entity_edges = [x IN ep.entity_edges WHERE x <> $uuid] + """, + uuid=uuid, + ) + except Exception: + logger.warning( + "Edge %s deleted but back-ref cleanup failed for user %s", + uuid, + user_id[:12], + exc_info=True, + ) + except Exception: + logger.warning( + "Failed to hard-delete edge %s for user %s", + uuid, + user_id[:12], + exc_info=True, + ) + failed.append(uuid) + return deleted, failed diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_forget_test.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget_test.py new file mode 100644 index 0000000000..94bbeb5d4f --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget_test.py @@ -0,0 +1,77 @@ +"""Tests for graphiti_forget delete helpers.""" + +from unittest.mock import AsyncMock + +import pytest + +from backend.copilot.tools.graphiti_forget import _hard_delete_edges, _soft_delete_edges + + +class TestSoftDeleteOverReportsSuccess: + """_soft_delete_edges always appends UUID to deleted list even when + the Cypher MATCH found no edge (query succeeds but matches nothing). + """ + + @pytest.mark.asyncio + async def test_reports_failure_when_no_edge_matched(self) -> None: + driver = AsyncMock() + # execute_query returns empty result set — no edge matched + driver.execute_query.return_value = ([], None, None) + + deleted, failed = await _soft_delete_edges( + driver, ["nonexistent-uuid"], "test-user" + ) + # Should NOT report success when nothing was actually updated + assert deleted == [], f"over-reported success: {deleted}" + assert failed == ["nonexistent-uuid"] + + +class TestSoftDeleteNoMatchReportsFailure: + """When the query returns empty records (no edge with that UUID exists + in the database), _soft_delete_edges should report it as failed. + """ + + @pytest.mark.asyncio + async def test_soft_delete_handles_non_relates_to_edge(self) -> None: + driver = AsyncMock() + # Simulate: RELATES_TO match returns nothing (edge is MENTIONS type) + driver.execute_query.return_value = ([], None, None) + + deleted, failed = await _soft_delete_edges( + driver, ["mentions-edge-uuid"], "test-user" + ) + # With the bug, this reports success even though nothing was updated + assert "mentions-edge-uuid" not in deleted + + +class TestHardDeleteBasicFlow: + """Verify _hard_delete_edges calls the right queries.""" + + @pytest.mark.asyncio + async def test_hard_delete_calls_both_queries(self) -> None: + driver = AsyncMock() + # First call (delete) returns a matched record, second (cleanup) returns empty + driver.execute_query.side_effect = [ + ([{"uuid": "uuid-1"}], None, None), + ([], None, None), + ] + + deleted, failed = await _hard_delete_edges(driver, ["uuid-1"], "test-user") + assert deleted == ["uuid-1"] + assert failed == [] + # Should call: 1) delete edge, 2) clean episode back-refs + assert driver.execute_query.call_count == 2 + + @pytest.mark.asyncio + async def test_hard_delete_reports_failure_when_no_edge_matched(self) -> None: + driver = AsyncMock() + # Delete query returns no records — edge not found + driver.execute_query.return_value = ([], None, None) + + deleted, failed = await _hard_delete_edges( + driver, ["nonexistent-uuid"], "test-user" + ) + assert deleted == [] + assert failed == ["nonexistent-uuid"] + # Only the delete query should run — cleanup skipped + assert driver.execute_query.call_count == 1 diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py index 27f47a6b29..0aef554bbf 100644 --- a/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py @@ -7,6 +7,7 @@ from typing import Any from backend.copilot.graphiti._format import ( extract_episode_body, + extract_episode_body_raw, extract_episode_timestamp, extract_fact, extract_temporal_validity, @@ -52,6 +53,15 @@ class MemorySearchTool(BaseTool): "description": "Maximum number of results to return", "default": 15, }, + "scope": { + "type": "string", + "description": ( + "Optional scope filter. When set, only memories matching " + "this scope are returned (hard filter). " + "Examples: 'real:global', 'project:crm', 'book:my-novel'. " + "Omit to search all scopes." + ), + }, }, "required": ["query"], } @@ -67,6 +77,7 @@ class MemorySearchTool(BaseTool): *, query: str = "", limit: int = 15, + scope: str = "", **kwargs, ) -> ToolResponseBase: if not user_id: @@ -122,7 +133,14 @@ class MemorySearchTool(BaseTool): ) facts = _format_edges(edges) - recent = _format_episodes(episodes) + + # Scope hard-filter: if a scope was requested, filter episodes + # whose MemoryEnvelope JSON contains a different scope. + # Skip redundant _format_episodes() when scope is set. + if scope: + recent = _filter_episodes_by_scope(episodes, scope) + else: + recent = _format_episodes(episodes) if not facts and not recent: return MemorySearchResponse( @@ -132,9 +150,10 @@ class MemorySearchTool(BaseTool): recent_episodes=[], ) + scope_note = f" (scope filter: {scope})" if scope else "" return MemorySearchResponse( message=( - f"Found {len(facts)} relationship facts and {len(recent)} stored memories. " + f"Found {len(facts)} relationship facts and {len(recent)} stored memories{scope_note}. " "Use BOTH sections to answer — stored memories often contain operational " "rules and instructions that relationship facts summarize." ), @@ -160,3 +179,35 @@ def _format_episodes(episodes) -> list[str]: body = extract_episode_body(ep) results.append(f"[{ts}] {body}") return results + + +def _filter_episodes_by_scope(episodes, scope: str) -> list[str]: + """Filter episodes by scope — hard filter on MemoryEnvelope JSON content. + + Episodes that are plain conversation text (not JSON envelopes) are + included by default since they have no scope metadata and belong + to the implicit ``real:global`` scope. + + Uses ``extract_episode_body_raw`` (no truncation) for JSON parsing + so that long MemoryEnvelope payloads are parsed correctly. + """ + import json + + results = [] + for ep in episodes: + raw_body = extract_episode_body_raw(ep) + try: + data = json.loads(raw_body) + if not isinstance(data, dict): + raise TypeError("non-dict JSON") + ep_scope = data.get("scope", "real:global") + if ep_scope != scope: + continue + except (json.JSONDecodeError, TypeError): + # Not JSON or non-dict JSON — plain conversation episode, treat as real:global + if scope != "real:global": + continue + display_body = extract_episode_body(ep) + ts = extract_episode_timestamp(ep) + results.append(f"[{ts}] {display_body}") + return results diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_search_test.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_search_test.py new file mode 100644 index 0000000000..99e2de78ea --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_search_test.py @@ -0,0 +1,64 @@ +"""Tests for graphiti_search helper functions.""" + +from types import SimpleNamespace + +from backend.copilot.graphiti.memory_model import MemoryEnvelope, MemoryKind, SourceKind +from backend.copilot.tools.graphiti_search import ( + _filter_episodes_by_scope, + _format_episodes, +) + + +class TestFilterEpisodesByScopeTruncation: + """extract_episode_body() truncates to 500 chars. A MemoryEnvelope + with a long content field exceeds that limit, producing invalid JSON. + _filter_episodes_by_scope then treats it as a plain-text episode + (real:global), leaking project-scoped data into global results. + """ + + def test_long_envelope_filtered_by_scope(self) -> None: + envelope = MemoryEnvelope( + content="x" * 600, + source_kind=SourceKind.user_asserted, + scope="project:crm", + memory_kind=MemoryKind.fact, + ) + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + # Requesting real:global scope — this project:crm episode should be excluded + results = _filter_episodes_by_scope([ep], "real:global") + assert ( + results == [] + ), f"project-scoped episode leaked into global results: {results}" + + def test_short_envelope_filtered_correctly(self) -> None: + """Short envelopes (under 500 chars) are parsed correctly.""" + envelope = MemoryEnvelope( + content="short note", + scope="project:crm", + ) + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + results = _filter_episodes_by_scope([ep], "real:global") + assert results == [] + + +class TestRedundantFormatting: + """_format_episodes is called even when scope filter will overwrite it. + Not a correctness bug, but verify the scope path doesn't depend on it. + """ + + def test_scope_filter_independent_of_format_episodes(self) -> None: + envelope = MemoryEnvelope(content="note", scope="real:global") + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + from_format = _format_episodes([ep]) + from_scope = _filter_episodes_by_scope([ep], "real:global") + assert len(from_format) == 1 + assert len(from_scope) == 1 diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py index 6e75eb2ed4..3112820e54 100644 --- a/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py @@ -5,6 +5,15 @@ from typing import Any from backend.copilot.graphiti.config import is_enabled_for_user from backend.copilot.graphiti.ingest import enqueue_episode +from backend.copilot.graphiti.memory_model import ( + MemoryEnvelope, + MemoryKind, + MemoryStatus, + ProcedureMemory, + ProcedureStep, + RuleMemory, + SourceKind, +) from backend.copilot.model import ChatSession from .base import BaseTool @@ -26,7 +35,7 @@ class MemoryStoreTool(BaseTool): "Store a memory or fact about the user for future recall. " "Use when the user shares preferences, business context, decisions, " "relationships, or other important information worth remembering " - "across sessions." + "across sessions. Supports optional metadata for scoping and classification." ) @property @@ -47,6 +56,94 @@ class MemoryStoreTool(BaseTool): "description": "Context about where this info came from", "default": "Conversation memory", }, + "source_kind": { + "type": "string", + "enum": [e.value for e in SourceKind], + "description": "Who asserted this: user_asserted (default), assistant_derived, or tool_observed", + "default": "user_asserted", + }, + "scope": { + "type": "string", + "description": "Namespace for this memory: 'real:global' (default), 'project:<name>', 'book:<title>'", + "default": "real:global", + }, + "memory_kind": { + "type": "string", + "enum": [e.value for e in MemoryKind], + "description": "Type of memory: fact (default), preference, rule, finding, plan, event, procedure", + "default": "fact", + }, + "rule": { + "type": "object", + "description": ( + "Structured rule data — use when memory_kind=rule to preserve " + "exact operational instructions. Example: " + '{"instruction": "CC Sarah on client communications", ' + '"actor": "Sarah", "trigger": "client-related communications"}' + ), + "properties": { + "instruction": { + "type": "string", + "description": "The actionable instruction", + }, + "actor": { + "type": "string", + "description": "Who performs or is subject to the rule", + }, + "trigger": { + "type": "string", + "description": "When the rule applies", + }, + "negation": { + "type": "string", + "description": "What NOT to do, if applicable", + }, + }, + "required": ["instruction"], + }, + "procedure": { + "type": "object", + "description": ( + "Structured procedure data — use when memory_kind=procedure " + "for multi-step workflows with ordering, tools, and conditions." + ), + "properties": { + "description": { + "type": "string", + "description": "What this procedure accomplishes", + }, + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "order": { + "type": "integer", + "description": "Step number", + }, + "action": { + "type": "string", + "description": "What to do", + }, + "tool": { + "type": "string", + "description": "Tool or service to use", + }, + "condition": { + "type": "string", + "description": "When this step applies", + }, + "negation": { + "type": "string", + "description": "What NOT to do", + }, + }, + "required": ["order", "action"], + }, + }, + }, + "required": ["description", "steps"], + }, }, "required": ["name", "content"], } @@ -63,6 +160,11 @@ class MemoryStoreTool(BaseTool): name: str = "", content: str = "", source_description: str = "Conversation memory", + source_kind: str = "user_asserted", + scope: str = "real:global", + memory_kind: str = "fact", + rule: dict | None = None, + procedure: dict | None = None, **kwargs, ) -> ToolResponseBase: if not user_id: @@ -83,12 +185,53 @@ class MemoryStoreTool(BaseTool): session_id=session.session_id, ) + rule_model = None + if rule and memory_kind == "rule": + try: + rule_model = RuleMemory(**rule) + except Exception: + logger.warning("Invalid rule data, storing as plain fact") + memory_kind = "fact" + + procedure_model = None + if procedure and memory_kind == "procedure": + try: + steps = [ProcedureStep(**s) for s in procedure.get("steps", [])] + procedure_model = ProcedureMemory( + description=procedure.get("description", content), + steps=steps, + ) + except Exception: + logger.warning("Invalid procedure data, storing as plain fact") + memory_kind = "fact" + + try: + resolved_source = SourceKind(source_kind) + except ValueError: + resolved_source = SourceKind.user_asserted + try: + resolved_kind = MemoryKind(memory_kind) + except ValueError: + resolved_kind = MemoryKind.fact + + envelope = MemoryEnvelope( + content=content, + source_kind=resolved_source, + scope=scope, + memory_kind=resolved_kind, + status=MemoryStatus.active, + provenance=session.session_id, + rule=rule_model, + procedure=procedure_model, + ) + queued = await enqueue_episode( user_id, session.session_id, name=name, - episode_body=content, + episode_body=envelope.model_dump_json(), source_description=source_description, + is_json=True, ) if not queued: diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py index 3742355d76..21224d39c0 100644 --- a/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py @@ -1,5 +1,6 @@ """Tests for MemoryStoreTool.""" +import json from datetime import UTC, datetime from unittest.mock import AsyncMock, patch @@ -153,13 +154,14 @@ class TestMemoryStoreTool: assert "queued for storage" in result.message assert result.session_id == "test-session" - mock_enqueue.assert_awaited_once_with( - "user-1", - "test-session", - name="user_prefers_python", - episode_body="The user prefers Python over JavaScript.", - source_description="Direct statement", - ) + mock_enqueue.assert_awaited_once() + call_kwargs = mock_enqueue.await_args.kwargs + assert call_kwargs["name"] == "user_prefers_python" + assert call_kwargs["source_description"] == "Direct statement" + assert call_kwargs["is_json"] is True + envelope = json.loads(call_kwargs["episode_body"]) + assert envelope["content"] == "The user prefers Python over JavaScript." + assert envelope["memory_kind"] == "fact" @pytest.mark.asyncio async def test_store_success_uses_default_source_description(self): @@ -187,10 +189,132 @@ class TestMemoryStoreTool: ) assert isinstance(result, MemoryStoreResponse) - mock_enqueue.assert_awaited_once_with( - "user-1", - "test-session", - name="some_fact", - episode_body="A fact worth remembering.", - source_description="Conversation memory", - ) + mock_enqueue.assert_awaited_once() + call_kwargs = mock_enqueue.await_args.kwargs + assert call_kwargs["name"] == "some_fact" + assert call_kwargs["source_description"] == "Conversation memory" + assert call_kwargs["is_json"] is True + envelope = json.loads(call_kwargs["episode_body"]) + assert envelope["content"] == "A fact worth remembering." + + @pytest.mark.asyncio + async def test_store_invalid_source_kind_falls_back(self): + """Invalid enum values should fall back to defaults, not crash.""" + tool = MemoryStoreTool() + session = _make_session() + + mock_enqueue = AsyncMock() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + mock_enqueue, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="some_fact", + content="A fact.", + source_kind="INVALID_SOURCE", + memory_kind="INVALID_KIND", + ) + + assert isinstance(result, MemoryStoreResponse) + envelope = json.loads(mock_enqueue.await_args.kwargs["episode_body"]) + assert envelope["source_kind"] == "user_asserted" + assert envelope["memory_kind"] == "fact" + + @pytest.mark.asyncio + async def test_store_valid_enum_values_preserved(self): + tool = MemoryStoreTool() + session = _make_session() + + mock_enqueue = AsyncMock() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + mock_enqueue, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="rule_1", + content="Always CC Sarah.", + source_kind="user_asserted", + memory_kind="rule", + ) + + assert isinstance(result, MemoryStoreResponse) + envelope = json.loads(mock_enqueue.await_args.kwargs["episode_body"]) + assert envelope["source_kind"] == "user_asserted" + assert envelope["memory_kind"] == "rule" + + @pytest.mark.asyncio + async def test_store_queue_full_returns_error(self): + tool = MemoryStoreTool() + session = _make_session() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + new_callable=AsyncMock, + return_value=False, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="pref", + content="likes python", + ) + + assert isinstance(result, ErrorResponse) + assert "queue" in result.message.lower() + + @pytest.mark.asyncio + async def test_store_with_scope(self): + tool = MemoryStoreTool() + session = _make_session() + + mock_enqueue = AsyncMock() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + mock_enqueue, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="project_note", + content="CRM uses PostgreSQL.", + scope="project:crm", + ) + + assert isinstance(result, MemoryStoreResponse) + envelope = json.loads(mock_enqueue.await_args.kwargs["episode_body"]) + assert envelope["scope"] == "project:crm" diff --git a/autogpt_platform/backend/backend/copilot/tools/models.py b/autogpt_platform/backend/backend/copilot/tools/models.py index bf211e2da7..90aa3d51db 100644 --- a/autogpt_platform/backend/backend/copilot/tools/models.py +++ b/autogpt_platform/backend/backend/copilot/tools/models.py @@ -84,6 +84,8 @@ class ResponseType(str, Enum): # Graphiti memory MEMORY_STORE = "memory_store" MEMORY_SEARCH = "memory_search" + MEMORY_FORGET_CANDIDATES = "memory_forget_candidates" + MEMORY_FORGET_CONFIRM = "memory_forget_confirm" # Base response model @@ -712,3 +714,18 @@ class MemorySearchResponse(ToolResponseBase): type: ResponseType = ResponseType.MEMORY_SEARCH facts: list[str] = Field(default_factory=list) recent_episodes: list[str] = Field(default_factory=list) + + +class MemoryForgetCandidatesResponse(ToolResponseBase): + """Response with candidate memories to forget.""" + + type: ResponseType = ResponseType.MEMORY_FORGET_CANDIDATES + candidates: list[dict[str, str]] = Field(default_factory=list) + + +class MemoryForgetConfirmResponse(ToolResponseBase): + """Response after deleting specific memory edges.""" + + type: ResponseType = ResponseType.MEMORY_FORGET_CONFIRM + deleted_uuids: list[str] = Field(default_factory=list) + failed_uuids: list[str] = Field(default_factory=list) diff --git a/autogpt_platform/backend/backend/util/architecture_test.py b/autogpt_platform/backend/backend/util/architecture_test.py new file mode 100644 index 0000000000..b3cf457911 --- /dev/null +++ b/autogpt_platform/backend/backend/util/architecture_test.py @@ -0,0 +1,134 @@ +""" +Architectural tests for the backend package. + +Each rule here exists to prevent a *class* of bug, not to police style. +When adding a rule, document the incident or failure mode that motivated +it so future maintainers know whether the rule still earns its keep. +""" + +import ast +import pathlib + +BACKEND_ROOT = pathlib.Path(__file__).resolve().parents[1] + + +# --------------------------------------------------------------------------- +# Rule: no process-wide @cached(...) around event-loop-bound async clients +# --------------------------------------------------------------------------- +# +# Motivation: `backend.util.cache.cached` stores its result in a process-wide +# dict for ttl_seconds. Async clients (AsyncOpenAI, httpx.AsyncClient, +# AsyncRabbitMQ, supabase AClient, ...) wrap connection pools whose internal +# asyncio primitives lazily bind to the first event loop that uses them. The +# executor runs two long-lived loops on separate threads; once the cache is +# populated from loop A, any subsequent call from loop B raises +# `RuntimeError: ... bound to a different event loop`, surfaced as an opaque +# `APIConnectionError: Connection error.` and poisons the cache for a full +# TTL window. +# +# Use `per_loop_cached` (keyed on id(running loop)) or construct per-call. + +LOOP_BOUND_TYPES = frozenset( + { + "AsyncOpenAI", + "LangfuseAsyncOpenAI", + "AsyncClient", # httpx, openai internal + "AsyncRabbitMQ", + "AClient", # supabase async + "AsyncRedisExecutionEventBus", + } +) + +# Pre-existing offenders tracked for future cleanup. Exclude from this test +# so the rule can still catch NEW violations without blocking unrelated PRs. +_KNOWN_OFFENDERS = frozenset( + { + "util/clients.py get_async_supabase", + "util/clients.py get_openai_client", + } +) + + +def _decorator_name(node: ast.expr) -> str | None: + if isinstance(node, ast.Call): + return _decorator_name(node.func) + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None + + +def _annotation_names(annotation: ast.expr | None) -> set[str]: + if annotation is None: + return set() + if isinstance(annotation, ast.Constant) and isinstance(annotation.value, str): + try: + parsed = ast.parse(annotation.value, mode="eval").body + except SyntaxError: + return set() + return _annotation_names(parsed) + names: set[str] = set() + for child in ast.walk(annotation): + if isinstance(child, ast.Name): + names.add(child.id) + elif isinstance(child, ast.Attribute): + names.add(child.attr) + return names + + +def _iter_backend_py_files(): + for path in BACKEND_ROOT.rglob("*.py"): + if "__pycache__" in path.parts: + continue + yield path + + +def test_known_offenders_use_posix_separators(): + """_KNOWN_OFFENDERS must use forward slashes since the comparison key + is built from pathlib.Path.relative_to() which uses OS-native separators. + On Windows this would be backslashes, causing false positives. + + Ensure the key construction normalises to forward slashes. + """ + for entry in _KNOWN_OFFENDERS: + path_part = entry.split()[0] + assert "\\" not in path_part, ( + f"_KNOWN_OFFENDERS entry uses backslash: {entry!r}. " + "Use forward slashes — the test should normalise Path separators." + ) + + +def test_no_process_cached_loop_bound_clients(): + offenders: list[str] = [] + for py in _iter_backend_py_files(): + try: + tree = ast.parse(py.read_text(encoding="utf-8"), filename=str(py)) + except SyntaxError: + continue + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + decorators = {_decorator_name(d) for d in node.decorator_list} + if "cached" not in decorators: + continue + bound = _annotation_names(node.returns) & LOOP_BOUND_TYPES + if bound: + rel = py.relative_to(BACKEND_ROOT) + key = f"{rel.as_posix()} {node.name}" + if key in _KNOWN_OFFENDERS: + continue + offenders.append( + f"{rel}:{node.lineno} {node.name}() -> {sorted(bound)}" + ) + + assert not offenders, ( + "Process-wide @cached(...) must not wrap functions returning event-" + "loop-bound async clients. These objects lazily bind their connection " + "pool to the first event loop that uses them; caching them across " + "loops poisons the cache and surfaces as opaque connection errors.\n\n" + "Offenders:\n " + "\n ".join(offenders) + "\n\n" + "Fix: construct the client per-call, or introduce a per-loop factory " + "keyed on id(asyncio.get_running_loop()). See " + "backend/util/clients.py::get_openai_client for context." + ) diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index e5ad3bf296..ef775cf92b 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1346,7 +1346,15 @@ { "$ref": "#/components/schemas/MCPToolsDiscoveredResponse" }, - { "$ref": "#/components/schemas/MCPToolOutputResponse" } + { "$ref": "#/components/schemas/MCPToolOutputResponse" }, + { "$ref": "#/components/schemas/MemoryStoreResponse" }, + { "$ref": "#/components/schemas/MemorySearchResponse" }, + { + "$ref": "#/components/schemas/MemoryForgetCandidatesResponse" + }, + { + "$ref": "#/components/schemas/MemoryForgetConfirmResponse" + } ], "title": "Response Getv2[Dummy] Tool Response Type Export For Codegen" } @@ -11525,6 +11533,103 @@ "title": "MarketplaceListingCreator", "description": "Creator information for a marketplace listing." }, + "MemoryForgetCandidatesResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_forget_candidates" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "candidates": { + "items": { + "additionalProperties": { "type": "string" }, + "type": "object" + }, + "type": "array", + "title": "Candidates" + } + }, + "type": "object", + "required": ["message"], + "title": "MemoryForgetCandidatesResponse", + "description": "Response with candidate memories to forget." + }, + "MemoryForgetConfirmResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_forget_confirm" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "deleted_uuids": { + "items": { "type": "string" }, + "type": "array", + "title": "Deleted Uuids" + }, + "failed_uuids": { + "items": { "type": "string" }, + "type": "array", + "title": "Failed Uuids" + } + }, + "type": "object", + "required": ["message"], + "title": "MemoryForgetConfirmResponse", + "description": "Response after deleting specific memory edges." + }, + "MemorySearchResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_search" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "facts": { + "items": { "type": "string" }, + "type": "array", + "title": "Facts" + }, + "recent_episodes": { + "items": { "type": "string" }, + "type": "array", + "title": "Recent Episodes" + } + }, + "type": "object", + "required": ["message"], + "title": "MemorySearchResponse", + "description": "Response when memories are searched." + }, + "MemoryStoreResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_store" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "memory_name": { "type": "string", "title": "Memory Name" } + }, + "type": "object", + "required": ["message", "memory_name"], + "title": "MemoryStoreResponse", + "description": "Response when a memory is stored." + }, "Message": { "properties": { "query": { "type": "string", "title": "Query" }, @@ -12921,7 +13026,9 @@ "feature_request_search", "feature_request_created", "memory_store", - "memory_search" + "memory_search", + "memory_forget_candidates", + "memory_forget_confirm" ], "title": "ResponseType", "description": "Types of tool responses." From 4efa1c4310798727b2949e28e43d4699b81d4098 Mon Sep 17 00:00:00 2001 From: majdyz <zamil.majdy@agpt.co> Date: Wed, 15 Apr 2026 16:31:07 +0700 Subject: [PATCH 03/10] fix(copilot): set session_id on mode-switch T1 to enable --resume on subsequent turns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a user switches from baseline (fast) mode to SDK (extended_thinking) mode mid-session, the first SDK turn has has_history=True (prior baseline messages in DB) but no CLI session file in storage. The old code gated session_id on `not has_history`, so mode-switch T1 never received a session_id — the CLI generated a random ID that wasn't uploaded under the expected key. Every subsequent SDK turn would fail to restore the CLI session and run without --resume, injecting the full compressed history on each turn, causing model confusion. Fix: set session_id whenever not using --resume (the `else` branch), covering T1 fresh, mode-switch T1, and T2+ fallback turns. The retry path is updated to use `"session_id" in sdk_options_kwargs` as the discriminator (instead of `not has_history`) so mode-switch T1 retries also keep the session_id while T2+ retries (where T1 restored a session file via restore_cli_session) still remove it to avoid "Session ID already in use". --- autogpt_platform/backend/backend/copilot/sdk/service.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 19f151f008..ed27b7c134 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -2915,9 +2915,10 @@ async def stream_chat_completion_sdk( sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry["session_id"] = session_id else: - # T2+ retry without --resume: do not pass --session-id. - # The T1 session file already exists at that path; re-using - # the same ID would fail with "Session ID already in use". + # T2+ retry without --resume: initial invocation used + # --resume, which restored the T1 session file to local + # storage. Re-using session_id without --resume would + # fail with "Session ID already in use". sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry.pop("session_id", None) # Recompute system_prompt for retry — ctx.use_resume may have From df205b54448a1ad29d7b10db3406a48b9b8cafef Mon Sep 17 00:00:00 2001 From: Zamil Majdy <zamil.majdy@gmail.com> Date: Wed, 15 Apr 2026 23:18:59 +0700 Subject: [PATCH 04/10] fix(backend/copilot): strip CLI session file to prevent auto-compaction context loss The Claude Code CLI auto-compacts its native session JSONL when the context approaches the model's token limit (~200K for Sonnet). After compaction the detailed conversation history is replaced by a ~27K-token summary, causing the silent context loss users see as memory failures in long sessions. Root cause identified from production logs for session 93ecf7c9: - T6 CLI session: 233KB / ~207K tokens (near Sonnet limit) - T7 CLI compacted session -> ~167KB / ~47K tokens (PreCompact hook missed) - T12 second compaction -> ~176KB / ~27K tokens (just system prompt + summary) - T14-T21: cache_read=26714 constantly -- only system prompt visible to Claude The same stripping we already apply to our transcript (stale thinking blocks, progress/metadata entries) now also runs on the CLI native session file. At ~2x the size of the stripped transcript, unstripped sessions routinely hit the compaction threshold within 6-10 turns of a heavy Opus/thinking session. After stripping: - same-pod turns reuse the stripped local file (no compaction trigger) - cross-pod turns restore the stripped GCS file (same benefit) --- .../backend/backend/copilot/transcript.py | 28 ++- .../backend/copilot/transcript_test.py | 196 ++++++++++++++++++ 2 files changed, 223 insertions(+), 1 deletion(-) diff --git a/autogpt_platform/backend/backend/copilot/transcript.py b/autogpt_platform/backend/backend/copilot/transcript.py index a1e11f352d..ea1bc2e81c 100644 --- a/autogpt_platform/backend/backend/copilot/transcript.py +++ b/autogpt_platform/backend/backend/copilot/transcript.py @@ -716,7 +716,7 @@ async def upload_cli_session( return try: - content = Path(real_path).read_bytes() + raw_bytes = Path(real_path).read_bytes() except FileNotFoundError: logger.debug( "%s CLI session file not found, skipping upload: %s", @@ -728,6 +728,32 @@ async def upload_cli_session( logger.warning("%s Failed to read CLI session file: %s", log_prefix, e) return + # Strip stale thinking blocks and metadata entries (progress, file-history-snapshot, + # queue-operation) from the CLI session before writing it back locally and uploading + # to GCS. Thinking blocks from non-last assistant turns are not needed for --resume + # but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact + # its session when the context window fills up. Stripping keeps the session well below + # the ~200K-token compaction threshold and prevents silent context loss. + try: + raw_text = raw_bytes.decode("utf-8") + stripped_text = strip_for_upload(raw_text) + stripped_bytes = stripped_text.encode("utf-8") + if len(stripped_bytes) < len(raw_bytes): + # Write the stripped version back locally so same-pod turns also benefit. + Path(real_path).write_bytes(stripped_bytes) + logger.info( + "%s Stripped CLI session file: %dB → %dB", + log_prefix, + len(raw_bytes), + len(stripped_bytes), + ) + content = stripped_bytes + except Exception as e: + logger.warning( + "%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e + ) + content = raw_bytes + storage = await get_workspace_storage() wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id) try: diff --git a/autogpt_platform/backend/backend/copilot/transcript_test.py b/autogpt_platform/backend/backend/copilot/transcript_test.py index fec869b6ac..88be88b07a 100644 --- a/autogpt_platform/backend/backend/copilot/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/transcript_test.py @@ -918,6 +918,202 @@ class TestUploadCliSession: mock_storage.store.assert_not_called() + def test_strips_session_before_upload_and_writes_back(self, tmp_path): + """Strippable entries (progress, thinking blocks) are removed before upload. + + The stripped content is written back to disk (so same-pod turns benefit) + and the smaller bytes are uploaded to GCS. + """ + import asyncio + import os + import re + from unittest.mock import AsyncMock, patch + + from .transcript import _sanitize_id, upload_cli_session + + projects_base = str(tmp_path) + session_id = "12345678-0000-0000-0000-000000000010" + sdk_cwd = str(tmp_path) + + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_dir = tmp_path / encoded_cwd + session_dir.mkdir(parents=True, exist_ok=True) + session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" + + # A CLI session with a progress entry (strippable) and a real assistant message. + import json + + progress_entry = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress", "stdout": "running..."}, + } + user_entry = { + "type": "user", + "uuid": "u1", + "message": {"role": "user", "content": "hello"}, + } + asst_entry = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": {"role": "assistant", "content": "world"}, + } + raw_content = ( + json.dumps(progress_entry) + + "\n" + + json.dumps(user_entry) + + "\n" + + json.dumps(asst_entry) + + "\n" + ) + raw_bytes = raw_content.encode("utf-8") + session_file.write_bytes(raw_bytes) + + mock_storage = AsyncMock() + + with ( + patch( + "backend.copilot.transcript._projects_base", + return_value=projects_base, + ), + patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ), + ): + asyncio.run( + upload_cli_session( + user_id="user-1", + session_id=session_id, + sdk_cwd=sdk_cwd, + ) + ) + + # Upload should have been called with stripped bytes (no progress entry). + mock_storage.store.assert_called_once() + stored_content: bytes = mock_storage.store.call_args.kwargs["content"] + stored_lines = stored_content.decode("utf-8").strip().split("\n") + stored_types = [json.loads(line).get("type") for line in stored_lines] + assert "progress" not in stored_types + assert "user" in stored_types + assert "assistant" in stored_types + # Stripped bytes should be smaller than raw. + assert len(stored_content) < len(raw_bytes) + # File on disk should also be the stripped version. + disk_content = session_file.read_bytes() + assert disk_content == stored_content + + def test_strips_stale_thinking_blocks_before_upload(self, tmp_path): + """Thinking blocks in non-last assistant turns are stripped to reduce size.""" + import asyncio + import json + import os + import re + from unittest.mock import AsyncMock, patch + + from .transcript import _sanitize_id, upload_cli_session + + projects_base = str(tmp_path) + session_id = "12345678-0000-0000-0000-000000000011" + sdk_cwd = str(tmp_path) + + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_dir = tmp_path / encoded_cwd + session_dir.mkdir(parents=True, exist_ok=True) + session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" + + # Two turns: first assistant has thinking block (stale), second doesn't. + u1 = { + "type": "user", + "uuid": "u1", + "message": {"role": "user", "content": "q1"}, + } + a1_with_thinking = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": { + "id": "msg_a1", + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "A" * 5000}, + {"type": "text", "text": "answer1"}, + ], + }, + } + u2 = { + "type": "user", + "uuid": "u2", + "parentUuid": "a1", + "message": {"role": "user", "content": "q2"}, + } + a2_no_thinking = { + "type": "assistant", + "uuid": "a2", + "parentUuid": "u2", + "message": { + "id": "msg_a2", + "role": "assistant", + "content": [{"type": "text", "text": "answer2"}], + }, + } + raw_content = ( + json.dumps(u1) + + "\n" + + json.dumps(a1_with_thinking) + + "\n" + + json.dumps(u2) + + "\n" + + json.dumps(a2_no_thinking) + + "\n" + ) + raw_bytes = raw_content.encode("utf-8") + session_file.write_bytes(raw_bytes) + + mock_storage = AsyncMock() + + with ( + patch( + "backend.copilot.transcript._projects_base", + return_value=projects_base, + ), + patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ), + ): + asyncio.run( + upload_cli_session( + user_id="user-1", + session_id=session_id, + sdk_cwd=sdk_cwd, + ) + ) + + stored_content: bytes = mock_storage.store.call_args.kwargs["content"] + stored_lines = stored_content.decode("utf-8").strip().split("\n") + + # a1 should have its thinking block stripped (it's not the last assistant turn). + a1_stored = json.loads(stored_lines[1]) + a1_content = a1_stored["message"]["content"] + assert all( + b["type"] != "thinking" for b in a1_content + ), "stale thinking block should be stripped from a1" + assert any( + b["type"] == "text" for b in a1_content + ), "text block should be kept in a1" + + # a2 (last turn) should be unchanged. + a2_stored = json.loads(stored_lines[3]) + assert a2_stored["message"]["content"] == [{"type": "text", "text": "answer2"}] + + # Stripped bytes smaller than raw. + assert len(stored_content) < len(raw_bytes) + class TestRestoreCliSession: def test_returns_false_when_file_not_found_in_storage(self): From fffbe0aad8225ec0bc2641bd56bec69b4b7fca31 Mon Sep 17 00:00:00 2001 From: Nicholas Tindle <nicholas.tindle@agpt.co> Date: Wed, 15 Apr 2026 11:53:30 -0500 Subject: [PATCH 05/10] fix(backend): default copilot sonnet to 4.6 (#12799) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How Why: Copilot/Autopilot standard requests were still defaulting to Claude Sonnet 4, while the expected default for this path is Sonnet 4.6. What: This PR updates the backend Copilot defaults so the standard/default path and fast path use Sonnet 4.6, and aligns the SDK fallback model and related test expectations. How: It changes `ChatConfig.model`, `ChatConfig.fast_model`, and `ChatConfig.claude_agent_fallback_model` to Sonnet 4.6 values, then updates backend tests that assert the default Sonnet model strings. ### Changes 🏗️ - Switch `ChatConfig.model` from `anthropic/claude-sonnet-4` to `anthropic/claude-sonnet-4-6` - Switch `ChatConfig.fast_model` from `anthropic/claude-sonnet-4` to `anthropic/claude-sonnet-4-6` - Switch `ChatConfig.claude_agent_fallback_model` from `claude-sonnet-4-20250514` to `claude-sonnet-4-6` - Update backend Copilot tests that assert the default Sonnet model strings - Configuration changes: - No new environment variables or docker-compose changes are required - Existing `.env.default` and compose files remain compatible because this only changes backend default model values in code ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] `poetry run format` - [x] `poetry run pytest backend/copilot/baseline/transcript_integration_test.py` - [x] `poetry run pytest backend/copilot/sdk/service_helpers_test.py` - [x] `poetry run pytest backend/copilot/sdk/service_test.py` - [x] `poetry run pytest backend/copilot/sdk/p0_guardrails_test.py` <details> <summary>Example test plan</summary> - [ ] Create from scratch and execute an agent with at least 3 blocks - [ ] Import an agent from file upload, and confirm it executes correctly - [ ] Upload agent to marketplace - [ ] Import an agent from marketplace and confirm it executes correctly - [ ] Edit an agent from monitor, and confirm it executes correctly </details> #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) <details> <summary>Examples of configuration changes</summary> - Changing ports - Adding new services that need to communicate with each other - Secrets or environment variable changes - New or infrastructure changes such as databases </details> <!-- CURSOR_SUMMARY --> --- > [!NOTE] > **Medium Risk** > Changes default/fallback LLM model identifiers for Copilot requests, which can affect runtime behavior, cost, and availability characteristics across both baseline and SDK paths. Risk is mitigated by being a small, config-only change with updated tests. > > **Overview** > Updates Copilot backend defaults so both the standard (`model`) and fast (`fast_model`) paths use `anthropic/claude-sonnet-4-6`, and aligns the Claude Agent SDK fallback model to `claude-sonnet-4-6`. > > Adjusts related test expectations in baseline transcript integration and SDK helper tests to match the new Sonnet 4.6 model strings. > > <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit 563361ac11d46d36c553e0b62fcfd1fb339e2837. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot).</sup> <!-- /CURSOR_SUMMARY --> --- .../copilot/baseline/transcript_integration_test.py | 2 +- autogpt_platform/backend/backend/copilot/config.py | 10 +++++----- .../backend/copilot/sdk/service_helpers_test.py | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index 624abb9acd..baeb3e3648 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -68,7 +68,7 @@ class TestResolveBaselineModel: assert _resolve_baseline_model(None) == config.model def test_default_and_fast_models_same(self): - """SDK 0.1.58: both tiers now use the same model (anthropic/claude-sonnet-4).""" + """SDK defaults currently keep standard and fast on Sonnet 4.6.""" assert config.model == config.fast_model diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index d5418bf872..8792717cad 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -29,13 +29,13 @@ class ChatConfig(BaseSettings): # OpenAI API Configuration model: str = Field( - default="anthropic/claude-sonnet-4", + default="anthropic/claude-sonnet-4-6", description="Default model for extended thinking mode. " - "Changed from Opus ($15/$75 per M) to Sonnet ($3/$15 per M) — " - "5x cheaper. Override via CHAT_MODEL env var for Opus.", + "Uses Sonnet 4.6 as the balanced default. " + "Override via CHAT_MODEL env var if you want a different default.", ) fast_model: str = Field( - default="anthropic/claude-sonnet-4", + default="anthropic/claude-sonnet-4-6", description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.", ) title_model: str = Field( @@ -156,7 +156,7 @@ class ChatConfig(BaseSettings): "history compression. Falls back to compression when unavailable.", ) claude_agent_fallback_model: str = Field( - default="claude-sonnet-4-20250514", + default="claude-sonnet-4-6", description="Fallback model when the primary model is unavailable (e.g. 529 " "overloaded). The SDK automatically retries with this cheaper model.", ) diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index 470858dc55..7c5e429697 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -392,7 +392,9 @@ class TestNormalizeModelName: def test_sonnet_openrouter_model(self): """Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly.""" - assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4" + assert ( + _normalize_model_name("anthropic/claude-sonnet-4-6") == "claude-sonnet-4-6" + ) # --------------------------------------------------------------------------- From 2740b2be3ad213ce3ea9156228626d8a18d11fe9 Mon Sep 17 00:00:00 2001 From: Zamil Majdy <zamil.majdy@agpt.co> Date: Thu, 16 Apr 2026 01:22:20 +0700 Subject: [PATCH 06/10] fix(backend/copilot): disable fallback model to fix prod CLI rejection (#12802) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How **Why:** `fffbe0aad8` changed both `ChatConfig.model` and `ChatConfig.claude_agent_fallback_model` to `claude-sonnet-4-6`. The Claude Code CLI rejects this with `Error: Fallback model cannot be the same as the main model`, causing every standard-mode copilot turn to fail with exit code 1 — the session "completes" in ~30s but produces no response and drops the transcript. **What:** Set `claude_agent_fallback_model` default to `""`. `_resolve_fallback_model()` already returns `None` on empty string, which means the `--fallback-model` flag is simply not passed to the CLI. On 529 overload errors the turn will surface normally instead of silently retrying with a fallback. **How:** One-line config change + test update. ### Changes 🏗️ - `ChatConfig.claude_agent_fallback_model` default: `"claude-sonnet-4-6"` → `""` - Update `test_fallback_model_default` to assert the empty default ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] `poetry run pytest backend/copilot/sdk/p0_guardrails_test.py` #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes --- autogpt_platform/backend/backend/copilot/config.py | 5 +++-- .../backend/backend/copilot/sdk/p0_guardrails_test.py | 8 +++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index 8792717cad..36644de680 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -156,9 +156,10 @@ class ChatConfig(BaseSettings): "history compression. Falls back to compression when unavailable.", ) claude_agent_fallback_model: str = Field( - default="claude-sonnet-4-6", + default="", description="Fallback model when the primary model is unavailable (e.g. 529 " - "overloaded). The SDK automatically retries with this cheaper model.", + "overloaded). The SDK automatically retries with this cheaper model. " + "Empty string disables the fallback (no --fallback-model flag passed to CLI).", ) claude_agent_max_turns: int = Field( default=50, diff --git a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py index 9305320fea..17b54797b8 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py @@ -86,15 +86,14 @@ class TestResolveFallbackModel: assert result == "claude-sonnet-4.5-20250514" def test_default_value(self): - """Default fallback model resolves to a valid string.""" + """Default fallback model resolves to None (disabled by default).""" cfg = _make_config() with patch(f"{_SVC}.config", cfg): from backend.copilot.sdk.service import _resolve_fallback_model result = _resolve_fallback_model() - assert result is not None - assert "sonnet" in result.lower() or "claude" in result.lower() + assert result is None # --------------------------------------------------------------------------- @@ -198,8 +197,7 @@ class TestConfigDefaults: def test_fallback_model_default(self): cfg = _make_config() - assert cfg.claude_agent_fallback_model - assert "sonnet" in cfg.claude_agent_fallback_model.lower() + assert cfg.claude_agent_fallback_model == "" def test_max_turns_default(self): cfg = _make_config() From bd2efed080fffe52b4f2b38768fedd6665501471 Mon Sep 17 00:00:00 2001 From: chernistry <73943355+chernistry@users.noreply.github.com> Date: Thu, 16 Apr 2026 00:25:07 +0300 Subject: [PATCH 07/10] fix(frontend): allow zooming out more in the builder (#12690) Reduced minZoom on the builder canvas from 0.1 to 0.05 to allow zooming out further when working with large agent graphs. Fixes #9325 Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co> --- .../app/(platform)/build/components/FlowEditor/Flow/Flow.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx index 3a55fabf1d..186c8d96fe 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx @@ -110,7 +110,7 @@ export const Flow = () => { event.preventDefault(); }} maxZoom={2} - minZoom={0.1} + minZoom={0.05} onDragOver={onDragOver} onDrop={onDrop} nodesDraggable={!isLocked} From d01a51be0ed2176461a31fb220e6c2221c105a12 Mon Sep 17 00:00:00 2001 From: Toran Bruce Richards <toran.richards@gmail.com> Date: Thu, 16 Apr 2026 06:09:00 +0100 Subject: [PATCH 08/10] Add check for GitHub account connection status (#12807) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added instruction to check GitHub authentication status before prompting user. This prevents repeated, unnecessary asking of the user to add their GitHub credentials when they're already added, which is currently a prevalent bug. ### Changes 🏗️ - Added one line to `autogpt_platform/backend/backend/copilot/prompting.py` instructing AutoPilot to run `gh auth status` before prompting the user to connect their GitHub account. Co-authored-by: Toran Bruce Richards <22963551+Torantulino@users.noreply.github.com> --- autogpt_platform/backend/backend/copilot/prompting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autogpt_platform/backend/backend/copilot/prompting.py b/autogpt_platform/backend/backend/copilot/prompting.py index ec136933e9..ed436733dd 100644 --- a/autogpt_platform/backend/backend/copilot/prompting.py +++ b/autogpt_platform/backend/backend/copilot/prompting.py @@ -174,6 +174,7 @@ sandbox so `bash_exec` can access it for further processing. The exact sandbox path is shown in the `[Sandbox copy available at ...]` note. ### GitHub CLI (`gh`) and git +- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it. - If the user has connected their GitHub account, both `gh` and `git` are pre-authenticated — use them directly without any manual login step. `git` HTTPS operations (clone, push, pull) work automatically. From 0cd0a76305bfb3bcb604d9b212530cacb8dad3b0 Mon Sep 17 00:00:00 2001 From: Zamil Majdy <zamil.majdy@gmail.com> Date: Thu, 16 Apr 2026 14:58:27 +0700 Subject: [PATCH 09/10] fix(backend/copilot): baseline always uploads when GCS has no transcript MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _load_prior_transcript was returning False for missing/invalid transcripts, which caused should_upload_transcript to suppress the upload. The original intent was to protect against overwriting a *newer* GCS version — but a missing or corrupt file is not 'newer'. Only stale (watermark ahead) and download errors (unknown GCS state) should suppress upload. Also renames transcript_covers_prefix → transcript_upload_safe throughout to accurately describe what the flag means. --- .../backend/copilot/baseline/service.py | 40 +++++++++++-------- .../baseline/transcript_integration_test.py | 37 +++++++++-------- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index dd6aa121b6..8239d8248e 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -720,16 +720,15 @@ def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) - def should_upload_transcript( - user_id: str | None, transcript_covers_prefix: bool + user_id: str | None, upload_safe: bool ) -> bool: """Return ``True`` when the caller should upload the final transcript. - Uploads require a logged-in user (for the storage key) *and* a - transcript that covered the session prefix when loaded — otherwise - we'd be overwriting a more complete version in storage with a - partial one built from just the current turn. + Uploads require a logged-in user (for the storage key) *and* a safe + upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a + newer version that we'd be overwriting. """ - return bool(user_id) and transcript_covers_prefix + return bool(user_id) and upload_safe async def _load_prior_transcript( @@ -740,24 +739,30 @@ async def _load_prior_transcript( ) -> bool: """Download and load the prior transcript into ``transcript_builder``. - Returns ``True`` when the loaded transcript fully covers the session - prefix; ``False`` otherwise (stale, missing, invalid, or download - error). Callers should suppress uploads when this returns ``False`` - to avoid overwriting a more complete version in storage. + Returns ``True`` when upload is safe at the end of this turn; ``False`` + when GCS has a *newer* version that we must not overwrite (stale case). + + Upload is suppressed only for **stale** transcripts (GCS watermark > + current turn's prefix) and **download errors** (we can't know what GCS + holds). Missing and invalid transcripts return ``True`` because there is + nothing in GCS worth protecting — uploading is always safe. """ try: dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]") except Exception as e: logger.warning("[Baseline] Transcript download failed: %s", e) + # Unknown GCS state — be conservative and skip upload. return False if dl is None: - logger.debug("[Baseline] No transcript available") - return False + logger.debug("[Baseline] No transcript available — will upload fresh") + # Nothing in GCS to protect; allow upload. + return True if not validate_transcript(dl.content): - logger.warning("[Baseline] Downloaded transcript but invalid") - return False + logger.warning("[Baseline] Downloaded transcript is invalid — will overwrite") + # Corrupt file in GCS; uploading a valid one is strictly better. + return True if is_transcript_stale(dl, session_msg_count): logger.warning( @@ -765,6 +770,7 @@ async def _load_prior_transcript( dl.message_count, session_msg_count, ) + # GCS watermark is ahead of this turn — do not overwrite. return False transcript_builder.load_previous(dl.content, log_prefix="[Baseline]") @@ -897,7 +903,7 @@ async def stream_chat_completion_baseline( # --- Transcript support (feature parity with SDK path) --- transcript_builder = TranscriptBuilder() - transcript_covers_prefix = True + transcript_upload_safe = True # Build system prompt only on the first turn to avoid mid-conversation # changes from concurrent chats updating business understanding. @@ -916,7 +922,7 @@ async def stream_chat_completion_baseline( # on the request critical path. if user_id and len(session.messages) > 1: ( - transcript_covers_prefix, + transcript_upload_safe, (base_system_prompt, understanding), ) = await asyncio.gather( _load_prior_transcript( @@ -1308,7 +1314,7 @@ async def stream_chat_completion_baseline( stop_reason=STOP_REASON_END_TURN, ) - if user_id and should_upload_transcript(user_id, transcript_covers_prefix): + if user_id and should_upload_transcript(user_id, transcript_upload_safe): await _upload_final_transcript( user_id=user_id, session_id=session_id, diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index baeb3e3648..336f3badbc 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -119,24 +119,26 @@ class TestLoadPriorTranscript: assert builder.is_empty @pytest.mark.asyncio - async def test_missing_transcript_returns_false(self): + async def test_missing_transcript_allows_upload(self): + """Nothing in GCS → safe to upload fresh transcript after the turn.""" builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", new=AsyncMock(return_value=None), ): - covers = await _load_prior_transcript( + upload_safe = await _load_prior_transcript( user_id="user-1", session_id="session-1", session_msg_count=2, transcript_builder=builder, ) - assert covers is False + assert upload_safe is True assert builder.is_empty @pytest.mark.asyncio - async def test_invalid_transcript_returns_false(self): + async def test_invalid_transcript_allows_upload(self): + """Corrupt file in GCS → overwriting with valid data is better.""" builder = TranscriptBuilder() download = TranscriptDownload( content='{"type":"progress","uuid":"a"}\n', @@ -146,14 +148,14 @@ class TestLoadPriorTranscript: "backend.copilot.baseline.service.download_transcript", new=AsyncMock(return_value=download), ): - covers = await _load_prior_transcript( + upload_safe = await _load_prior_transcript( user_id="user-1", session_id="session-1", session_msg_count=2, transcript_builder=builder, ) - assert covers is False + assert upload_safe is True assert builder.is_empty @pytest.mark.asyncio @@ -560,7 +562,7 @@ class TestTranscriptLifecycle: # --- 3. Gate + upload --- assert ( should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers + user_id="user-1", upload_safe=covers ) is True ) @@ -611,7 +613,7 @@ class TestTranscriptLifecycle: assert covers is False # The caller's gate mirrors the production path. assert ( - should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers) + should_upload_transcript(user_id="user-1", upload_safe=covers) is False ) upload_mock.assert_not_awaited() @@ -628,14 +630,13 @@ class TestTranscriptLifecycle: ) assert ( - should_upload_transcript(user_id=None, transcript_covers_prefix=True) + should_upload_transcript(user_id=None, upload_safe=True) is False ) @pytest.mark.asyncio async def test_lifecycle_missing_download_still_uploads_new_content(self): - """No prior transcript → covers defaults to True in the service, - new turn should upload cleanly.""" + """No prior transcript → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() upload_mock = AsyncMock(return_value=None) with ( @@ -648,20 +649,18 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers = await _load_prior_transcript( + upload_safe = await _load_prior_transcript( user_id="user-1", session_id="session-1", session_msg_count=1, transcript_builder=builder, ) - # No download: covers is False, so the production path would - # skip upload. This protects against overwriting a future - # more-complete transcript with a single-turn snapshot. - assert covers is False + # Nothing in GCS → upload is safe so the first baseline turn + # can write the initial snapshot. + assert upload_safe is True assert ( should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers + user_id="user-1", upload_safe=upload_safe ) - is False + is True ) - upload_mock.assert_not_awaited() From 0d4b31e8a181a7408784ecd00e07987c8375fb1d Mon Sep 17 00:00:00 2001 From: Zamil Majdy <zamil.majdy@agpt.co> Date: Thu, 16 Apr 2026 15:35:18 +0700 Subject: [PATCH 10/10] =?UTF-8?q?refactor(backend/copilot):=20unified=20tr?= =?UTF-8?q?anscript=20context=20=E2=80=94=20extract=5Fcontext=5Fmessages,?= =?UTF-8?q?=20mode-gated=20--resume,=20compaction-aware=20gap-fill=20(#128?= =?UTF-8?q?04)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How **Why:** The copilot had two separate GCS paths (`cli-sessions/` and `chat-transcripts/`), redundant function names (`upload_cli_session`/`restore_cli_session`), and no shared context strategy between modes. When switching from baseline→SDK or SDK→baseline, the receiving mode discarded the stored transcript and fell back to full DB reconstruction — loading all raw messages instead of the compacted form — causing inflated context, wasted tokens, and loss of CLI compaction summaries. **What:** - Single GCS path (`cli-sessions/`) for both modes — `chat-transcripts/` removed - Unified public API: `upload_transcript` / `download_transcript` / `TranscriptDownload` - `TranscriptMode = Literal["sdk", "baseline"]` persisted in `.meta.json` — SDK skips `--resume` when `mode != "sdk"` (baseline-written JSONL has stripped fields / synthetic IDs) - `extract_context_messages(download, session_messages)` — shared context primitive used by **both SDK and baseline**: reads compacted transcript content + fills only the DB gap (messages after watermark), so CLI compaction summaries are preserved across mode switches - Watermark fix: `_jsonl_covered = transcript_msg_count + 2` when a real transcript is present, preventing false gap detection after `--resume` - Baseline gap-fill: `_append_gap_to_builder` converts `ChatMessage` → JSONL entries; no more silently discarded stale transcripts **How:** ``` SDK turn (mode="sdk" transcript available): ──► --resume [full CLI session restored natively] ──► inject gap prefix if DB has messages after watermark SDK turn (mode="baseline" transcript available): ──► cannot --resume (synthetic CLI IDs) ──► extract_context_messages(download, session_messages): returns transcript JSONL (compacted, isCompactSummary preserved) + gap excludes session_messages[-1] (current turn — caller injects it separately) ──► format as <conversation_history> + "Now, the user says: {current}" Baseline turn (any transcript): ──► _load_prior_transcript → TranscriptDownload ──► extract_context_messages(download, session_messages) + session_messages[-1] replaces full session.messages DB read ──► LLM messages: [compacted history + gap] + [current user turn] Transcript unavailable — both SDK (use_resume=False) and baseline: ──► extract_context_messages(None, session_messages) returns session_messages[:-1] (all prior DB messages except the current user turn at [-1]) ──► graceful fallback — no crash, no empty context ──► covers: first turn, GCS error, corrupt JSONL, missing .meta.json ──► next successful response uploads a fresh transcript ``` `extract_context_messages` is the shared primitive — both modes call the same function, which handles: - `download=None` (first turn, GCS unavailable) → falls back to `session_messages[:-1]` - Empty/corrupt content → falls back to `session_messages[:-1]` - `bytes` content (raw GCS) or `str` content (pre-decoded baseline path) - `isCompactSummary=True` entries → preserved so CLI compaction survives mode switches - Missing/corrupt `.meta.json` → `message_count` defaults to `0`, `mode` defaults to `"sdk"` **Why `[:-1]` and not all messages?** `session_messages[-1]` is always the current user turn being handled right now. Both callers inject it separately — SDK wraps it as `"Now, the user says: ..."`, baseline appends it as the final message in the LLM array. Returning it inside `extract_context_messages` would double-inject it. ### Changes 🏗️ - **`transcript.py`**: `CliSessionRestore` → `TranscriptDownload` + `mode` field; `upload_cli_session` → `upload_transcript`; `restore_cli_session` → `download_transcript`; add `TranscriptMode`, `detect_gap`, `extract_context_messages`; import `ChatMessage` via relative path to match `service.py` style - **`sdk/service.py`**: mode-check before `--resume`; `_RestoreResult` carries `baseline_download` + `context_messages` + `transcript_content`; `_build_query_message` accepts `prior_messages` override; `_restore_cli_session_for_turn` populates `context_messages` via `extract_context_messages` and sets `transcript_content` to prevent duplicate DB reconstruction; watermark fix (`_jsonl_covered = transcript_msg_count + 2`) - **`baseline/service.py`**: `_load_prior_transcript` returns `(bool, TranscriptDownload | None)`; LLM context replaced with `extract_context_messages(download, messages)`; `_append_gap_to_builder` + `detect_gap` call; `upload_transcript(mode="baseline")` - **`sdk/transcript.py`**: updated re-exports, old aliases removed - **`scripts/download_transcripts.py`**: updated for `bytes | str` content type - **Test files**: 179 tests total; `transcript_test.py`, `baseline/transcript_integration_test.py`, `sdk/service_helpers_test.py`, `sdk/test_transcript_watermark.py`, `test/copilot/test_transcript_watermark.py` all updated/added ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] 179 unit tests pass — `transcript_test`, `baseline/transcript_integration_test`, `sdk/service_helpers_test`, `sdk/test_transcript_watermark` - [x] pyright 0 errors on all changed files - [x] SDK `--resume` path still works when `mode="sdk"` transcript is present - [x] SDK fallback uses `extract_context_messages` (compacted baseline content + gap) when `mode="baseline"` transcript is stored — no more full DB reconstruction - [x] Baseline uses `extract_context_messages` per turn instead of full `session.messages` DB read - [x] `isCompactSummary=True` entries preserved across mode switches - [x] Watermark (`_jsonl_covered`) fix prevents false gap detection after `--resume` - [x] Baseline gap detection no longer silently discards stale transcripts - [x] `TranscriptDownload.content` accepts `bytes | str` — backward compatible - [x] Transcript unavailable (GCS error, first turn, corrupt file) gracefully falls back to `session_messages[:-1]` without crash — applies to both SDK and baseline paths --------- Co-authored-by: chernistry <73943355+chernistry@users.noreply.github.com> Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co> --- .../backend/copilot/baseline/service.py | 210 +++-- .../baseline/transcript_integration_test.py | 329 ++++--- .../backend/backend/copilot/context.py | 2 +- .../copilot/sdk/mode_switch_context_test.py | 57 +- .../copilot/sdk/retry_scenarios_test.py | 25 +- .../backend/copilot/sdk/security_hooks.py | 2 +- .../backend/backend/copilot/sdk/service.py | 642 ++++++++----- .../copilot/sdk/service_helpers_test.py | 338 +++++++ .../copilot/sdk/test_transcript_watermark.py | 95 ++ .../backend/backend/copilot/sdk/transcript.py | 16 +- .../backend/copilot/sdk/transcript_test.py | 197 +++- .../backend/backend/copilot/service_test.py | 15 +- .../backend/backend/copilot/transcript.py | 512 +++++------ .../backend/copilot/transcript_test.py | 855 +++++++++++------- .../backend/scripts/download_transcripts.py | 14 +- .../test/copilot/test_transcript_watermark.py | 140 +++ 16 files changed, 2396 insertions(+), 1053 deletions(-) create mode 100644 autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py create mode 100644 autogpt_platform/backend/test/copilot/test_transcript_watermark.py diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index dd6aa121b6..a2813ad881 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -67,11 +67,15 @@ from backend.copilot.transcript import ( STOP_REASON_END_TURN, STOP_REASON_TOOL_USE, TranscriptDownload, + detect_gap, download_transcript, + extract_context_messages, + strip_for_upload, upload_transcript, validate_transcript, ) from backend.copilot.transcript_builder import TranscriptBuilder +from backend.util import json as util_json from backend.util.exceptions import NotFoundError from backend.util.prompt import ( compress_context, @@ -699,81 +703,147 @@ async def _compress_session_messages( return messages -def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool: - """Return ``True`` when a download doesn't cover the current session. - - A transcript is stale when it has a known ``message_count`` and that - count doesn't reach ``session_msg_count - 1`` (i.e. the session has - already advanced beyond what the stored transcript captures). - Loading a stale transcript would silently drop intermediate turns, - so callers should treat stale as "skip load, skip upload". - - An unknown ``message_count`` (``0``) is treated as **not stale** - because older transcripts uploaded before msg_count tracking - existed must still be usable. - """ - if dl is None: - return False - if not dl.message_count: - return False - return dl.message_count < session_msg_count - 1 - - -def should_upload_transcript( - user_id: str | None, transcript_covers_prefix: bool -) -> bool: +def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool: """Return ``True`` when the caller should upload the final transcript. - Uploads require a logged-in user (for the storage key) *and* a - transcript that covered the session prefix when loaded — otherwise - we'd be overwriting a more complete version in storage with a - partial one built from just the current turn. + Uploads require a logged-in user (for the storage key) *and* a safe + upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a + newer version that we'd be overwriting. """ - return bool(user_id) and transcript_covers_prefix + return bool(user_id) and upload_safe + + +def _append_gap_to_builder( + gap: list[ChatMessage], + builder: TranscriptBuilder, +) -> None: + """Append gap messages from chat-db into the TranscriptBuilder. + + Converts ChatMessage (OpenAI format) to TranscriptBuilder entries + (Claude CLI JSONL format) so the uploaded transcript covers all turns. + + Pre-condition: ``gap`` always starts at a user or assistant boundary + (never mid-turn at a ``tool`` role), because ``detect_gap`` enforces + ``session_messages[wm-1].role == 'assistant'`` before returning a non-empty + gap. Any ``tool`` role messages within the gap always follow an assistant + entry that already exists in the builder or in the gap itself. + """ + for msg in gap: + if msg.role == "user": + builder.append_user(msg.content or "") + elif msg.role == "assistant": + content_blocks: list[dict] = [] + if msg.content: + content_blocks.append({"type": "text", "text": msg.content}) + if msg.tool_calls: + for tc in msg.tool_calls: + fn = tc.get("function", {}) if isinstance(tc, dict) else {} + input_data = util_json.loads(fn.get("arguments", "{}"), fallback={}) + content_blocks.append( + { + "type": "tool_use", + "id": tc.get("id", "") if isinstance(tc, dict) else "", + "name": fn.get("name", "unknown"), + "input": input_data, + } + ) + if not content_blocks: + # Fallback: ensure every assistant gap message produces an entry + # so the builder's entry count matches the gap length. + content_blocks.append({"type": "text", "text": ""}) + builder.append_assistant(content_blocks=content_blocks) + elif msg.role == "tool": + if msg.tool_call_id: + builder.append_tool_result( + tool_use_id=msg.tool_call_id, + content=msg.content or "", + ) + else: + # Malformed tool message — no tool_call_id to link to an + # assistant tool_use block. Skip to avoid an unmatched + # tool_result entry in the builder (which would confuse --resume). + logger.warning( + "[Baseline] Skipping tool gap message with no tool_call_id" + ) async def _load_prior_transcript( user_id: str, session_id: str, - session_msg_count: int, + session_messages: list[ChatMessage], transcript_builder: TranscriptBuilder, -) -> bool: - """Download and load the prior transcript into ``transcript_builder``. +) -> tuple[bool, "TranscriptDownload | None"]: + """Download and load the prior CLI session into ``transcript_builder``. - Returns ``True`` when the loaded transcript fully covers the session - prefix; ``False`` otherwise (stale, missing, invalid, or download - error). Callers should suppress uploads when this returns ``False`` - to avoid overwriting a more complete version in storage. + Returns a tuple of (upload_safe, transcript_download): + - ``upload_safe`` is ``True`` when it is safe to upload at the end of this + turn. Upload is suppressed only for **download errors** (unknown GCS + state) — missing and invalid files return ``True`` because there is + nothing in GCS worth protecting against overwriting. + - ``transcript_download`` is a ``TranscriptDownload`` with str content + (pre-decoded and stripped) when available, or ``None`` when no valid + transcript could be loaded. Callers pass this to + ``extract_context_messages`` to build the LLM context. """ try: - dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]") - except Exception as e: - logger.warning("[Baseline] Transcript download failed: %s", e) - return False - - if dl is None: - logger.debug("[Baseline] No transcript available") - return False - - if not validate_transcript(dl.content): - logger.warning("[Baseline] Downloaded transcript but invalid") - return False - - if is_transcript_stale(dl, session_msg_count): - logger.warning( - "[Baseline] Transcript stale: covers %d of %d messages, skipping", - dl.message_count, - session_msg_count, + restore = await download_transcript( + user_id, session_id, log_prefix="[Baseline]" ) - return False + except Exception as e: + logger.warning("[Baseline] Session restore failed: %s", e) + # Unknown GCS state — be conservative, skip upload. + return False, None - transcript_builder.load_previous(dl.content, log_prefix="[Baseline]") + if restore is None: + logger.debug("[Baseline] No CLI session available — will upload fresh") + # Nothing in GCS to protect; allow upload so the first baseline turn + # writes the initial transcript snapshot. + return True, None + + content_bytes = restore.content + try: + raw_str = ( + content_bytes.decode("utf-8") + if isinstance(content_bytes, bytes) + else content_bytes + ) + except UnicodeDecodeError: + logger.warning("[Baseline] CLI session content is not valid UTF-8") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + stripped = strip_for_upload(raw_str) + if not validate_transcript(stripped): + logger.warning("[Baseline] CLI session content invalid after strip") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + transcript_builder.load_previous(stripped, log_prefix="[Baseline]") logger.info( - "[Baseline] Loaded transcript: %dB, msg_count=%d", - len(dl.content), - dl.message_count, + "[Baseline] Loaded CLI session: %dB, msg_count=%d", + len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str), + restore.message_count, ) - return True + + gap = detect_gap(restore, session_messages) + if gap: + _append_gap_to_builder(gap, transcript_builder) + logger.info( + "[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB", + restore.message_count, + len(gap), + ) + + # Return a str-content version so extract_context_messages receives a + # pre-decoded, stripped transcript (avoids redundant decode + strip). + # TranscriptDownload.content is typed as bytes | str; we pass str here + # to avoid a redundant encode + decode round-trip. + str_restore = TranscriptDownload( + content=stripped, + message_count=restore.message_count, + mode=restore.mode, + ) + return True, str_restore async def _upload_final_transcript( @@ -807,10 +877,10 @@ async def _upload_final_transcript( upload_transcript( user_id=user_id, session_id=session_id, - content=content, + content=content.encode("utf-8"), message_count=session_msg_count, + mode="baseline", log_prefix="[Baseline]", - skip_strip=True, ) ) _background_tasks.add(upload_task) @@ -897,7 +967,7 @@ async def stream_chat_completion_baseline( # --- Transcript support (feature parity with SDK path) --- transcript_builder = TranscriptBuilder() - transcript_covers_prefix = True + transcript_upload_safe = True # Build system prompt only on the first turn to avoid mid-conversation # changes from concurrent chats updating business understanding. @@ -914,15 +984,16 @@ async def stream_chat_completion_baseline( # Run download + prompt build concurrently — both are independent I/O # on the request critical path. + transcript_download: TranscriptDownload | None = None if user_id and len(session.messages) > 1: ( - transcript_covers_prefix, + (transcript_upload_safe, transcript_download), (base_system_prompt, understanding), ) = await asyncio.gather( _load_prior_transcript( user_id=user_id, session_id=session_id, - session_msg_count=len(session.messages), + session_messages=session.messages, transcript_builder=transcript_builder, ), prompt_task, @@ -962,9 +1033,14 @@ async def stream_chat_completion_baseline( warm_ctx = await fetch_warm_context(user_id, message or "") - # Compress context if approaching the model's token limit + # Context path: transcript content (compacted, isCompactSummary preserved) + + # gap (DB messages after watermark) + current user turn. + # This avoids re-reading the full session history from DB on every turn. + # See extract_context_messages() in transcript.py for the shared primitive. + prior_context = extract_context_messages(transcript_download, session.messages) messages_for_context = await _compress_session_messages( - session.messages, model=active_model + prior_context + ([session.messages[-1]] if session.messages else []), + model=active_model, ) # Build OpenAI message list from session history. @@ -1308,7 +1384,7 @@ async def stream_chat_completion_baseline( stop_reason=STOP_REASON_END_TURN, ) - if user_id and should_upload_transcript(user_id, transcript_covers_prefix): + if user_id and should_upload_transcript(user_id, transcript_upload_safe): await _upload_final_transcript( user_id=user_id, session_id=session_id, diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index baeb3e3648..4247c76c19 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -1,7 +1,7 @@ """Integration tests for baseline transcript flow. -Exercises the real helpers in ``baseline/service.py`` that download, -validate, load, append to, backfill, and upload the transcript. +Exercises the real helpers in ``baseline/service.py`` that restore, +validate, load, append to, backfill, and upload the CLI session. Storage is mocked via ``download_transcript`` / ``upload_transcript`` patches; no network access is required. """ @@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch import pytest from backend.copilot.baseline.service import ( + _append_gap_to_builder, _load_prior_transcript, _record_turn_to_transcript, _resolve_baseline_model, _upload_final_transcript, - is_transcript_stale, should_upload_transcript, ) +from backend.copilot.model import ChatMessage from backend.copilot.service import config from backend.copilot.transcript import ( STOP_REASON_END_TURN, @@ -54,6 +55,13 @@ def _make_transcript_content(*roles: str) -> str: return "\n".join(lines) + "\n" +def _make_session_messages(*roles: str) -> list[ChatMessage]: + """Build a list of ChatMessage objects matching the given roles.""" + return [ + ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles) + ] + + class TestResolveBaselineModel: """Model selection honours the per-request mode.""" @@ -73,87 +81,102 @@ class TestResolveBaselineModel: class TestLoadPriorTranscript: - """``_load_prior_transcript`` wraps the download + validate + load flow.""" + """``_load_prior_transcript`` wraps the CLI session restore + validate + load flow.""" @pytest.mark.asyncio async def test_loads_fresh_transcript(self): builder = TranscriptBuilder() content = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=content, message_count=2) + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="sdk" + ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True + assert dl is not None + assert dl.message_count == 2 assert builder.entry_count == 2 assert builder.last_entry_type == "assistant" @pytest.mark.asyncio - async def test_rejects_stale_transcript(self): - """msg_count strictly less than session-1 is treated as stale.""" + async def test_fills_gap_when_transcript_is_behind(self): + """When transcript covers fewer messages than session, gap is filled from DB.""" builder = TranscriptBuilder() content = _make_transcript_content("user", "assistant") - # session has 6 messages, transcript only covers 2 → stale. - download = TranscriptDownload(content=content, message_count=2) + # transcript covers 2 messages, session has 4 (plus current user turn = 5) + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="baseline" + ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=6, + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), transcript_builder=builder, ) - assert covers is False - assert builder.is_empty + assert covers is True + assert dl is not None + # 2 from transcript + 2 gap messages (user+assistant at positions 2,3) + assert builder.entry_count == 4 @pytest.mark.asyncio - async def test_missing_transcript_returns_false(self): + async def test_missing_transcript_allows_upload(self): + """Nothing in GCS → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", new=AsyncMock(return_value=None), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True + assert dl is None assert builder.is_empty @pytest.mark.asyncio - async def test_invalid_transcript_returns_false(self): + async def test_invalid_transcript_allows_upload(self): + """Corrupt file in GCS → overwriting with a valid one is better.""" builder = TranscriptBuilder() - download = TranscriptDownload( - content='{"type":"progress","uuid":"a"}\n', + restore = TranscriptDownload( + content=b'{"type":"progress","uuid":"a"}\n', message_count=1, + mode="sdk", ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True + assert dl is None assert builder.is_empty @pytest.mark.asyncio @@ -163,36 +186,39 @@ class TestLoadPriorTranscript: "backend.copilot.baseline.service.download_transcript", new=AsyncMock(side_effect=RuntimeError("boom")), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) assert covers is False + assert dl is None assert builder.is_empty @pytest.mark.asyncio async def test_zero_message_count_not_stale(self): - """When msg_count is 0 (unknown), staleness check is skipped.""" + """When msg_count is 0 (unknown), gap detection is skipped.""" builder = TranscriptBuilder() - download = TranscriptDownload( - content=_make_transcript_content("user", "assistant"), + restore = TranscriptDownload( + content=_make_transcript_content("user", "assistant").encode("utf-8"), message_count=0, + mode="sdk", ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=20, + session_messages=_make_session_messages(*["user"] * 20), transcript_builder=builder, ) assert covers is True + assert dl is not None assert builder.entry_count == 2 @@ -227,7 +253,7 @@ class TestUploadFinalTranscript: assert call_kwargs["user_id"] == "user-1" assert call_kwargs["session_id"] == "session-1" assert call_kwargs["message_count"] == 2 - assert "hello" in call_kwargs["content"] + assert b"hello" in call_kwargs["content"] @pytest.mark.asyncio async def test_skips_upload_when_builder_empty(self): @@ -374,17 +400,19 @@ class TestRoundTrip: @pytest.mark.asyncio async def test_full_round_trip(self): prior = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=prior, message_count=2) + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True @@ -424,11 +452,11 @@ class TestRoundTrip: upload_mock.assert_awaited_once() assert upload_mock.await_args is not None uploaded = upload_mock.await_args.kwargs["content"] - assert "new question" in uploaded - assert "new answer" in uploaded + assert b"new question" in uploaded + assert b"new answer" in uploaded # Original content preserved in the round trip. - assert "user message 0" in uploaded - assert "assistant message 1" in uploaded + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded @pytest.mark.asyncio async def test_backfill_append_guard(self): @@ -459,36 +487,6 @@ class TestRoundTrip: assert builder.entry_count == initial_count -class TestIsTranscriptStale: - """``is_transcript_stale`` gates prior-transcript loading.""" - - def test_none_download_is_not_stale(self): - assert is_transcript_stale(None, session_msg_count=5) is False - - def test_zero_message_count_is_not_stale(self): - """Legacy transcripts without msg_count tracking must remain usable.""" - dl = TranscriptDownload(content="", message_count=0) - assert is_transcript_stale(dl, session_msg_count=20) is False - - def test_stale_when_covers_less_than_prefix(self): - dl = TranscriptDownload(content="", message_count=2) - # session has 6 messages; transcript must cover at least 5 (6-1). - assert is_transcript_stale(dl, session_msg_count=6) is True - - def test_fresh_when_covers_full_prefix(self): - dl = TranscriptDownload(content="", message_count=5) - assert is_transcript_stale(dl, session_msg_count=6) is False - - def test_fresh_when_exceeds_prefix(self): - """Race: transcript ahead of session count is still acceptable.""" - dl = TranscriptDownload(content="", message_count=10) - assert is_transcript_stale(dl, session_msg_count=6) is False - - def test_boundary_equal_to_prefix_minus_one(self): - dl = TranscriptDownload(content="", message_count=5) - assert is_transcript_stale(dl, session_msg_count=6) is False - - class TestShouldUploadTranscript: """``should_upload_transcript`` gates the final upload.""" @@ -510,7 +508,7 @@ class TestShouldUploadTranscript: class TestTranscriptLifecycle: - """End-to-end: download → validate → build → upload. + """End-to-end: restore → validate → build → upload. Simulates the full transcript lifecycle inside ``stream_chat_completion_baseline`` by mocking the storage layer and @@ -519,27 +517,29 @@ class TestTranscriptLifecycle: @pytest.mark.asyncio async def test_full_lifecycle_happy_path(self): - """Fresh download, append a turn, upload covers the session.""" + """Fresh restore, append a turn, upload covers the session.""" builder = TranscriptBuilder() prior = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=prior, message_count=2) + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) upload_mock = AsyncMock(return_value=None) with ( patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ), patch( "backend.copilot.baseline.service.upload_transcript", new=upload_mock, ), ): - # --- 1. Download & load prior transcript --- - covers = await _load_prior_transcript( + # --- 1. Restore & load prior session --- + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True @@ -559,10 +559,7 @@ class TestTranscriptLifecycle: # --- 3. Gate + upload --- assert ( - should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers - ) - is True + should_upload_transcript(user_id="user-1", upload_safe=covers) is True ) await _upload_final_transcript( user_id="user-1", @@ -574,20 +571,21 @@ class TestTranscriptLifecycle: upload_mock.assert_awaited_once() assert upload_mock.await_args is not None uploaded = upload_mock.await_args.kwargs["content"] - assert "follow-up question" in uploaded - assert "follow-up answer" in uploaded + assert b"follow-up question" in uploaded + assert b"follow-up answer" in uploaded # Original prior-turn content preserved. - assert "user message 0" in uploaded - assert "assistant message 1" in uploaded + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded @pytest.mark.asyncio - async def test_lifecycle_stale_download_suppresses_upload(self): - """Stale download → covers=False → upload must be skipped.""" + async def test_lifecycle_stale_download_fills_gap(self): + """When transcript covers fewer messages, gap is filled rather than rejected.""" builder = TranscriptBuilder() - # session has 10 msgs but stored transcript only covers 2 → stale. + # session has 5 msgs but stored transcript only covers 2 → gap filled. stale = TranscriptDownload( - content=_make_transcript_content("user", "assistant"), + content=_make_transcript_content("user", "assistant").encode("utf-8"), message_count=2, + mode="baseline", ) upload_mock = AsyncMock(return_value=None) @@ -601,20 +599,18 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers = await _load_prior_transcript( + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=10, + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), transcript_builder=builder, ) - assert covers is False - # The caller's gate mirrors the production path. - assert ( - should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers) - is False - ) - upload_mock.assert_not_awaited() + assert covers is True + # Gap was filled: 2 from transcript + 2 gap messages + assert builder.entry_count == 4 @pytest.mark.asyncio async def test_lifecycle_anonymous_user_skips_upload(self): @@ -627,15 +623,11 @@ class TestTranscriptLifecycle: stop_reason=STOP_REASON_END_TURN, ) - assert ( - should_upload_transcript(user_id=None, transcript_covers_prefix=True) - is False - ) + assert should_upload_transcript(user_id=None, upload_safe=True) is False @pytest.mark.asyncio async def test_lifecycle_missing_download_still_uploads_new_content(self): - """No prior transcript → covers defaults to True in the service, - new turn should upload cleanly.""" + """No prior session → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() upload_mock = AsyncMock(return_value=None) with ( @@ -648,20 +640,117 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=1, + session_messages=_make_session_messages("user"), transcript_builder=builder, ) - # No download: covers is False, so the production path would - # skip upload. This protects against overwriting a future - # more-complete transcript with a single-turn snapshot. - assert covers is False + # Nothing in GCS → upload is safe so the first baseline turn + # can write the initial transcript snapshot. + assert upload_safe is True + assert dl is None assert ( - should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers - ) - is False + should_upload_transcript(user_id="user-1", upload_safe=upload_safe) + is True ) - upload_mock.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# _append_gap_to_builder +# --------------------------------------------------------------------------- + + +class TestAppendGapToBuilder: + """``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries.""" + + def test_user_message_appended(self): + builder = TranscriptBuilder() + msgs = [ChatMessage(role="user", content="hello")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + assert builder.last_entry_type == "user" + + def test_assistant_text_message_appended(self): + builder = TranscriptBuilder() + msgs = [ + ChatMessage(role="user", content="q"), + ChatMessage(role="assistant", content="answer"), + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 2 + assert builder.last_entry_type == "assistant" + assert "answer" in builder.to_jsonl() + + def test_assistant_with_tool_calls_appended(self): + """Assistant tool_calls are recorded as tool_use blocks in the transcript.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-1", + "type": "function", + "function": {"name": "my_tool", "arguments": '{"key":"val"}'}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "tool_use" in jsonl + assert "my_tool" in jsonl + assert "tc-1" in jsonl + + def test_assistant_invalid_json_args_uses_empty_dict(self): + """Malformed JSON in tool_call arguments falls back to {}.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-bad", + "type": "function", + "function": {"name": "bad_tool", "arguments": "not-json"}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert '"input":{}' in jsonl + + def test_assistant_empty_content_and_no_tools_uses_fallback(self): + """Assistant with no content and no tool_calls gets a fallback empty text block.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="assistant", content=None)] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "text" in jsonl + + def test_tool_role_with_tool_call_id_appended(self): + """Tool result messages are appended when tool_call_id is set.""" + builder = TranscriptBuilder() + # Need a preceding assistant tool_use entry + builder.append_user("use tool") + builder.append_assistant( + content_blocks=[ + {"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}} + ] + ) + msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 3 + assert "tool_result" in builder.to_jsonl() + + def test_tool_role_without_tool_call_id_skipped(self): + """Tool messages without tool_call_id are silently skipped.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 0 + + def test_tool_call_missing_function_key_uses_unknown_name(self): + """A tool_call dict with no 'function' key uses 'unknown' as the tool name.""" + builder = TranscriptBuilder() + # Tool call dict exists but 'function' sub-dict is missing entirely + msgs = [ + ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}]) + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "unknown" in jsonl diff --git a/autogpt_platform/backend/backend/copilot/context.py b/autogpt_platform/backend/backend/copilot/context.py index 895aa6c4a1..7a22f02cb2 100644 --- a/autogpt_platform/backend/backend/copilot/context.py +++ b/autogpt_platform/backend/backend/copilot/context.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: # Allowed base directory for the Read tool. Public so service.py can use it # for sweep operations without depending on a private implementation detail. # Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's -# _projects_base() function. +# projects_base() function. _config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude") SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects")) diff --git a/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py b/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py index 5e1ef41979..212fca189b 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py @@ -8,7 +8,7 @@ Cross-mode transcript flow ========================== Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking -mode) read and write the same JSONL transcript store via +mode) read and write the same CLI session store via ``backend.copilot.transcript.upload_transcript`` / ``download_transcript``. @@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch: @pytest.mark.asyncio async def test_scenario_s_baseline_loads_sdk_transcript(self): - """Scenario S: SDK-written transcript is accepted by baseline's load helper.""" + """Scenario S: SDK-written CLI session is accepted by baseline's load helper.""" from backend.copilot.baseline.service import _load_prior_transcript + from backend.copilot.model import ChatMessage from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload from backend.copilot.transcript_builder import TranscriptBuilder @@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch: sdk_transcript = builder_sdk.to_jsonl() # Baseline session now has those 2 SDK messages + 1 new baseline message. - download = TranscriptDownload(content=sdk_transcript, message_count=2) + restore = TranscriptDownload( + content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk" + ) baseline_builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, # 2 SDK + 1 new baseline + session_messages=[ + ChatMessage(role="user", content="sdk-question"), + ChatMessage(role="assistant", content="sdk-answer"), + ChatMessage(role="user", content="baseline-question"), + ], transcript_builder=baseline_builder, ) - # Transcript is valid and covers the prefix. + # CLI session is valid and covers the prefix. assert covers is True + assert dl is not None assert baseline_builder.entry_count == 2 @pytest.mark.asyncio async def test_scenario_s_stale_sdk_transcript_not_loaded(self): - """Scenario S (stale): SDK transcript is stale — baseline does not load it. + """Scenario S (stale): SDK CLI session is stale — baseline does not load it. - If SDK mode produced more turns than the transcript captured (e.g. - upload failed on one turn), the baseline rejects the stale transcript + If SDK mode produced more turns than the session captured (e.g. + upload failed on one turn), the baseline rejects the stale session to avoid injecting an incomplete history. """ from backend.copilot.baseline.service import _load_prior_transcript + from backend.copilot.model import ChatMessage from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload from backend.copilot.transcript_builder import TranscriptBuilder @@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch: ) sdk_transcript = builder_sdk.to_jsonl() - # Transcript covers only 2 messages but session has 10 (many SDK turns). - download = TranscriptDownload(content=sdk_transcript, message_count=2) + # Session covers only 2 messages but session has 10 (many SDK turns). + # With watermark=2 and 10 total messages, detect_gap will fill the gap + # by appending messages 2..8 (positions 2 to total-2). + restore = TranscriptDownload( + content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk" + ) + + # Build a session with 10 alternating user/assistant messages + current user + session_messages = [ + ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}") + for i in range(10) + ] baseline_builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=10, + session_messages=session_messages, transcript_builder=baseline_builder, ) - # Stale transcript must be rejected. - assert covers is False - assert baseline_builder.is_empty + # With gap filling, covers is True and gap messages are appended. + assert covers is True + assert dl is not None + # 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn) + assert baseline_builder.entry_count == 9 diff --git a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py index a48d7def3d..60c65f00ce 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py @@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.copilot.transcript import ( + TranscriptDownload, _flatten_assistant_content, _flatten_tool_result_content, _messages_to_transcript, @@ -999,14 +1000,15 @@ def _make_sdk_patches( f"{_SVC}.download_transcript", dict( new_callable=AsyncMock, - return_value=MagicMock(content=original_transcript, message_count=2), + return_value=TranscriptDownload( + content=original_transcript.encode("utf-8"), + message_count=2, + mode="sdk", + ), ), ), - ( - f"{_SVC}.restore_cli_session", - dict(new_callable=AsyncMock, return_value=True), - ), - (f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)), + (f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)), + (f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)), (f"{_SVC}.validate_transcript", dict(return_value=True)), ( f"{_SVC}.compact_transcript", @@ -1037,7 +1039,6 @@ def _make_sdk_patches( claude_agent_fallback_model=None, ), ), - (f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)), (f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)), ] @@ -1914,14 +1915,14 @@ class TestStreamChatCompletionRetryIntegration: compacted_transcript=None, client_side_effect=_client_factory, ) - # Override restore_cli_session to return False (CLI native session unavailable) + # Override download_transcript to return None (CLI native session unavailable) patches = [ ( ( - f"{_SVC}.restore_cli_session", - dict(new_callable=AsyncMock, return_value=False), + f"{_SVC}.download_transcript", + dict(new_callable=AsyncMock, return_value=None), ) - if p[0] == f"{_SVC}.restore_cli_session" + if p[0] == f"{_SVC}.download_transcript" else p ) for p in patches @@ -1944,7 +1945,7 @@ class TestStreamChatCompletionRetryIntegration: # captured_options holds {"options": ClaudeAgentOptions}, so check # the attribute directly rather than dict keys. assert not getattr(captured_options.get("options"), "resume", None), ( - f"--resume was set even though restore_cli_session returned False: " + f"--resume was set even though download_transcript returned None: " f"{captured_options}" ) assert any(isinstance(e, StreamStart) for e in events) diff --git a/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py index e5ba184f4f..666e55fbba 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py +++ b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py @@ -365,7 +365,7 @@ def create_security_hooks( trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50) # Sanitize untrusted input: strip control chars for logging AND # for the value passed downstream. read_compacted_entries() - # validates against _projects_base() as defence-in-depth, but + # validates against projects_base() as defence-in-depth, but # sanitizing here prevents log injection and rejects obviously # malformed paths early. transcript_path = _sanitize( diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 19f151f008..936f1f8df1 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -16,6 +16,7 @@ import uuid from collections.abc import AsyncGenerator, AsyncIterator from dataclasses import dataclass from dataclasses import field as dataclass_field +from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple, NotRequired, cast if TYPE_CHECKING: @@ -92,12 +93,15 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path from ..tracking import track_user_message from ..transcript import ( _run_compression, + TranscriptDownload, cleanup_stale_project_dirs, + cli_session_path, compact_transcript, download_transcript, + extract_context_messages, + projects_base, read_compacted_entries, - restore_cli_session, - upload_cli_session, + strip_for_upload, upload_transcript, validate_transcript, ) @@ -849,6 +853,181 @@ def _make_sdk_cwd(session_id: str) -> str: return cwd +def _write_cli_session_to_disk( + content: bytes, + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> bool: + """Write downloaded CLI session bytes to disk so the CLI can --resume. + + Returns True on success, False if the path is invalid or the write fails. + Path-traversal guard: rejects paths outside the CLI projects base. + """ + session_file = cli_session_path(sdk_cwd, session_id) + real_path = os.path.realpath(session_file) + _pbase = projects_base() + if not real_path.startswith(_pbase + os.sep): + logger.warning( + "%s CLI session restore path outside projects base: %s", + log_prefix, + os.path.basename(session_file), + ) + return False + try: + os.makedirs(os.path.dirname(real_path), exist_ok=True) + Path(real_path).write_bytes(content) + logger.info( + "%s Wrote CLI session to disk (%dB) for --resume", + log_prefix, + len(content), + ) + return True + except OSError as e: + logger.warning( + "%s Failed to write CLI session file %s: %s", + log_prefix, + os.path.basename(session_file), + e.strerror or str(e), + ) + return False + + +def _read_cli_session_from_disk( + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> bytes | None: + """Read the CLI session JSONL file from disk after the SDK turn. + + Returns the file bytes, or None if the file is missing, outside the + projects base, or unreadable. + Path-traversal guard: rejects paths outside the CLI projects base. + """ + session_file = cli_session_path(sdk_cwd, session_id) + real_path = os.path.realpath(session_file) + _pbase = projects_base() + if not real_path.startswith(_pbase + os.sep): + logger.warning( + "%s CLI session file outside projects base, skipping upload: %s", + log_prefix, + os.path.basename(real_path), + ) + return None + try: + raw_bytes = Path(real_path).read_bytes() + except FileNotFoundError: + logger.debug( + "%s CLI session file not found, skipping upload: %s", + log_prefix, + os.path.basename(session_file), + ) + return None + except OSError as e: + logger.warning( + "%s Failed to read CLI session file %s: %s", + log_prefix, + os.path.basename(session_file), + e.strerror or str(e), + ) + return None + + # Strip stale thinking blocks and metadata entries before uploading. + # Thinking blocks from non-last turns can be massive; keeping them causes + # the CLI to auto-compact its session when the context window fills up, + # silently losing conversation history. + try: + raw_text = raw_bytes.decode("utf-8") + stripped_text = strip_for_upload(raw_text) + stripped_bytes = stripped_text.encode("utf-8") + except UnicodeDecodeError: + logger.warning("%s CLI session is not valid UTF-8, uploading raw", log_prefix) + return raw_bytes + except (OSError, ValueError) as e: + # OSError: encode/decode I/O failure; ValueError: malformed JSONL in strip. + # Other unexpected exceptions are not silently swallowed here so they propagate + # to the outer OSError handler and are logged with exc_info. + logger.warning( + "%s Failed to strip CLI session, uploading raw: %s", log_prefix, e + ) + return raw_bytes + + if len(stripped_bytes) < len(raw_bytes): + # Write back locally so same-pod turns also benefit. + try: + Path(real_path).write_bytes(stripped_bytes) + logger.info( + "%s Stripped CLI session: %dB → %dB", + log_prefix, + len(raw_bytes), + len(stripped_bytes), + ) + except OSError as e: + # write_bytes failed — stripped content is still valid for GCS upload even + # though the local write-back failed (same-pod optimization silently skipped). + logger.warning( + "%s Failed to write back stripped CLI session: %s", + log_prefix, + e.strerror or str(e), + ) + return stripped_bytes + + +def _process_cli_restore( + cli_restore: TranscriptDownload, + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> tuple[str, bool]: + """Validate and write a restored CLI session to disk. + + Decodes bytes → UTF-8, strips progress entries and stale thinking blocks, + validates the result, then writes the stripped content to disk so the CLI + can ``--resume`` from it. + + Returns ``(stripped_content, success)`` where ``success=False`` means the + content was invalid or the disk write failed (caller should skip --resume). + """ + try: + raw_bytes = cli_restore.content + raw_str = ( + raw_bytes.decode("utf-8") if isinstance(raw_bytes, bytes) else raw_bytes + ) + except UnicodeDecodeError: + logger.warning( + "%s CLI session content is not valid UTF-8, skipping", log_prefix + ) + return "", False + + stripped = strip_for_upload(raw_str) + is_valid = validate_transcript(stripped) + # Use len(raw_str) rather than len(cli_restore.content) so the unit is always + # characters (raw_str is always str at this point regardless of input type). + # lines_stripped = original lines minus remaining lines after stripping. + _original_lines = len(raw_str.strip().split("\n")) if raw_str.strip() else 0 + _remaining_lines = len(stripped.strip().split("\n")) if stripped.strip() else 0 + logger.info( + "%s Restored CLI session: %dB raw, %d lines stripped, msg_count=%d, valid=%s", + log_prefix, + len(raw_str), + _original_lines - _remaining_lines, + cli_restore.message_count, + is_valid, + ) + if not is_valid: + logger.warning( + "%s CLI session content invalid after strip — running without --resume", + log_prefix, + ) + return "", False + + stripped_bytes = stripped.encode("utf-8") + if not _write_cli_session_to_disk(stripped_bytes, sdk_cwd, session_id, log_prefix): + return "", False + + return stripped, True + + async def _cleanup_sdk_tool_results(cwd: str) -> None: """Remove SDK session artifacts for a specific working directory. @@ -922,8 +1101,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]: result.append(block) else: logger.warning( - f"[SDK] Unknown content block type: {type(block).__name__}. " - f"This may indicate a new SDK version with additional block types." + "[SDK] Unknown content block type: %s." + " This may indicate a new SDK version with additional block types.", + type(block).__name__, ) return result @@ -978,10 +1158,11 @@ async def _compress_messages( if result.was_compacted: logger.info( - f"[SDK] Context compacted: {result.original_token_count} -> " - f"{result.token_count} tokens " - f"({result.messages_summarized} summarized, " - f"{result.messages_dropped} dropped)" + "[SDK] Context compacted: %d -> %d tokens (%d summarized, %d dropped)", + result.original_token_count, + result.token_count, + result.messages_summarized, + result.messages_dropped, ) # Convert compressed dicts back to ChatMessages return [ @@ -1048,11 +1229,17 @@ def _session_messages_to_transcript(messages: list[ChatMessage]) -> str: ) if blocks: builder.append_assistant(blocks) - elif msg.role == "tool" and msg.tool_call_id: - builder.append_tool_result( - tool_use_id=msg.tool_call_id, - content=msg.content or "", - ) + elif msg.role == "tool": + if msg.tool_call_id: + builder.append_tool_result( + tool_use_id=msg.tool_call_id, + content=msg.content or "", + ) + else: + # Malformed tool message — no tool_call_id to link to an + # assistant tool_use block. Skip to avoid an unmatched + # tool_result entry in the builder (which would confuse --resume). + logger.warning("[SDK] Skipping tool gap message with no tool_call_id") return builder.to_jsonl() @@ -1098,6 +1285,7 @@ async def _build_query_message( transcript_msg_count: int, session_id: str, target_tokens: int | None = None, + prior_messages: "list[ChatMessage] | None" = None, ) -> tuple[str, bool]: """Build the query message with appropriate context. @@ -1203,15 +1391,16 @@ async def _build_query_message( ) return current_message, False + source = prior_messages if prior_messages is not None else prior logger.warning( - "[SDK] [%s] No --resume for %d-message session — compressing" - " full session history (pod affinity issue or first turn after" - " restore failure); target_tokens=%s", + "[SDK] [%s] No --resume for %d-message session — compressing context " + "(source=%s, target_tokens=%s)", session_id[:8], msg_count, + "transcript+gap" if prior_messages is not None else "full-db", target_tokens, ) - compressed, was_compressed = await _compress_messages(prior, target_tokens) + compressed, was_compressed = await _compress_messages(source, target_tokens) history_context = _format_conversation_context(compressed) if history_context: logger.info( @@ -1228,7 +1417,7 @@ async def _build_query_message( "[SDK] [%s] Fallback context empty after compression" " (%d messages) — sending message without history", session_id[:8], - len(prior), + len(source), ) return current_message, False @@ -2233,6 +2422,163 @@ async def _seed_transcript( return _seeded, True, len(_prior) +@dataclass +class _RestoreResult: + """Return value from ``_restore_cli_session_for_turn``.""" + + transcript_content: str = "" + transcript_covers_prefix: bool = True + use_resume: bool = False + resume_file: str | None = None + transcript_msg_count: int = 0 + baseline_download: "TranscriptDownload | None" = None + context_messages: "list[ChatMessage] | None" = None + + +async def _restore_cli_session_for_turn( + user_id: str | None, + session_id: str, + session: "ChatSession", + sdk_cwd: str, + transcript_builder: "TranscriptBuilder", + log_prefix: str, +) -> _RestoreResult: + """Download, validate and restore a CLI session for ``--resume`` on this turn. + + Performs a single GCS round-trip to fetch the session bytes + message_count + watermark. Falls back to DB-message reconstruction when GCS has no session + (first turn or upload missed). + + Returns a ``_RestoreResult`` with all transcript-related state ready for the + caller to merge into its local variables. + """ + result = _RestoreResult() + + if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1): + return result + + try: + cli_restore = await download_transcript( + user_id, session_id, log_prefix=log_prefix + ) + except Exception as restore_err: + logger.warning( + "%s CLI session restore failed, continuing without --resume: %s", + log_prefix, + restore_err, + ) + cli_restore = None + + # Only attempt --resume for SDK-written transcripts. + # Baseline-written transcripts use TranscriptBuilder format (synthetic IDs, + # stripped fields) that may not be valid for --resume. + if cli_restore is not None and cli_restore.mode != "sdk": + logger.info( + "%s Transcript written by mode=%r — skipping --resume, " + "will use transcript content + gap for context", + log_prefix, + cli_restore.mode, + ) + result.baseline_download = cli_restore # keep for extract_context_messages + cli_restore = None + + # Validate, strip, and write to disk — delegate to helper to reduce + # function complexity. Writing an invalid/corrupt file to disk then + # falling back to "no --resume" would cause the CLI to fail with + # "Session ID already in use" because the file exists at the expected + # session path, so we validate BEFORE any disk write. + stripped = "" + if cli_restore is not None and sdk_cwd: + stripped, ok = _process_cli_restore( + cli_restore, sdk_cwd, session_id, log_prefix + ) + if not ok: + result.transcript_covers_prefix = False + cli_restore = None + + if cli_restore is None and sdk_cwd: + # Validation failed or GCS returned no session. Delete any + # existing local session file so the CLI doesn't reject the + # session_id with "Session ID already in use". T1 may have + # left a valid file at this path; we clear it so the fallback + # path (session_id= without --resume) can create a new session. + _stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id)) + if Path(_stale_path).exists() and _stale_path.startswith( + projects_base() + os.sep + ): + try: + Path(_stale_path).unlink() + logger.debug( + "%s Removed stale local CLI session file for clean fallback", + log_prefix, + ) + except OSError as _unlink_err: + logger.debug( + "%s Failed to remove stale local session file: %s", + log_prefix, + _unlink_err, + ) + + if cli_restore is not None: + result.transcript_content = stripped + transcript_builder.load_previous(stripped, log_prefix=log_prefix) + result.use_resume = True + result.resume_file = session_id + result.transcript_msg_count = cli_restore.message_count + return result + + # No valid --resume source (mode="baseline" or no GCS file). + # Build context from transcript content + gap, falling back to full DB. + # extract_context_messages handles both: non-None baseline_download uses + # the compacted transcript + gap; None falls back to all prior DB messages. + context_msgs = extract_context_messages(result.baseline_download, session.messages) + result.context_messages = context_msgs + result.transcript_msg_count = ( + result.baseline_download.message_count + if result.baseline_download is not None + and result.baseline_download.message_count > 0 + else len(session.messages) - 1 + ) + result.transcript_covers_prefix = True + logger.info( + "%s Context built from %s: %d messages (transcript watermark=%d, " + "will inject as <conversation_history>)", + log_prefix, + ( + "baseline transcript + gap" + if result.baseline_download is not None + else "DB fallback" + ), + len(context_msgs), + result.transcript_msg_count, + ) + + # Load baseline transcript content into builder so the upload path has accurate state. + # Also sets result.transcript_content so the _seed_transcript guard in the caller + # (``not transcript_content``) does not overwrite this builder state with a DB + # reconstruction — which would duplicate entries since load_previous appends. + if result.baseline_download is not None: + try: + raw_for_builder = result.baseline_download.content + if isinstance(raw_for_builder, bytes): + raw_for_builder = raw_for_builder.decode("utf-8") + stripped = strip_for_upload(raw_for_builder) + if validate_transcript(stripped): + transcript_builder.load_previous(stripped, log_prefix=log_prefix) + result.transcript_content = stripped + except (UnicodeDecodeError, ValueError, OSError) as _load_err: + # UnicodeDecodeError: non-UTF-8 content; ValueError: malformed JSONL in + # strip_for_upload; OSError: encode/decode I/O failure. Unexpected + # exceptions propagate so programming errors are not silently masked. + logger.debug( + "%s Could not load baseline transcript into builder: %s", + log_prefix, + _load_err, + ) + + return result + + async def stream_chat_completion_sdk( session_id: str, message: str | None = None, @@ -2427,28 +2773,9 @@ async def stream_chat_completion_sdk( return sandbox - async def _fetch_transcript(): - """Download transcript for --resume if applicable.""" - if not ( - config.claude_agent_use_resume and user_id and len(session.messages) > 1 - ): - return None - try: - return await download_transcript( - user_id, session_id, log_prefix=log_prefix - ) - except Exception as transcript_err: - logger.warning( - "%s Transcript download failed, continuing without --resume: %s", - log_prefix, - transcript_err, - ) - return None - - e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather( + e2b_sandbox, (base_system_prompt, understanding) = await asyncio.gather( _setup_e2b(), _build_system_prompt(user_id if not has_history else None), - _fetch_transcript(), ) use_e2b = e2b_sandbox is not None @@ -2473,95 +2800,17 @@ async def stream_chat_completion_sdk( warm_ctx = await fetch_warm_context(user_id, message or "") or "" - # Process transcript download result and restore CLI native session. - # The CLI native session file (uploaded after each turn) is the - # source of truth for --resume. Our custom JSONL (TranscriptEntry) - # is loaded into the builder for future upload_transcript calls. - transcript_msg_count = 0 - if dl: - is_valid = validate_transcript(dl.content) - dl_lines = dl.content.strip().split("\n") if dl.content else [] - logger.info( - "%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s", - log_prefix, - len(dl.content), - len(dl_lines), - dl.message_count, - is_valid, - ) - if is_valid: - # Load previous FULL context into builder for state tracking. - transcript_content = dl.content - transcript_builder.load_previous(dl.content, log_prefix=log_prefix) - # Restore CLI's native session file so --resume session_id works. - # Falls back gracefully if not available (first turn or upload missed). - # user_id is guaranteed non-None here: _fetch_transcript only sets dl - # when `config.claude_agent_use_resume and user_id` is truthy. - cli_restored = user_id is not None and await restore_cli_session( - user_id, session_id, sdk_cwd, log_prefix=log_prefix - ) - if cli_restored: - use_resume = True - resume_file = session_id # CLI --resume expects UUID, not file path - transcript_msg_count = dl.message_count - logger.info( - "%s Using --resume %s (%dB transcript, msg_count=%d)", - log_prefix, - session_id[:8], - len(dl.content), - transcript_msg_count, - ) - else: - # Builder loaded but CLI native session not available. - # --resume will not be used this turn; upload after turn - # will seed the native session for the next turn. - # - # Still record transcript_msg_count so _build_query_message - # can use the transcript-aware gap path (inject only new - # messages since the transcript end) instead of compressing - # the full DB history. This avoids prompt-too-long on - # large sessions where the CLI session is temporarily - # unavailable (e.g. mixed-version rolling deployment). - transcript_msg_count = dl.message_count - logger.info( - "%s CLI session not restored — running without" - " --resume this turn (transcript_msg_count=%d for" - " gap-aware fallback)", - log_prefix, - transcript_msg_count, - ) - else: - logger.warning("%s Transcript downloaded but invalid", log_prefix) - transcript_covers_prefix = False - elif config.claude_agent_use_resume and user_id and len(session.messages) > 1: - # No transcript in storage — reconstruct from DB messages as a - # last-resort fallback (e.g., first turn after a crash or transition). - # This path loses tool call IDs and structural fidelity but prevents - # a completely context-free response for established sessions. - prior = session.messages[:-1] - reconstructed = _session_messages_to_transcript(prior) - if reconstructed: - # Populate builder only; no --resume since there is no CLI - # native session to restore. The transcript builder state is - # still useful for the upload that seeds future native sessions. - transcript_content = reconstructed - transcript_builder.load_previous(reconstructed, log_prefix=log_prefix) - transcript_msg_count = len(prior) - transcript_covers_prefix = True - logger.info( - "%s Reconstructed transcript from %d session messages " - "(no CLI native session — running without --resume this turn)", - log_prefix, - len(prior), - ) - else: - logger.warning( - "%s No transcript available and reconstruction produced empty" - " output (%d messages in session)", - log_prefix, - len(session.messages), - ) - transcript_covers_prefix = False + # Restore CLI session — single GCS round-trip covers both --resume and builder state. + # message_count watermark lives in the companion .meta.json alongside the session file. + _restore = await _restore_cli_session_for_turn( + user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix + ) + transcript_content = _restore.transcript_content + transcript_covers_prefix = _restore.transcript_covers_prefix + use_resume = _restore.use_resume + resume_file = _restore.resume_file + transcript_msg_count = _restore.transcript_msg_count + restore_context_messages = _restore.context_messages yield StreamStart(messageId=message_id, sessionId=session_id) @@ -2680,14 +2929,14 @@ async def stream_chat_completion_sdk( else: # Set session_id whenever NOT resuming so the CLI writes the # native session file to a predictable path for - # upload_cli_session() after the turn. This covers: + # upload_transcript() after the turn. This covers: # • T1 fresh: no prior history, first SDK turn. # • Mode-switch T1: has_history=True (prior baseline turns in # DB) but no CLI session file was ever uploaded — the CLI has # never been invoked with this session_id before. # • T2+ without --resume (restore failed): no session file was - # restored to local storage (restore_cli_session returned - # False), so no conflict with an existing file. + # restored to local storage (download_transcript returned + # None), so no conflict with an existing file. # When --resume is active the session_id is already implied by # the resume file; passing it again would be rejected by the CLI. sdk_options_kwargs["session_id"] = session_id @@ -2780,6 +3029,7 @@ async def stream_chat_completion_sdk( use_resume, transcript_msg_count, session_id, + prior_messages=restore_context_messages, ) # If files are attached, prepare them: images become vision # content blocks in the user message, other files go to sdk_cwd. @@ -2909,15 +3159,16 @@ async def stream_chat_completion_sdk( elif "session_id" in sdk_options_kwargs: # Initial invocation used session_id (T1 or mode-switch # T1): keep it so the CLI writes the session file to the - # predictable path for upload_cli_session(). Storage is + # predictable path for upload_transcript(). Storage is # ephemeral per invocation, so no "Session ID already in # use" conflict occurs — no prior file was restored. sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry["session_id"] = session_id else: - # T2+ retry without --resume: do not pass --session-id. - # The T1 session file already exists at that path; re-using - # the same ID would fail with "Session ID already in use". + # T2+ retry without --resume: initial invocation used + # --resume, which restored the T1 session file to local + # storage. Re-using session_id without --resume would + # fail with "Session ID already in use". sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry.pop("session_id", None) # Recompute system_prompt for retry — ctx.use_resume may have @@ -2931,6 +3182,10 @@ async def stream_chat_completion_sdk( system_prompt, cross_user_cache=_cross_user_retry ) state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs + # Retry intentionally omits prior_messages (transcript+gap context) and + # falls back to full session.messages[:-1] from DB — the authoritative + # source. transcript+gap is an optimisation for the first attempt only; + # on retry the extra overhead of full-DB context is acceptable. state.query_message, state.was_compacted = await _build_query_message( current_message, session, @@ -3366,86 +3621,23 @@ async def stream_chat_completion_sdk( _background_tasks.add(_ingest_task) _ingest_task.add_done_callback(_background_tasks.discard) - # --- Upload transcript for next-turn --resume --- - # TranscriptBuilder is the single source of truth. It mirrors the - # CLI's active context: on compaction, replace_entries() syncs it - # with the compacted session file. No CLI file read needed here. - if skip_transcript_upload: - logger.warning( - "%s Skipping transcript upload — transcript was dropped " - "during prompt-too-long recovery", - log_prefix, - ) - elif ( - config.claude_agent_use_resume - and user_id - and session is not None - and state is not None - ): - try: - transcript_upload_content = state.transcript_builder.to_jsonl() - entry_count = state.transcript_builder.entry_count - - if not transcript_upload_content: - logger.warning( - "%s No transcript to upload (builder empty)", log_prefix - ) - elif not validate_transcript(transcript_upload_content): - logger.warning( - "%s Transcript invalid, skipping upload (entries=%d)", - log_prefix, - entry_count, - ) - elif not transcript_covers_prefix: - logger.warning( - "%s Skipping transcript upload — builder does not " - "cover full session prefix (entries=%d, session=%d)", - log_prefix, - entry_count, - len(session.messages), - ) - else: - logger.info( - "%s Uploading transcript (entries=%d, bytes=%d)", - log_prefix, - entry_count, - len(transcript_upload_content), - ) - await asyncio.shield( - upload_transcript( - user_id=user_id, - session_id=session_id, - content=transcript_upload_content, - message_count=len(session.messages), - log_prefix=log_prefix, - ) - ) - except Exception as upload_err: - logger.error( - "%s Transcript upload failed in finally: %s", - log_prefix, - upload_err, - exc_info=True, - ) - # --- Upload CLI native session file for cross-pod --resume --- # The CLI writes its native session JSONL after each turn completes. - # Uploading it here enables --resume on any pod (no pod affinity needed). - # Runs after upload_transcript so both are available for the next turn. - # asyncio.shield: same pattern as upload_transcript above — if the - # outer finally-block coroutine is cancelled while awaiting shield, - # the CancelledError propagates (BaseException, not caught by - # `except Exception`) letting the caller handle cancellation, while - # the shielded inner coroutine continues running to completion so the - # upload is not lost. This is intentional and matches the pattern - # used for upload_transcript immediately above. + # The companion .meta.json carries the message_count watermark and mode + # so the next turn can restore both --resume context and gap-fill state + # in a single GCS round-trip via download_transcript(). + # asyncio.shield: if the outer finally-block coroutine is cancelled + # while awaiting shield, the CancelledError propagates (BaseException, + # not caught by `except Exception`) letting the caller handle + # cancellation, while the shielded inner coroutine continues running + # to completion so the upload is not lost. # # NOTE: upload is attempted regardless of state.use_resume — even when # this turn ran without --resume (restore failed or first T2+ on a new # pod), the T1 session file at the expected path may still be present # and should be re-uploaded so the next turn can resume from it. - # upload_cli_session silently skips when the file is absent, so this is - # always safe. + # _read_cli_session_from_disk returns None when the file is absent, so + # this is always safe. # # Intentionally NOT gated on skip_transcript_upload: that flag is set # when our custom JSONL transcript is dropped (transcript_lost=True on @@ -3471,14 +3663,36 @@ async def stream_chat_completion_sdk( skip_transcript_upload, ) try: - await asyncio.shield( - upload_cli_session( - user_id=user_id, - session_id=session_id, - sdk_cwd=sdk_cwd, - log_prefix=log_prefix, - ) + # Read the CLI's native session file from disk (written by the CLI + # after the turn), then upload the bytes to GCS. + _cli_content = _read_cli_session_from_disk( + sdk_cwd, session_id, log_prefix ) + if _cli_content: + # Watermark = number of DB messages this transcript covers. + # len(session.messages) is accurate: the CLI session file + # was just written after the turn completed, so it covers + # all messages through this turn. Any gap from a prior + # missed upload was already detected by detect_gap and + # injected as context, so the model has the full history. + # + # Previously this used _final_tmsg_count + 2, which + # under-counted for tool-use turns (delta = 2 + 2*N_tool_calls), + # causing persistent spurious gap-fills on every subsequent turn. + # That concern was addressed by the inflated-watermark fix + # (using the GCS watermark as the anchor for gap detection), + # which makes len(session.messages) safe to use here. + _jsonl_covered = len(session.messages) + await asyncio.shield( + upload_transcript( + user_id=user_id, + session_id=session_id, + content=_cli_content, + message_count=_jsonl_covered, + mode="sdk", + log_prefix=log_prefix, + ) + ) except Exception as cli_upload_err: logger.warning( "%s CLI session upload failed in finally: %s", diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index 7c5e429697..3b919c6036 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -22,6 +22,7 @@ from .service import ( _iter_sdk_messages, _normalize_model_name, _reduce_context, + _restore_cli_session_for_turn, _TokenUsage, ) @@ -615,3 +616,340 @@ class TestSdkSessionIdSelection: ) assert retry.get("resume") == self.SESSION_ID assert "session_id" not in retry + + +# --------------------------------------------------------------------------- +# _restore_cli_session_for_turn — mode check +# --------------------------------------------------------------------------- + + +class TestRestoreCliSessionModeCheck: + """SDK skips --resume when the transcript was written by the baseline mode.""" + + @pytest.mark.asyncio + async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path): + """A transcript with mode='baseline' must not be used as the --resume source. + + The mode check discards the GCS baseline content and falls back to DB + reconstruction from session.messages instead. + """ + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="hello-unique-marker"), + ChatMessage(role="assistant", content="world-unique-marker"), + ChatMessage(role="user", content="follow up"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + # Baseline content with a sentinel that must NOT appear in the final transcript + baseline_restore = TranscriptDownload( + content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n', + message_count=1, + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + download_mock = AsyncMock(return_value=baseline_restore) + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=download_mock, + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + # download_transcript was called (attempted GCS restore) + download_mock.assert_awaited_once() + # use_resume must be False — baseline transcripts cannot be used with --resume + assert result.use_resume is False + # context_messages must be populated — new behaviour uses transcript content + gap + # instead of full DB reconstruction. + assert result.context_messages is not None + # The baseline transcript has 1 user message (BASELINE_SENTINEL). + # Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns []. + # Result: 1 message from transcript, no gap. + assert len(result.context_messages) == 1 + assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "") + + @pytest.mark.asyncio + async def test_sdk_mode_transcript_allows_resume(self, tmp_path): + """A valid SDK-written transcript is accepted for --resume.""" + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "hi"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "hello"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="hi"), + ChatMessage(role="assistant", content="hello"), + ChatMessage(role="user", content="follow up"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + sdk_restore = TranscriptDownload( + content=content, + message_count=2, + mode="sdk", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=sdk_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is True + + @pytest.mark.asyncio + async def test_baseline_mode_context_messages_from_transcript_content( + self, tmp_path + ): + """mode='baseline' → context_messages populated from transcript content + gap. + + When a baseline-mode transcript exists, extract_context_messages converts + the JSONL content to ChatMessage objects and returns them in context_messages. + use_resume must remain False. + """ + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + # Build a minimal valid JSONL transcript with 2 messages + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "TRANSCRIPT_USER"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="DB_USER"), + ChatMessage(role="assistant", content="DB_ASSISTANT"), + ChatMessage(role="user", content="current turn"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + baseline_restore = TranscriptDownload( + content=content, + message_count=2, + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=baseline_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is False + assert result.context_messages is not None + # Transcript content has 2 messages, no gap (watermark=2, session prior=2) + assert len(result.context_messages) == 2 + assert result.context_messages[0].role == "user" + assert result.context_messages[1].role == "assistant" + assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "") + # transcript_content must be non-empty so the _seed_transcript guard in + # stream_chat_completion_sdk skips DB reconstruction (which would duplicate + # builder entries since load_previous appends). + assert result.transcript_content != "" + + @pytest.mark.asyncio + async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path): + """mode='baseline' + gap → context_messages includes transcript msgs and gap.""" + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + # Transcript covers only 2 messages; session has 4 prior + current turn + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "TRANSCRIPT_USER_0"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="DB_USER_0"), + ChatMessage(role="assistant", content="DB_ASSISTANT_1"), + ChatMessage(role="user", content="GAP_USER_2"), + ChatMessage(role="assistant", content="GAP_ASSISTANT_3"), + ChatMessage(role="user", content="current turn"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + baseline_restore = TranscriptDownload( + content=content, + message_count=2, # watermark=2; session has 4 prior → gap of 2 + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=baseline_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is False + assert result.context_messages is not None + # 2 from transcript + 2 gap messages = 4 total + assert len(result.context_messages) == 4 + roles = [m.role for m in result.context_messages] + assert roles == ["user", "assistant", "user", "assistant"] + # Gap messages come from DB (ChatMessage objects) + gap_user = result.context_messages[2] + gap_asst = result.context_messages[3] + assert gap_user.content == "GAP_USER_2" + assert gap_asst.content == "GAP_ASSISTANT_3" diff --git a/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py b/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py new file mode 100644 index 0000000000..592dbde82f --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py @@ -0,0 +1,95 @@ +"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk. + +The fix is at the upload step: when use_resume=True and transcript_msg_count>0 +we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just +recorded) instead of len(session.messages). This prevents the "inflated +watermark" bug where a stale JSONL in GCS could hide missing context from +future gap-fill checks. +""" + +from __future__ import annotations + + +def _compute_jsonl_covered( + use_resume: bool, + transcript_msg_count: int, + session_msg_count: int, +) -> int: + """Mirror the watermark computation from ``stream_chat_completion_sdk``. + + Extracted here so we can unit-test it independently without invoking the + full streaming stack. + """ + if use_resume and transcript_msg_count > 0: + return transcript_msg_count + 2 + return session_msg_count + + +class TestWatermarkFix: + """Watermark computation logic — mirrors the finally-block in SDK service.""" + + def test_inflated_watermark_triggers_gap_fill(self): + """Stale JSONL (T12) with high watermark (46) → after fix, watermark=14. + + Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1) + never fires because 46 >= 47-1=46, so context loss is silent. + After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and + the model receives the missing turns. + """ + # Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47 + use_resume = True + transcript_msg_count = 12 + session_msg_count = 47 # DB count (what old code used to set watermark) + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == 14 # 12 + 2, NOT 47 + # Verify: the gap check would fire on next turn + # next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True + assert watermark < session_msg_count - 1 + + def test_no_false_positive_when_transcript_current(self): + """Transcript current (watermark=46, DB=47) → gap stays 0. + + When the JSONL actually covers T46 (the most recent assistant turn), + uploading watermark=46+2=48 means next turn's gap check sees + 48 >= 48-1=47 → no gap. Correct. + """ + use_resume = True + transcript_msg_count = 46 + session_msg_count = 47 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == 48 # 46 + 2 + # Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap + next_turn_session = 48 + assert watermark >= next_turn_session - 1 + + def test_fresh_session_falls_back_to_db_count(self): + """use_resume=False → watermark = len(session.messages) (original behaviour).""" + use_resume = False + transcript_msg_count = 0 + session_msg_count = 3 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == session_msg_count + + def test_old_format_meta_zero_count_falls_back_to_db(self): + """transcript_msg_count=0 (old-format meta with no count field) → DB fallback.""" + use_resume = True + transcript_msg_count = 0 # old-format meta or not-yet-set + session_msg_count = 10 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == session_msg_count diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript.py b/autogpt_platform/backend/backend/copilot/sdk/transcript.py index cfbf01a466..d5cf3c3e94 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/transcript.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript.py @@ -12,18 +12,20 @@ from backend.copilot.transcript import ( ENTRY_TYPE_MESSAGE, STOP_REASON_END_TURN, STRIPPABLE_TYPES, - TRANSCRIPT_STORAGE_PREFIX, TranscriptDownload, + TranscriptMode, cleanup_stale_project_dirs, + cli_session_path, compact_transcript, delete_transcript, + detect_gap, download_transcript, + extract_context_messages, + projects_base, read_compacted_entries, - restore_cli_session, strip_for_upload, strip_progress_entries, strip_stale_thinking_blocks, - upload_cli_session, upload_transcript, validate_transcript, write_transcript_to_tempfile, @@ -34,18 +36,20 @@ __all__ = [ "ENTRY_TYPE_MESSAGE", "STOP_REASON_END_TURN", "STRIPPABLE_TYPES", - "TRANSCRIPT_STORAGE_PREFIX", "TranscriptDownload", + "TranscriptMode", "cleanup_stale_project_dirs", + "cli_session_path", "compact_transcript", "delete_transcript", + "detect_gap", "download_transcript", + "extract_context_messages", + "projects_base", "read_compacted_entries", - "restore_cli_session", "strip_for_upload", "strip_progress_entries", "strip_stale_thinking_blocks", - "upload_cli_session", "upload_transcript", "validate_transcript", "write_transcript_to_tempfile", diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py index 14e404a994..f8e1608094 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py @@ -297,8 +297,8 @@ class TestStripProgressEntries: class TestDeleteTranscript: @pytest.mark.asyncio - async def test_deletes_both_jsonl_and_meta(self): - """delete_transcript removes both the .jsonl and .meta.json files.""" + async def test_deletes_cli_session_and_meta(self): + """delete_transcript removes the CLI session .jsonl and .meta.json.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock() @@ -309,7 +309,7 @@ class TestDeleteTranscript: ): await delete_transcript("user-123", "session-456") - assert mock_storage.delete.call_count == 3 + assert mock_storage.delete.call_count == 2 paths = [call.args[0] for call in mock_storage.delete.call_args_list] assert any(p.endswith(".jsonl") for p in paths) assert any(p.endswith(".meta.json") for p in paths) @@ -319,7 +319,7 @@ class TestDeleteTranscript: """If .jsonl delete fails, .meta.json delete is still attempted.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock( - side_effect=[Exception("jsonl delete failed"), None, None] + side_effect=[Exception("jsonl delete failed"), None] ) with patch( @@ -330,14 +330,14 @@ class TestDeleteTranscript: # Should not raise await delete_transcript("user-123", "session-456") - assert mock_storage.delete.call_count == 3 + assert mock_storage.delete.call_count == 2 @pytest.mark.asyncio async def test_handles_meta_delete_failure(self): """If .meta.json delete fails, no exception propagates.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock( - side_effect=[None, Exception("meta delete failed"), None] + side_effect=[None, Exception("meta delete failed")] ) with patch( @@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs: nonexistent = str(tmp_path / "does-not-exist" / "projects") monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: nonexistent, ) @@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks: # Both entries of last turn (msg_last) preserved assert lines[1]["message"]["content"][0]["type"] == "thinking" assert lines[2]["message"]["content"][0]["type"] == "text" + + +class TestProcessCliRestore: + """``_process_cli_restore`` validates, strips, and writes CLI session to disk.""" + + def test_writes_stripped_bytes_not_raw(self, tmp_path): + """Stripped bytes (not raw bytes) must be written to disk for --resume.""" + import os + import re + from pathlib import Path + from unittest.mock import patch + + from backend.copilot.sdk.service import _process_cli_restore + from backend.copilot.transcript import TranscriptDownload + + session_id = "12345678-0000-0000-0000-abcdef000001" + sdk_cwd = str(tmp_path) + projects_base_dir = str(tmp_path) + + # Build raw content with a strippable progress entry + a valid user/assistant pair + raw_content = ( + '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + '{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n' + '{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n' + ) + raw_bytes = raw_content.encode("utf-8") + restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk") + + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + stripped_str, ok = _process_cli_restore( + restore, sdk_cwd, session_id, "[Test]" + ) + + assert ok, "Expected successful restore" + + # Find the written session file + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl" + assert session_file.exists(), "Session file should have been written" + + written_bytes = session_file.read_bytes() + # The written bytes must be the stripped version (no progress entry) + assert ( + b"progress" not in written_bytes + ), "Raw bytes with progress entry should not have been written" + assert ( + b"hello" in written_bytes + ), "Stripped content should still contain assistant turn" + + # Written bytes must equal the stripped string re-encoded + assert written_bytes == stripped_str.encode( + "utf-8" + ), "Written bytes must equal stripped content" + + def test_invalid_content_returns_false(self): + """Content that fails validation after strip returns (empty, False).""" + from backend.copilot.sdk.service import _process_cli_restore + from backend.copilot.transcript import TranscriptDownload + + # A single progress-only entry — stripped result will be empty/invalid + raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + restore = TranscriptDownload( + content=raw_content.encode("utf-8"), message_count=1, mode="sdk" + ) + + stripped_str, ok = _process_cli_restore( + restore, + "/tmp/nonexistent-sdk-cwd", + "12345678-0000-0000-0000-000000000099", + "[Test]", + ) + + assert not ok + assert stripped_str == "" + + +class TestReadCliSessionFromDisk: + """``_read_cli_session_from_disk`` reads, strips, and optionally writes back the session.""" + + def _build_session_file(self, tmp_path, session_id: str): + """Build the session file path inside tmp_path using the same encoding as cli_session_path.""" + import os + import re + from pathlib import Path + + sdk_cwd = str(tmp_path) + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_dir = Path(str(tmp_path)) / encoded_cwd + session_dir.mkdir(parents=True, exist_ok=True) + return sdk_cwd, session_dir / f"{session_id}.jsonl" + + def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path): + """Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback).""" + from unittest.mock import patch + + from backend.copilot.sdk.service import _read_cli_session_from_disk + + session_id = "12345678-0000-0000-0000-aabbccdd0001" + projects_base_dir = str(tmp_path) + sdk_cwd, session_file = self._build_session_file(tmp_path, session_id) + + # Write raw invalid UTF-8 bytes + session_file.write_bytes(b"\xff\xfe invalid utf-8\n") + + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]") + + # UnicodeDecodeError path returns the raw bytes (upload-raw fallback) + assert result == b"\xff\xfe invalid utf-8\n" + + def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path): + """OSError on write-back returns stripped bytes for GCS upload (not raw).""" + from unittest.mock import patch + + from backend.copilot.sdk.service import _read_cli_session_from_disk + + session_id = "12345678-0000-0000-0000-aabbccdd0002" + projects_base_dir = str(tmp_path) + sdk_cwd, session_file = self._build_session_file(tmp_path, session_id) + + # Content with a strippable progress entry so stripped_bytes < raw_bytes + raw_content = ( + '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + '{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n' + '{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n' + ) + session_file.write_bytes(raw_content.encode("utf-8")) + # Make the file read-only so write_bytes raises OSError on the write-back + session_file.chmod(0o444) + + try: + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]") + finally: + session_file.chmod(0o644) + + # Must return stripped bytes (not raw, not None) so GCS gets the clean version + assert result is not None + assert ( + b"progress" not in result + ), "Stripped bytes must not contain progress entry" + assert b"hello" in result, "Stripped bytes should contain assistant turn" diff --git a/autogpt_platform/backend/backend/copilot/service_test.py b/autogpt_platform/backend/backend/copilot/service_test.py index c4b1c3182e..ec9b13fb22 100644 --- a/autogpt_platform/backend/backend/copilot/service_test.py +++ b/autogpt_platform/backend/backend/copilot/service_test.py @@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id): # (CLI version, platform). When that happens, multi-turn still works # via conversation compression (non-resume path), but we can't test # the --resume round-trip. - transcript = None + cli_session = None for _ in range(10): await asyncio.sleep(0.5) - transcript = await download_transcript(test_user_id, session.session_id) - if transcript: + cli_session = await download_transcript(test_user_id, session.session_id) + # Wait until both the session bytes AND the message_count watermark are + # present — a session with message_count=0 means the .meta.json hasn't + # been uploaded yet, so --resume on the next turn would skip gap-fill. + if cli_session and cli_session.message_count > 0: break - if not transcript: + if not cli_session: return pytest.skip( "CLI did not produce a usable transcript — " "cannot test --resume round-trip in this environment" ) - logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes") + logger.info( + f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}" + ) # Reload session for turn 2 session = await get_chat_session(session.session_id, test_user_id) diff --git a/autogpt_platform/backend/backend/copilot/transcript.py b/autogpt_platform/backend/backend/copilot/transcript.py index a1e11f352d..c4d3de28af 100644 --- a/autogpt_platform/backend/backend/copilot/transcript.py +++ b/autogpt_platform/backend/backend/copilot/transcript.py @@ -1,10 +1,10 @@ """JSONL transcript management for stateless multi-turn resume. The Claude Code CLI persists conversations as JSONL files (one JSON object per -line). When the SDK's ``Stop`` hook fires we read this file, strip bloat -(progress entries, metadata), and upload the result to bucket storage. On the -next turn we download the transcript, write it to a temp file, and pass -``--resume`` so the CLI can reconstruct the full conversation. +line). When the SDK's ``Stop`` hook fires the caller reads this file, strips +bloat (progress entries, metadata), and uploads the result to bucket storage. +On the next turn the caller downloads the bytes and writes them to disk before +passing ``--resume`` so the CLI can reconstruct the full conversation. Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local filesystem for self-hosted) — no DB column needed. @@ -20,6 +20,7 @@ import shutil import time from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING, Literal from uuid import uuid4 from backend.util import json @@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client from backend.util.prompt import CompressResult, compress_context from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage +if TYPE_CHECKING: + from .model import ChatMessage + logger = logging.getLogger(__name__) # UUIDs are hex + hyphens; strip everything else to prevent path injection. @@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset( ) +TranscriptMode = Literal["sdk", "baseline"] + + @dataclass class TranscriptDownload: - """Result of downloading a transcript with its metadata.""" - - content: str - message_count: int = 0 # session.messages length when uploaded - uploaded_at: float = 0.0 # epoch timestamp of upload + content: bytes | str + message_count: int = 0 + # "sdk" = Claude CLI native, "baseline" = TranscriptBuilder + mode: TranscriptMode = "sdk" -# Workspace storage constants — deterministic path from session_id. -TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" # Storage prefix for the CLI's native session JSONL files (for cross-pod --resume). _CLI_SESSION_STORAGE_PREFIX = "cli-sessions" @@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str: _SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") -def _projects_base() -> str: +def projects_base() -> str: """Return the resolved path to the CLI's projects directory.""" config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude") return os.path.realpath(os.path.join(config_dir, "projects")) @@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: Returns the number of directories removed. """ - projects_base = _projects_base() - if not os.path.isdir(projects_base): + _pbase = projects_base() + if not os.path.isdir(_pbase): return 0 now = time.time() @@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: # Scoped mode: only clean up the one directory for the current session. if encoded_cwd: - target = Path(projects_base) / encoded_cwd + target = Path(_pbase) / encoded_cwd if not target.is_dir(): return 0 # Guard: only sweep copilot-generated dirs. @@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: # Only safe for single-tenant deployments; callers should prefer the # scoped variant by passing encoded_cwd. try: - entries = Path(projects_base).iterdir() + entries = Path(_pbase).iterdir() except OSError as e: logger.warning("[Transcript] Failed to list projects dir: %s", e) return 0 @@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None: if not transcript_path: return None - projects_base = _projects_base() + _pbase = projects_base() real_path = os.path.realpath(transcript_path) - if not real_path.startswith(projects_base + os.sep): + if not real_path.startswith(_pbase + os.sep): logger.warning( "[Transcript] transcript_path outside projects base: %s", transcript_path ) @@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool: # --------------------------------------------------------------------------- -def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: - """Return (workspace_id, file_id, filename) for a session's transcript. - - Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl`` - IDs are sanitized to hex+hyphen to prevent path traversal. - """ - return ( - TRANSCRIPT_STORAGE_PREFIX, - _sanitize_id(user_id), - f"{_sanitize_id(session_id)}.jsonl", - ) - - -def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: - """Return (workspace_id, file_id, filename) for a session's transcript metadata.""" - return ( - TRANSCRIPT_STORAGE_PREFIX, - _sanitize_id(user_id), - f"{_sanitize_id(session_id)}.meta.json", - ) - - def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str: """Build a full storage path from (workspace_id, file_id, filename) parts.""" wid, fid, fname = parts @@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str: return f"local://{wid}/{fid}/{fname}" -def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: - """Build the full storage path string that ``retrieve()`` expects.""" - return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend) - - -def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str: - """Build the full storage path for the companion .meta.json file.""" - return _build_path_from_parts( - _meta_storage_path_parts(user_id, session_id), backend - ) - - # --------------------------------------------------------------------------- # CLI native session file — cross-pod --resume support # --------------------------------------------------------------------------- -def _cli_session_path(sdk_cwd: str, session_id: str) -> str: +def cli_session_path(sdk_cwd: str, session_id: str) -> str: """Expected path of the CLI's native session JSONL file. The CLI resolves the working directory via ``os.path.realpath``, then @@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str: """ encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) safe_id = _sanitize_id(session_id) - return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl") + return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl") def _cli_session_storage_path_parts( @@ -689,209 +659,82 @@ def _cli_session_storage_path_parts( ) -async def upload_cli_session( - user_id: str, - session_id: str, - sdk_cwd: str, - log_prefix: str = "[Transcript]", -) -> None: - """Upload the CLI's native session JSONL file to remote storage. - - Called after each turn so the next turn can restore the file on any pod - (eliminating the pod-affinity requirement for --resume). - - The CLI only writes the session file after the turn completes, so this - must run in the finally block, AFTER the SDK stream has finished. - """ - session_file = _cli_session_path(sdk_cwd, session_id) - real_path = os.path.realpath(session_file) - projects_base = _projects_base() - - if not real_path.startswith(projects_base + os.sep): - logger.warning( - "%s CLI session file outside projects base, skipping upload: %s", - log_prefix, - os.path.basename(real_path), - ) - return - - try: - content = Path(real_path).read_bytes() - except FileNotFoundError: - logger.debug( - "%s CLI session file not found, skipping upload: %s", - log_prefix, - session_file, - ) - return - except OSError as e: - logger.warning("%s Failed to read CLI session file: %s", log_prefix, e) - return - - storage = await get_workspace_storage() - wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id) - try: - await storage.store( - workspace_id=wid, file_id=fid, filename=fname, content=content - ) - logger.info( - "%s Uploaded CLI session file (%dB) for cross-pod --resume", - log_prefix, - len(content), - ) - except Exception as e: - logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e) - - -async def restore_cli_session( - user_id: str, - session_id: str, - sdk_cwd: str, - log_prefix: str = "[Transcript]", -) -> bool: - """Download and restore the CLI's native session file for --resume. - - Returns True if the file was successfully restored and --resume can be - used with the session UUID. Returns False if not available (first turn - or upload failed), in which case the caller should not set --resume. - """ - session_file = _cli_session_path(sdk_cwd, session_id) - real_path = os.path.realpath(session_file) - projects_base = _projects_base() - - if not real_path.startswith(projects_base + os.sep): - logger.warning( - "%s CLI session restore path outside projects base: %s", - log_prefix, - os.path.basename(session_file), - ) - return False - - # If the session file already exists locally (same-pod reuse), use it directly. - # Downloading from storage could overwrite a newer local version when a previous - # turn's upload failed: stored content is stale while the local file already - # contains extended history from that turn. - if Path(real_path).exists(): - logger.debug( - "%s CLI session file already exists locally — using it for --resume", - log_prefix, - ) - return True - - storage = await get_workspace_storage() - path = _build_path_from_parts( - _cli_session_storage_path_parts(user_id, session_id), storage +def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: + """Return (workspace_id, file_id, filename) for the CLI session meta file.""" + return ( + _CLI_SESSION_STORAGE_PREFIX, + _sanitize_id(user_id), + f"{_sanitize_id(session_id)}.meta.json", ) - try: - content = await storage.retrieve(path) - except FileNotFoundError: - logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix) - return False - except Exception as e: - logger.warning("%s Failed to download CLI session: %s", log_prefix, e) - return False - - try: - os.makedirs(os.path.dirname(real_path), exist_ok=True) - Path(real_path).write_bytes(content) - logger.info( - "%s Restored CLI session file (%dB) for --resume", - log_prefix, - len(content), - ) - return True - except OSError as e: - logger.warning("%s Failed to write CLI session file: %s", log_prefix, e) - return False - async def upload_transcript( user_id: str, session_id: str, - content: str, + content: bytes, message_count: int = 0, + mode: TranscriptMode = "sdk", log_prefix: str = "[Transcript]", - skip_strip: bool = False, ) -> None: - """Strip progress entries and stale thinking blocks, then upload transcript. + """Upload CLI session content to GCS with companion meta.json. - The transcript represents the FULL active context (atomic). - Each upload REPLACES the previous transcript entirely. + Pure GCS operation — no disk I/O. The caller is responsible for reading + the session file from disk before calling this function. - The executor holds a cluster lock per session, so concurrent uploads for - the same session cannot happen. + Also uploads a companion .meta.json with the message_count watermark so + download_transcript can return it without a separate fetch. - Args: - content: Complete JSONL transcript (from TranscriptBuilder). - message_count: ``len(session.messages)`` at upload time. - skip_strip: When ``True``, skip the strip + re-validate pass. - Safe for builder-generated content (baseline path) which - never emits progress entries or stale thinking blocks. + Called after each turn so the next turn can restore the file on any pod + (eliminating the pod-affinity requirement for --resume). """ - if skip_strip: - # Caller guarantees the content is already clean and valid. - stripped = content - else: - # Strip metadata entries and stale thinking blocks in a single parse. - # SDK-built transcripts may have progress entries; strip for safety. - stripped = strip_for_upload(content) - if not skip_strip and not validate_transcript(stripped): - # Log entry types for debugging — helps identify why validation failed - entry_types = [ - json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?") - for line in stripped.strip().split("\n") - ] - logger.warning( - "%s Skipping upload — stripped content not valid " - "(types=%s, stripped_len=%d, raw_len=%d)", - log_prefix, - entry_types, - len(stripped), - len(content), - ) - logger.debug("%s Raw content preview: %s", log_prefix, content[:500]) - logger.debug("%s Stripped content: %s", log_prefix, stripped[:500]) - return - storage = await get_workspace_storage() - wid, fid, fname = _storage_path_parts(user_id, session_id) - encoded = stripped.encode("utf-8") - meta = {"message_count": message_count, "uploaded_at": time.time()} - mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id) + wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id) + mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id) + meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()} meta_encoded = json.dumps(meta).encode("utf-8") - # Transcript + metadata are independent objects at different keys, so - # write them concurrently. ``return_exceptions`` keeps a metadata - # failure from sinking the transcript write. - transcript_result, metadata_result = await asyncio.gather( - storage.store( - workspace_id=wid, - file_id=fid, - filename=fname, - content=encoded, - ), - storage.store( - workspace_id=mwid, - file_id=mfid, - filename=mfname, - content=meta_encoded, - ), - return_exceptions=True, - ) - if isinstance(transcript_result, BaseException): - raise transcript_result - if isinstance(metadata_result, BaseException): - # Metadata is best-effort — the gap-fill logic in - # _build_query_message tolerates a missing metadata file. - logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result) + # Write JSONL first, meta second — sequential so a crash between the two + # leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong + # watermark / mode paired with stale or absent content). + # On any failure we roll back the other file so the pair is always absent + # together; download_transcript returns None when either file is missing. + try: + await storage.store( + workspace_id=wid, file_id=fid, filename=fname, content=content + ) + except Exception as session_err: + logger.warning( + "%s Failed to upload CLI session file: %s", log_prefix, session_err + ) + return + + try: + await storage.store( + workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded + ) + except Exception as meta_err: + logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err) + # Roll back the JSONL so neither file exists — avoids orphaned JSONL being + # used with wrong mode/watermark defaults on the next restore. + try: + session_path = _build_path_from_parts( + _cli_session_storage_path_parts(user_id, session_id), storage + ) + await storage.delete(session_path) + except Exception as rollback_err: + logger.debug( + "%s Session rollback failed (harmless — download will return None): %s", + log_prefix, + rollback_err, + ) + return logger.info( - "%s Uploaded %dB (stripped from %dB, msg_count=%d)", + "%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)", log_prefix, - len(encoded), len(content), message_count, + mode, ) @@ -900,83 +743,173 @@ async def download_transcript( session_id: str, log_prefix: str = "[Transcript]", ) -> TranscriptDownload | None: - """Download transcript and metadata from bucket storage. + """Download CLI session from GCS. Returns content + message_count + mode, or None if not found. - Returns a ``TranscriptDownload`` with the JSONL content and the - ``message_count`` watermark from the upload, or ``None`` if not found. + Pure GCS operation — no disk I/O. The caller is responsible for writing + content to disk if --resume is needed. - The content and metadata fetches run concurrently since they are - independent objects in the bucket. + Returns a TranscriptDownload with the raw content, message_count watermark, + and mode on success, or None if not available (first turn or upload failed). """ storage = await get_workspace_storage() - path = _build_storage_path(user_id, session_id, storage) - meta_path = _build_meta_storage_path(user_id, session_id, storage) + path = _build_path_from_parts( + _cli_session_storage_path_parts(user_id, session_id), storage + ) + meta_path = _build_path_from_parts( + _cli_session_meta_path_parts(user_id, session_id), storage + ) - content_task = asyncio.create_task(storage.retrieve(path)) - meta_task = asyncio.create_task(storage.retrieve(meta_path)) content_result, meta_result = await asyncio.gather( - content_task, meta_task, return_exceptions=True + storage.retrieve(path), + storage.retrieve(meta_path), + return_exceptions=True, ) if isinstance(content_result, FileNotFoundError): - logger.debug("%s No transcript in storage", log_prefix) + logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix) return None if isinstance(content_result, BaseException): logger.warning( - "%s Failed to download transcript: %s", log_prefix, content_result + "%s Failed to download CLI session: %s", log_prefix, content_result ) return None - content = content_result.decode("utf-8") + content: bytes = content_result - # Metadata is best-effort — old transcripts won't have it. + # Parse message_count and mode from companion meta — best-effort, defaults. message_count = 0 - uploaded_at = 0.0 + mode: TranscriptMode = "sdk" if isinstance(meta_result, FileNotFoundError): - pass # No metadata — treat as unknown (msg_count=0 → always fill gap) + pass # No meta — old upload; default to "sdk" elif isinstance(meta_result, BaseException): - logger.debug( - "%s Failed to load transcript metadata: %s", log_prefix, meta_result - ) + logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result) else: - meta = json.loads(meta_result.decode("utf-8"), fallback={}) - message_count = meta.get("message_count", 0) - uploaded_at = meta.get("uploaded_at", 0.0) + try: + meta_str = meta_result.decode("utf-8") + except UnicodeDecodeError: + logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix) + meta_str = None + if meta_str is not None: + meta = json.loads(meta_str, fallback={}) + if isinstance(meta, dict): + raw_count = meta.get("message_count", 0) + message_count = ( + raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0 + ) + raw_mode = meta.get("mode", "sdk") + mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk" logger.info( - "%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count - ) - return TranscriptDownload( - content=content, - message_count=message_count, - uploaded_at=uploaded_at, + "%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)", + log_prefix, + len(content), + message_count, + mode, ) + return TranscriptDownload(content=content, message_count=message_count, mode=mode) + + +def detect_gap( + download: TranscriptDownload, + session_messages: list[ChatMessage], +) -> list[ChatMessage]: + """Return chat-db messages after the transcript watermark (excluding current user turn). + + Returns [] if transcript is current, watermark is zero, or the watermark + position doesn't end on an assistant turn (misaligned watermark). + """ + if download.message_count == 0: + return [] + wm = download.message_count + total = len(session_messages) + if wm >= total - 1: + return [] + # Sanity: position wm-1 should be an assistant turn; misaligned watermark + # means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context. + # In normal operation ``message_count`` is always written after a complete + # user→assistant exchange (never mid-turn), so the last covered position is + # always assistant. This guard fires only on data corruption or message deletion. + if session_messages[wm - 1].role != "assistant": + return [] + return list(session_messages[wm : total - 1]) + + +def extract_context_messages( + download: TranscriptDownload | None, + session_messages: "list[ChatMessage]", +) -> "list[ChatMessage]": + """Return context messages for the current turn: transcript content + gap. + + This is the shared context primitive used by both the SDK path + (``use_resume=False`` → ``<conversation_history>`` injection) and the + baseline path (OpenAI messages array). + + How it works: + + - When a transcript exists, ``TranscriptBuilder.load_previous`` preserves + ``isCompactSummary=True`` compaction entries, so the returned messages + mirror the compacted context the CLI would see via ``--resume``. + - The gap (DB messages after the transcript watermark) is always small in + normal operation; it only grows during mode switches or when an upload + was missed. + - Falls back to full DB messages when no transcript exists (first turn, + upload failure, or GCS unavailable). + - Returns *prior* messages only (excluding the current user turn at + ``session_messages[-1]``). Callers that need the current turn append + ``session_messages[-1]`` themselves. + - **Tool calls from transcript entries are flattened to text**: assistant + messages derived from the JSONL use ``_flatten_assistant_content``, which + serialises ``tool_use`` blocks as human-readable text rather than + structured ``tool_calls``. Gap messages (from DB) preserve their + original ``tool_calls`` field. This is the same trade-off as the old + ``_compress_session_messages(session.messages)`` approach — no regression. + + Args: + download: The ``TranscriptDownload`` from GCS, or ``None`` when no + transcript is available. ``content`` may be either ``bytes`` or + ``str`` (the baseline path decodes + strips before returning). + session_messages: All messages in the session, with the current user + turn as the last element. + + Returns: + A list of ``ChatMessage`` objects covering the prior conversation + context, suitable for injection as conversation history. + """ + from .model import ChatMessage as _ChatMessage # runtime import + + prior = session_messages[:-1] + + if download is None: + return prior + + raw_content = download.content + if not raw_content: + return prior + + # Handle both bytes (raw GCS download) and str (pre-decoded baseline path). + if isinstance(raw_content, bytes): + try: + content_str: str = raw_content.decode("utf-8") + except UnicodeDecodeError: + return prior + else: + content_str = raw_content + + raw = _transcript_to_messages(content_str) + if not raw: + return prior + + transcript_msgs = [ + _ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw + ] + gap = detect_gap(download, session_messages) + return transcript_msgs + gap async def delete_transcript(user_id: str, session_id: str) -> None: - """Delete transcript and its metadata from bucket storage. - - Removes both the ``.jsonl`` transcript and the companion ``.meta.json`` - so stale ``message_count`` watermarks cannot corrupt gap-fill logic. - """ + """Delete CLI session JSONL and its companion .meta.json from bucket storage.""" storage = await get_workspace_storage() - path = _build_storage_path(user_id, session_id, storage) - try: - await storage.delete(path) - logger.info("[Transcript] Deleted transcript for session %s", session_id) - except Exception as e: - logger.warning("[Transcript] Failed to delete transcript: %s", e) - - # Also delete the companion .meta.json to avoid orphaned metadata. - try: - meta_path = _build_meta_storage_path(user_id, session_id, storage) - await storage.delete(meta_path) - logger.info("[Transcript] Deleted metadata for session %s", session_id) - except Exception as e: - logger.warning("[Transcript] Failed to delete metadata: %s", e) - - # Also delete the CLI native session file to prevent storage growth. try: cli_path = _build_path_from_parts( _cli_session_storage_path_parts(user_id, session_id), storage @@ -986,6 +919,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None: except Exception as e: logger.warning("[Transcript] Failed to delete CLI session: %s", e) + try: + cli_meta_path = _build_path_from_parts( + _cli_session_meta_path_parts(user_id, session_id), storage + ) + await storage.delete(cli_meta_path) + logger.info("[Transcript] Deleted CLI session meta for session %s", session_id) + except Exception as e: + logger.warning("[Transcript] Failed to delete CLI session meta: %s", e) + # --------------------------------------------------------------------------- # Transcript compaction — LLM summarization for prompt-too-long recovery diff --git a/autogpt_platform/backend/backend/copilot/transcript_test.py b/autogpt_platform/backend/backend/copilot/transcript_test.py index fec869b6ac..2d624308f5 100644 --- a/autogpt_platform/backend/backend/copilot/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/transcript_test.py @@ -16,11 +16,11 @@ from .transcript import ( _flatten_assistant_content, _flatten_tool_result_content, _messages_to_transcript, - _meta_storage_path_parts, _rechain_tail, _sanitize_id, - _storage_path_parts, _transcript_to_messages, + detect_gap, + extract_context_messages, strip_for_upload, validate_transcript, ) @@ -64,24 +64,6 @@ class TestSanitizeId: assert _sanitize_id("!@#$%^&*()") == "unknown" -# --------------------------------------------------------------------------- -# _storage_path_parts / _meta_storage_path_parts -# --------------------------------------------------------------------------- - - -class TestStoragePathParts: - def test_returns_triple(self): - prefix, uid, fname = _storage_path_parts("user-1", "sess-2") - assert prefix == "chat-transcripts" - assert "e" in uid # hex chars from "user-1" sanitized - assert fname.endswith(".jsonl") - - def test_meta_returns_meta_json(self): - prefix, _, fname = _meta_storage_path_parts("user-1", "sess-2") - assert prefix == "chat-transcripts" - assert fname.endswith(".meta.json") - - # --------------------------------------------------------------------------- # _build_path_from_parts # --------------------------------------------------------------------------- @@ -103,24 +85,6 @@ class TestBuildPathFromParts: assert path == "local://wid/fid/file.jsonl" -# --------------------------------------------------------------------------- -# TranscriptDownload dataclass -# --------------------------------------------------------------------------- - - -class TestTranscriptDownload: - def test_defaults(self): - td = TranscriptDownload(content="hello") - assert td.content == "hello" - assert td.message_count == 0 - assert td.uploaded_at == 0.0 - - def test_custom_values(self): - td = TranscriptDownload(content="data", message_count=5, uploaded_at=123.45) - assert td.message_count == 5 - assert td.uploaded_at == 123.45 - - # --------------------------------------------------------------------------- # _flatten_assistant_content # --------------------------------------------------------------------------- @@ -733,202 +697,203 @@ class TestValidateTranscript: class TestCliSessionPath: def test_encodes_slashes_to_dashes(self): - from .transcript import _cli_session_path, _projects_base + from .transcript import cli_session_path, projects_base sdk_cwd = "/tmp/copilot-abc" - result = _cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc") - base = _projects_base() + result = cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc") + base = projects_base() assert result.startswith(base) # Encoded cwd replaces '/' with '-' assert "-tmp-copilot-abc" in result assert result.endswith(".jsonl") def test_sanitizes_session_id(self): - from .transcript import _cli_session_path + from .transcript import cli_session_path - result = _cli_session_path("/tmp/cwd", "../../etc/passwd") + result = cli_session_path("/tmp/cwd", "../../etc/passwd") # _sanitize_id strips non-hex/hyphen chars; path traversal impossible assert ".." not in result assert "passwd" not in result class TestUploadCliSession: - def test_skips_upload_when_path_outside_projects_base(self, tmp_path): - """Files outside the CLI projects base are rejected without upload.""" + def test_uploads_content_bytes_successfully(self): + """Happy path: content bytes are stored as jsonl + meta.json.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import upload_cli_session + from .transcript import upload_transcript mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=str(tmp_path), - ), - # Return a path that is genuinely outside tmp_path so that - # realpath(session_file).startswith(projects_base + "/") is False - # and the boundary guard actually fires. - patch( - "backend.copilot.transcript._cli_session_path", - return_value="/outside/escaped/session.jsonl", - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, ): asyncio.run( - upload_cli_session( + upload_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), + session_id="12345678-0000-0000-0000-000000000001", + content=content, ) ) - # storage.store must NOT be called — boundary guard should reject the path - mock_storage.store.assert_not_called() + # Two calls expected: session JSONL + companion .meta.json + assert mock_storage.store.call_count == 2 - def test_skips_upload_when_file_not_found(self, tmp_path): - """Missing CLI session file logs debug and skips upload silently.""" + def test_uploads_companion_meta_json_with_message_count(self): + """upload_transcript stores a companion .meta.json with message_count.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import upload_transcript + + mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000010", + content=content, + message_count=5, + ) + ) + + assert mock_storage.store.call_count == 2 + # Find the meta.json store call + meta_call = next( + c + for c in mock_storage.store.call_args_list + if c.kwargs.get("filename", "").endswith(".meta.json") + ) + meta_content = json.loads(meta_call.kwargs["content"]) + assert meta_content["message_count"] == 5 + + def test_skips_upload_on_storage_failure(self): + """Storage exception on jsonl write is logged and does not propagate. + + With sequential writes, JSONL failure returns early — meta store is + never called, so no rollback is needed. + """ import asyncio from unittest.mock import AsyncMock, patch - from .transcript import upload_cli_session + from .transcript import upload_transcript mock_storage = AsyncMock() - projects_base = str(tmp_path) + mock_storage.store.side_effect = RuntimeError("gcs unavailable") + content = b'{"type":"assistant"}\n' - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, ): - # session file doesn't exist — should not raise + # Should not raise — failures are logged as warnings asyncio.run( - upload_cli_session( + upload_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), - ) - ) - - mock_storage.store.assert_not_called() - - def test_uploads_file_successfully(self, tmp_path): - """Happy path: session file exists within projects base → upload called.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import _sanitize_id, upload_cli_session - - projects_base = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000001" - sdk_cwd = str(tmp_path) - - # Build the path the same way _cli_session_path does, but using our tmp_path - # as projects_base so the boundary check passes. - # Must use the same encoding: re.sub non-alphanumeric → "-" on realpath. - import os - import re - - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) - session_dir = tmp_path / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" - session_file.write_bytes(b'{"type":"assistant"}\n') - - mock_storage = AsyncMock() - - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - ): - asyncio.run( - upload_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, + session_id="12345678-0000-0000-0000-000000000002", + content=content, ) ) + # Only one store call attempted (the JSONL); meta never reached mock_storage.store.assert_called_once() + mock_storage.delete.assert_not_called() - def test_skips_upload_on_oserror(self, tmp_path): - """OSError reading session file is logged as warning; upload is skipped.""" + def test_rolls_back_session_when_meta_upload_fails(self): + """When meta upload fails after JSONL succeeds, JSONL is rolled back. + + Guarantees the pair is either both present or both absent — avoids an + orphaned JSONL being used with wrong mode/watermark defaults. + """ import asyncio from unittest.mock import AsyncMock, patch - from .transcript import _sanitize_id, upload_cli_session - - projects_base = str(tmp_path) - sdk_cwd = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000002" - - # Build file at a path inside projects_base so boundary check passes. - import os - import re - - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) - session_dir = tmp_path / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" - session_file.write_bytes(b'{"type":"assistant"}\n') - # Remove read permission to trigger OSError - session_file.chmod(0o000) + from .transcript import upload_transcript mock_storage = AsyncMock() + # First store (JSONL) succeeds; second store (meta) fails + mock_storage.store.side_effect = [None, RuntimeError("meta write failed")] + content = b'{"type":"assistant"}\n' - try: - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - ): - asyncio.run( - upload_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000099", + content=content, ) - finally: - session_file.chmod(0o644) # restore so tmp_path cleanup works + ) - mock_storage.store.assert_not_called() + # Both store calls were attempted (JSONL then meta) + assert mock_storage.store.call_count == 2 + # JSONL should be rolled back via delete + mock_storage.delete.assert_called_once() + + def test_baseline_mode_stored_in_meta(self): + """upload_transcript with mode='baseline' stores mode in companion meta.json.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import upload_transcript + + mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000098", + content=content, + message_count=4, + mode="baseline", + ) + ) + + meta_call = next( + c + for c in mock_storage.store.call_args_list + if c.kwargs.get("filename", "").endswith(".meta.json") + ) + meta_content = json.loads(meta_call.kwargs["content"]) + assert meta_content["mode"] == "baseline" + assert meta_content["message_count"] == 4 class TestRestoreCliSession: - def test_returns_false_when_file_not_found_in_storage(self): - """Returns False (graceful degradation) when the session is missing.""" + def test_returns_none_when_file_not_found_in_storage(self): + """Returns None (graceful degradation) when the session is missing.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import restore_cli_session + from .transcript import download_transcript mock_storage = AsyncMock() - mock_storage.retrieve.side_effect = FileNotFoundError("not found") + mock_storage.retrieve.side_effect = [ + FileNotFoundError("no session"), + FileNotFoundError("no meta"), + ] with patch( "backend.copilot.transcript.get_workspace_storage", @@ -936,144 +901,26 @@ class TestRestoreCliSession: return_value=mock_storage, ): result = asyncio.run( - restore_cli_session( + download_transcript( user_id="user-1", session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd="/tmp/copilot-test", ) ) - assert result is False + assert result is None - def test_returns_false_when_restore_path_outside_projects_base(self, tmp_path): - """Path traversal guard: rejects restoration outside the projects base.""" + def test_returns_transcript_download_on_success_no_meta(self): + """Happy path with no meta.json: returns TranscriptDownload with message_count=0.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import restore_cli_session + from .transcript import download_transcript - mock_storage = AsyncMock() - mock_storage.retrieve.return_value = b'{"type":"assistant"}\n' - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=str(tmp_path), - ), - # Return a path genuinely outside tmp_path so the boundary guard fires. - patch( - "backend.copilot.transcript._cli_session_path", - return_value="/outside/escaped/session.jsonl", - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), - ) - ) - - assert result is False - - def test_returns_true_when_local_file_already_exists(self, tmp_path): - """Same-pod reuse: if local file exists, skip storage download and return True.""" - import asyncio - import os - import re - from pathlib import Path - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - session_id = "12345678-0000-0000-0000-000000000099" - sdk_cwd = str(tmp_path) - - # Pre-create the local session file (simulates previous turn on same pod) - projects_base = os.path.realpath(str(tmp_path)) - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", projects_base) - session_dir = Path(projects_base) / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - existing_content = b'{"type":"user"}\n{"type":"assistant"}\n' - (session_dir / f"{session_id}.jsonl").write_bytes(existing_content) - - mock_storage = AsyncMock() - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - assert result is True - # Storage should NOT have been accessed (local file was used as-is) - mock_storage.retrieve.assert_not_called() - # Local file should be unchanged - assert (session_dir / f"{session_id}.jsonl").read_bytes() == existing_content - - def test_returns_true_on_success(self, tmp_path): - """Happy path: storage has the session → file written → returns True.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - projects_base = str(tmp_path) - sdk_cwd = str(tmp_path) session_id = "12345678-0000-0000-0000-000000000003" content = b'{"type":"assistant"}\n' mock_storage = AsyncMock() - mock_storage.retrieve.return_value = content - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - assert result is True - - def test_returns_false_on_download_exception(self): - """Non-FileNotFoundError during retrieve logs warning and returns False.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - mock_storage = AsyncMock() - mock_storage.retrieve.side_effect = RuntimeError("network error") + mock_storage.retrieve.side_effect = [content, FileNotFoundError("no meta")] with patch( "backend.copilot.transcript.get_workspace_storage", @@ -1081,11 +928,411 @@ class TestRestoreCliSession: return_value=mock_storage, ): result = asyncio.run( - restore_cli_session( + download_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000004", - sdk_cwd="/tmp/copilot-test", + session_id=session_id, ) ) - assert result is False + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 0 + assert result.mode == "sdk" + + def test_returns_transcript_download_with_message_count_from_meta(self): + """When meta.json is present, message_count and mode are read from it.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + session_id = "12345678-0000-0000-0000-000000000005" + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps( + {"message_count": 7, "mode": "sdk", "uploaded_at": 1234567.0} + ).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id=session_id, + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 7 + assert result.mode == "sdk" + + def test_returns_none_on_download_exception(self): + """Non-FileNotFoundError during retrieve logs warning and returns None.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [ + RuntimeError("network error"), + FileNotFoundError("no meta"), + ] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000004", + ) + ) + + assert result is None + + def test_baseline_mode_in_meta_returned(self): + """When meta.json contains mode='baseline', result.mode is 'baseline'.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps( + {"message_count": 3, "mode": "baseline", "uploaded_at": 0.0} + ).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000020", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.mode == "baseline" + assert result.message_count == 3 + + def test_invalid_mode_in_meta_defaults_to_sdk(self): + """Unknown mode value in meta.json falls back to 'sdk'.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps({"message_count": 2, "mode": "unknown_mode"}).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000021", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.mode == "sdk" + + def test_invalid_utf8_meta_uses_defaults(self): + """Meta bytes that fail UTF-8 decode fall back to message_count=0, mode='sdk'.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + bad_meta = b"\xff\xfe" + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, bad_meta] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000022", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.message_count == 0 + assert result.mode == "sdk" + + def test_meta_fetch_exception_uses_defaults(self): + """Non-FileNotFoundError on meta fetch still returns content with defaults.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, RuntimeError("meta unavailable")] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000023", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 0 + assert result.mode == "sdk" + + +# --------------------------------------------------------------------------- +# detect_gap +# --------------------------------------------------------------------------- + + +def _msgs(*roles: str): + """Build a list of ChatMessage objects with the given roles.""" + from .model import ChatMessage + + return [ChatMessage(role=r, content=f"{r}-{i}") for i, r in enumerate(roles)] + + +class TestDetectGap: + """``detect_gap`` returns messages between transcript watermark and current turn.""" + + def _dl(self, message_count: int) -> TranscriptDownload: + return TranscriptDownload(content=b"", message_count=message_count, mode="sdk") + + def test_zero_watermark_returns_empty(self): + """message_count=0 means no watermark — skip gap detection.""" + dl = self._dl(0) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_watermark_covers_all_prefix_returns_empty(self): + """Transcript already covers all messages up to the current user turn.""" + # session: [user, assistant, user(current)] — wm=2 means covers up to assistant + dl = self._dl(2) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_watermark_exceeds_session_returns_empty(self): + """Watermark ahead of session count (race / over-count) → no gap.""" + dl = self._dl(10) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_misaligned_watermark_not_on_assistant_returns_empty(self): + """Watermark at a user-role position is misaligned — skip gap.""" + # wm=1: position 0 is 'user', not 'assistant' → skip + dl = self._dl(1) + messages = _msgs("user", "assistant", "user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_returns_gap_messages(self): + """Watermark behind session — gap messages returned (excluding current turn).""" + # session: [user0, assistant1, user2, assistant3, user4(current)] + # wm=2: transcript covers [0,1]; gap = [user2, assistant3] + dl = self._dl(2) + messages = _msgs("user", "assistant", "user", "assistant", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 2 + assert gap[0].role == "user" + assert gap[1].role == "assistant" + + def test_excludes_current_user_turn(self): + """The last message (current user turn) is never included in the gap.""" + # wm=2, session has 4 msgs: gap = [msg2] only (msg3 is current turn → excluded) + dl = self._dl(2) + messages = _msgs("user", "assistant", "user", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 1 + assert gap[0].role == "user" + + def test_single_gap_message(self): + """One message between watermark and current turn.""" + # session: [user0, assistant1, user2, assistant3, user4(current)] + # wm=3: position 2 is 'user' → misaligned, returns [] + # use wm=4: but 4 >= total-1=4 → also empty + # wm=3 with session [u, a, u, a, u, a, u(current)]: position 2 is 'user' → empty + # Valid case: wm=2 has 3 messages (assistant at 1), wm=4 with [u,a,u,a,u,a,u]: + # let's use wm=4 with 7 messages: wm=4 >= total-1=6? no, 4<6. pos[3]=assistant → gap=[msg4,msg5] + # simpler: wm=2, [u0,a1,a2,u3(current)] — pos[1]=assistant, gap=[a2] only + dl = self._dl(2) + messages = _msgs("user", "assistant", "assistant", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 1 + assert gap[0].role == "assistant" + + +# --------------------------------------------------------------------------- +# extract_context_messages +# --------------------------------------------------------------------------- + + +def _make_valid_transcript(*roles: str) -> str: + """Build a minimal valid JSONL transcript with the given message roles.""" + import json as stdlib_json + + from .transcript import STOP_REASON_END_TURN + + lines = [] + parent = "" + for i, role in enumerate(roles): + uid = f"uid-{i}" + entry: dict = { + "type": role, + "uuid": uid, + "parentUuid": parent, + "message": { + "role": role, + "content": f"{role} content {i}", + }, + } + if role == "assistant": + entry["message"]["id"] = f"msg_{i}" + entry["message"]["model"] = "test-model" + entry["message"]["type"] = "message" + entry["message"]["stop_reason"] = STOP_REASON_END_TURN + entry["message"]["content"] = [ + {"type": "text", "text": f"assistant content {i}"} + ] + lines.append(stdlib_json.dumps(entry)) + parent = uid + return "\n".join(lines) + "\n" + + +class TestExtractContextMessages: + """``extract_context_messages`` returns the shared context primitive.""" + + def test_none_download_returns_prior(self): + """No download → falls back to all session messages except current turn.""" + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(None, messages) + assert result == messages[:-1] + assert len(result) == 2 + + def test_empty_content_download_returns_prior(self): + """Empty bytes content → falls back to all prior messages.""" + dl = TranscriptDownload(content=b"", message_count=2, mode="sdk") + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + assert result == messages[:-1] + + def test_valid_transcript_no_gap_returns_transcript_messages(self): + """Transcript covers all prior turns → only transcript messages returned.""" + # Transcript: [user, assistant] — 2 messages + # Session: [user, assistant, user(current)] — watermark=2 covers prefix + transcript_content = _make_valid_transcript("user", "assistant") + dl = TranscriptDownload( + content=transcript_content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + # Transcript has 2 messages (user + assistant) and no gap + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + + def test_valid_transcript_with_gap_returns_transcript_plus_gap(self): + """Transcript is stale → gap messages appended after transcript content.""" + # Transcript: [user, assistant] — watermark=2 + # Session: [user, assistant, user, assistant, user(current)] + # Gap: [user(2), assistant(3)] — positions 2 and 3 + transcript_content = _make_valid_transcript("user", "assistant") + dl = TranscriptDownload( + content=transcript_content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user", "assistant", "user") + result = extract_context_messages(dl, messages) + # 2 transcript messages + 2 gap messages = 4 + assert len(result) == 4 + assert result[0].role == "user" # transcript user + assert result[1].role == "assistant" # transcript assistant + assert result[2].role == "user" # gap user + assert result[3].role == "assistant" # gap assistant + + def test_compact_summary_entries_preserved(self): + """``isCompactSummary=True`` entries survive ``_transcript_to_messages``.""" + import json as stdlib_json + + from .transcript import STOP_REASON_END_TURN + + # Build a transcript where one entry is a compaction summary. + # isCompactSummary=True entries have type in STRIPPABLE_TYPES but are kept. + compact_entry = stdlib_json.dumps( + { + "type": "summary", + "uuid": "uid-compact", + "parentUuid": "", + "isCompactSummary": True, + "message": { + "role": "user", + "content": "COMPACT_SUMMARY_CONTENT", + }, + } + ) + assistant_entry = stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-compact", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "response after compact"}], + }, + } + ) + content = compact_entry + "\n" + assistant_entry + "\n" + dl = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + # Both the compact summary and the assistant response are present + assert len(result) == 2 + roles = [m.role for m in result] + assert "user" in roles # compact summary has role=user + assert "assistant" in roles + # The compact summary content is preserved + compact_msgs = [m for m in result if m.role == "user"] + assert any("COMPACT_SUMMARY_CONTENT" in (m.content or "") for m in compact_msgs) diff --git a/autogpt_platform/backend/scripts/download_transcripts.py b/autogpt_platform/backend/scripts/download_transcripts.py index 26204c3243..a9b32e8494 100644 --- a/autogpt_platform/backend/scripts/download_transcripts.py +++ b/autogpt_platform/backend/scripts/download_transcripts.py @@ -88,17 +88,19 @@ async def cmd_download(session_ids: list[str]) -> None: print(f"[{sid[:12]}] Not found in GCS") continue + content_str = ( + dl.content.decode("utf-8") if isinstance(dl.content, bytes) else dl.content + ) out = _transcript_path(sid) with open(out, "w") as f: - f.write(dl.content) + f.write(content_str) - lines = len(dl.content.strip().split("\n")) + lines = len(content_str.strip().split("\n")) meta = { "session_id": sid, "user_id": user_id, "message_count": dl.message_count, - "uploaded_at": dl.uploaded_at, - "transcript_bytes": len(dl.content), + "transcript_bytes": len(content_str), "transcript_lines": lines, } with open(_meta_path(sid), "w") as f: @@ -106,7 +108,7 @@ async def cmd_download(session_ids: list[str]) -> None: print( f"[{sid[:12]}] Saved: {lines} entries, " - f"{len(dl.content)} bytes, msg_count={dl.message_count}" + f"{len(content_str)} bytes, msg_count={dl.message_count}" ) print("\nDone. Run 'load' command to import into local dev environment.") @@ -227,7 +229,7 @@ async def cmd_load(session_ids: list[str]) -> None: await upload_transcript( user_id=user_id, session_id=sid, - content=content, + content=content.encode("utf-8"), message_count=msg_count, ) print(f"[{sid[:12]}] Stored transcript in local workspace storage") diff --git a/autogpt_platform/backend/test/copilot/test_transcript_watermark.py b/autogpt_platform/backend/test/copilot/test_transcript_watermark.py new file mode 100644 index 0000000000..bd88726339 --- /dev/null +++ b/autogpt_platform/backend/test/copilot/test_transcript_watermark.py @@ -0,0 +1,140 @@ +"""Unit tests for the transcript watermark (message_count) fix. + +The bug: upload used message_count=len(session.messages) (DB count). When a +prior turn's GCS upload failed silently, the JSONL on GCS was stale (e.g. +covered only T1-T12) but the meta.json watermark matched the full DB count +(e.g. 46). The next turn's gap-fill check (transcript_msg_count < msg_count-1) +never triggered, so the model silently lost context for the skipped turns. + +The fix: watermark = previous_coverage + 2 (current user+asst pair) when +use_resume=True and transcript_msg_count > 0. This ensures the watermark +reflects the JSONL content, not the DB count. + +These tests exercise _build_query_message directly to verify that gap-fill +triggers with the corrected watermark but NOT with the inflated (buggy) one. +""" + +from unittest.mock import MagicMock + +import pytest + +from backend.copilot.sdk.service import _build_query_message + + +def _make_messages(n_pairs: int, *, current_user: str = "current") -> list[MagicMock]: + """Build a flat list of n_pairs*2 alternating user/asst messages, plus + one trailing user message for the *current* turn.""" + msgs: list[MagicMock] = [] + for i in range(n_pairs): + u = MagicMock() + u.role = "user" + u.content = f"user message {i}" + a = MagicMock() + a.role = "assistant" + a.content = f"assistant response {i}" + msgs.extend([u, a]) + # Current turn's user message + cur = MagicMock() + cur.role = "user" + cur.content = current_user + msgs.append(cur) + return msgs + + +def _make_session(messages: list[MagicMock]) -> MagicMock: + session = MagicMock() + session.messages = messages + return session + + +@pytest.mark.asyncio +async def test_gap_fill_triggers_for_stale_jsonl(): + """Scenario: T1-T12 in JSONL (watermark=24), DB has T1-T22+Test (46 msgs). + + With the FIX: 'Test' uploaded watermark=26 (T12's 24 + 2 for 'Test'). + Next turn (T24) downloads watermark=26, DB has 47. + Gap check: 26 < 47-1=46 → TRUE → gap fills T14-T23. + """ + # T23 turns in DB (46 messages) + T24 user = 47 + msgs = _make_messages(23, current_user="memory test - recall all") + assert len(msgs) == 47 + + session = _make_session(msgs) + + # Watermark as uploaded by the FIX: T12 covered 24, 'Test' +2 = 26 + result_msg, _ = await _build_query_message( + current_message="memory test - recall all", + session=session, + use_resume=True, + transcript_msg_count=26, + session_id="test-session-id", + ) + + assert "<conversation_history>" in result_msg, ( + "Expected gap-fill to inject <conversation_history> when " + "watermark=26 < msg_count-1=46" + ) + + +@pytest.mark.asyncio +async def test_no_gap_fill_when_watermark_is_current(): + """When the JSONL is fully current (watermark = DB-1), no gap injected.""" + # T23 turns in DB (46 messages) + T24 user = 47 + msgs = _make_messages(23, current_user="next message") + session = _make_session(msgs) + + result_msg, _ = await _build_query_message( + current_message="next message", + session=session, + use_resume=True, + transcript_msg_count=46, # current — no gap + session_id="test-session-id", + ) + + assert ( + "<conversation_history>" not in result_msg + ), "No gap-fill expected when watermark is current" + assert result_msg == "next message" + + +@pytest.mark.asyncio +async def test_inflated_watermark_suppresses_gap_fill(): + """Documents the original bug: inflated watermark suppresses gap-fill. + + 'Test' uploaded watermark=len(session.messages)=46 even though only 26 + messages are in the JSONL. Next turn: 46 < 47-1=46 → FALSE → no gap fill. + """ + msgs = _make_messages(23, current_user="memory test") + session = _make_session(msgs) + + # Buggy watermark: inflated to DB count + result_msg, _ = await _build_query_message( + current_message="memory test", + session=session, + use_resume=True, + transcript_msg_count=46, # inflated — suppresses gap fill + session_id="test-session-id", + ) + + assert ( + "<conversation_history>" not in result_msg + ), "With inflated watermark, gap-fill is suppressed — this documents the bug" + + +@pytest.mark.asyncio +async def test_fixed_watermark_fills_same_gap(): + """Same scenario but with the FIXED watermark triggers gap-fill.""" + msgs = _make_messages(23, current_user="memory test") + session = _make_session(msgs) + + result_msg, _ = await _build_query_message( + current_message="memory test", + session=session, + use_resume=True, + transcript_msg_count=26, # fixed watermark + session_id="test-session-id", + ) + + assert ( + "<conversation_history>" in result_msg + ), "With fixed watermark=26, gap-fill triggers and injects missing turns"