diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index c01da95a03..e731f9f9bf 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -60,7 +60,8 @@ NVIDIA_API_KEY= # Graphiti Temporal Knowledge Graph Memory # Rollout controlled by LaunchDarkly flag "graphiti-memory" -# LLM/embedder keys fall back to OPEN_ROUTER_API_KEY and OPENAI_API_KEY when empty. +# LLM key falls back to CHAT_API_KEY (AutoPilot), then OPEN_ROUTER_API_KEY. +# Embedder key falls back to CHAT_OPENAI_API_KEY (AutoPilot), then OPENAI_API_KEY. GRAPHITI_FALKORDB_HOST=localhost GRAPHITI_FALKORDB_PORT=6380 GRAPHITI_FALKORDB_PASSWORD= diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 7782785489..b5b1d0d6fe 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -62,6 +62,10 @@ from backend.copilot.tools.models import ( InputValidationErrorResponse, MCPToolOutputResponse, MCPToolsDiscoveredResponse, + MemoryForgetCandidatesResponse, + MemoryForgetConfirmResponse, + MemorySearchResponse, + MemoryStoreResponse, NeedLoginResponse, NoResultsResponse, SetupRequirementsResponse, @@ -1432,6 +1436,10 @@ ToolResponseUnion = ( | DocPageResponse | MCPToolsDiscoveredResponse | MCPToolOutputResponse + | MemoryStoreResponse + | MemorySearchResponse + | MemoryForgetCandidatesResponse + | MemoryForgetConfirmResponse ) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index bb3906811c..a2813ad881 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -67,11 +67,15 @@ from backend.copilot.transcript import ( STOP_REASON_END_TURN, STOP_REASON_TOOL_USE, TranscriptDownload, + detect_gap, download_transcript, + extract_context_messages, + strip_for_upload, upload_transcript, validate_transcript, ) from backend.copilot.transcript_builder import TranscriptBuilder +from backend.util import json as util_json from backend.util.exceptions import NotFoundError from backend.util.prompt import ( compress_context, @@ -293,56 +297,69 @@ async def _baseline_llm_caller( ) tool_calls_by_index: dict[int, dict[str, str]] = {} - async for chunk in response: - if chunk.usage: - state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 - state.turn_completion_tokens += chunk.usage.completion_tokens or 0 - # Extract cache token details when available (OpenAI / - # OpenRouter include these in prompt_tokens_details). - ptd = getattr(chunk.usage, "prompt_tokens_details", None) - if ptd: - state.turn_cache_read_tokens += ( - getattr(ptd, "cached_tokens", 0) or 0 - ) - # cache_creation_input_tokens is reported by some providers - # (e.g. Anthropic native) but not standard OpenAI streaming. - state.turn_cache_creation_tokens += ( - getattr(ptd, "cache_creation_input_tokens", 0) or 0 - ) - - delta = chunk.choices[0].delta if chunk.choices else None - if not delta: - continue - - if delta.content: - emit = state.thinking_stripper.process(delta.content) - if emit: - if not state.text_started: - state.pending_events.append( - StreamTextStart(id=state.text_block_id) + # Iterate under an inner try/finally so early exits (cancel, tool-call + # break, exception) always release the underlying httpx connection. + # Without this, openai.AsyncStream leaks the streaming response and + # the TCP socket ends up in CLOSE_WAIT until the process exits. + try: + async for chunk in response: + if chunk.usage: + state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 + state.turn_completion_tokens += chunk.usage.completion_tokens or 0 + # Extract cache token details when available (OpenAI / + # OpenRouter include these in prompt_tokens_details). + ptd = getattr(chunk.usage, "prompt_tokens_details", None) + if ptd: + state.turn_cache_read_tokens += ( + getattr(ptd, "cached_tokens", 0) or 0 + ) + # cache_creation_input_tokens is reported by some providers + # (e.g. Anthropic native) but not standard OpenAI streaming. + state.turn_cache_creation_tokens += ( + getattr(ptd, "cache_creation_input_tokens", 0) or 0 ) - state.text_started = True - round_text += emit - state.pending_events.append( - StreamTextDelta(id=state.text_block_id, delta=emit) - ) - if delta.tool_calls: - for tc in delta.tool_calls: - idx = tc.index - if idx not in tool_calls_by_index: - tool_calls_by_index[idx] = { - "id": "", - "name": "", - "arguments": "", - } - entry = tool_calls_by_index[idx] - if tc.id: - entry["id"] = tc.id - if tc.function and tc.function.name: - entry["name"] = tc.function.name - if tc.function and tc.function.arguments: - entry["arguments"] += tc.function.arguments + delta = chunk.choices[0].delta if chunk.choices else None + if not delta: + continue + + if delta.content: + emit = state.thinking_stripper.process(delta.content) + if emit: + if not state.text_started: + state.pending_events.append( + StreamTextStart(id=state.text_block_id) + ) + state.text_started = True + round_text += emit + state.pending_events.append( + StreamTextDelta(id=state.text_block_id, delta=emit) + ) + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": "", + "name": "", + "arguments": "", + } + entry = tool_calls_by_index[idx] + if tc.id: + entry["id"] = tc.id + if tc.function and tc.function.name: + entry["name"] = tc.function.name + if tc.function and tc.function.arguments: + entry["arguments"] += tc.function.arguments + finally: + # Release the streaming httpx connection back to the pool on every + # exit path (normal completion, break, exception). openai.AsyncStream + # does not auto-close when the async-for loop exits early. + try: + await response.close() + except Exception: + pass # Flush any buffered text held back by the thinking stripper. tail = state.thinking_stripper.flush() @@ -686,81 +703,147 @@ async def _compress_session_messages( return messages -def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool: - """Return ``True`` when a download doesn't cover the current session. - - A transcript is stale when it has a known ``message_count`` and that - count doesn't reach ``session_msg_count - 1`` (i.e. the session has - already advanced beyond what the stored transcript captures). - Loading a stale transcript would silently drop intermediate turns, - so callers should treat stale as "skip load, skip upload". - - An unknown ``message_count`` (``0``) is treated as **not stale** - because older transcripts uploaded before msg_count tracking - existed must still be usable. - """ - if dl is None: - return False - if not dl.message_count: - return False - return dl.message_count < session_msg_count - 1 - - -def should_upload_transcript( - user_id: str | None, transcript_covers_prefix: bool -) -> bool: +def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool: """Return ``True`` when the caller should upload the final transcript. - Uploads require a logged-in user (for the storage key) *and* a - transcript that covered the session prefix when loaded — otherwise - we'd be overwriting a more complete version in storage with a - partial one built from just the current turn. + Uploads require a logged-in user (for the storage key) *and* a safe + upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a + newer version that we'd be overwriting. """ - return bool(user_id) and transcript_covers_prefix + return bool(user_id) and upload_safe + + +def _append_gap_to_builder( + gap: list[ChatMessage], + builder: TranscriptBuilder, +) -> None: + """Append gap messages from chat-db into the TranscriptBuilder. + + Converts ChatMessage (OpenAI format) to TranscriptBuilder entries + (Claude CLI JSONL format) so the uploaded transcript covers all turns. + + Pre-condition: ``gap`` always starts at a user or assistant boundary + (never mid-turn at a ``tool`` role), because ``detect_gap`` enforces + ``session_messages[wm-1].role == 'assistant'`` before returning a non-empty + gap. Any ``tool`` role messages within the gap always follow an assistant + entry that already exists in the builder or in the gap itself. + """ + for msg in gap: + if msg.role == "user": + builder.append_user(msg.content or "") + elif msg.role == "assistant": + content_blocks: list[dict] = [] + if msg.content: + content_blocks.append({"type": "text", "text": msg.content}) + if msg.tool_calls: + for tc in msg.tool_calls: + fn = tc.get("function", {}) if isinstance(tc, dict) else {} + input_data = util_json.loads(fn.get("arguments", "{}"), fallback={}) + content_blocks.append( + { + "type": "tool_use", + "id": tc.get("id", "") if isinstance(tc, dict) else "", + "name": fn.get("name", "unknown"), + "input": input_data, + } + ) + if not content_blocks: + # Fallback: ensure every assistant gap message produces an entry + # so the builder's entry count matches the gap length. + content_blocks.append({"type": "text", "text": ""}) + builder.append_assistant(content_blocks=content_blocks) + elif msg.role == "tool": + if msg.tool_call_id: + builder.append_tool_result( + tool_use_id=msg.tool_call_id, + content=msg.content or "", + ) + else: + # Malformed tool message — no tool_call_id to link to an + # assistant tool_use block. Skip to avoid an unmatched + # tool_result entry in the builder (which would confuse --resume). + logger.warning( + "[Baseline] Skipping tool gap message with no tool_call_id" + ) async def _load_prior_transcript( user_id: str, session_id: str, - session_msg_count: int, + session_messages: list[ChatMessage], transcript_builder: TranscriptBuilder, -) -> bool: - """Download and load the prior transcript into ``transcript_builder``. +) -> tuple[bool, "TranscriptDownload | None"]: + """Download and load the prior CLI session into ``transcript_builder``. - Returns ``True`` when the loaded transcript fully covers the session - prefix; ``False`` otherwise (stale, missing, invalid, or download - error). Callers should suppress uploads when this returns ``False`` - to avoid overwriting a more complete version in storage. + Returns a tuple of (upload_safe, transcript_download): + - ``upload_safe`` is ``True`` when it is safe to upload at the end of this + turn. Upload is suppressed only for **download errors** (unknown GCS + state) — missing and invalid files return ``True`` because there is + nothing in GCS worth protecting against overwriting. + - ``transcript_download`` is a ``TranscriptDownload`` with str content + (pre-decoded and stripped) when available, or ``None`` when no valid + transcript could be loaded. Callers pass this to + ``extract_context_messages`` to build the LLM context. """ try: - dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]") - except Exception as e: - logger.warning("[Baseline] Transcript download failed: %s", e) - return False - - if dl is None: - logger.debug("[Baseline] No transcript available") - return False - - if not validate_transcript(dl.content): - logger.warning("[Baseline] Downloaded transcript but invalid") - return False - - if is_transcript_stale(dl, session_msg_count): - logger.warning( - "[Baseline] Transcript stale: covers %d of %d messages, skipping", - dl.message_count, - session_msg_count, + restore = await download_transcript( + user_id, session_id, log_prefix="[Baseline]" ) - return False + except Exception as e: + logger.warning("[Baseline] Session restore failed: %s", e) + # Unknown GCS state — be conservative, skip upload. + return False, None - transcript_builder.load_previous(dl.content, log_prefix="[Baseline]") + if restore is None: + logger.debug("[Baseline] No CLI session available — will upload fresh") + # Nothing in GCS to protect; allow upload so the first baseline turn + # writes the initial transcript snapshot. + return True, None + + content_bytes = restore.content + try: + raw_str = ( + content_bytes.decode("utf-8") + if isinstance(content_bytes, bytes) + else content_bytes + ) + except UnicodeDecodeError: + logger.warning("[Baseline] CLI session content is not valid UTF-8") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + stripped = strip_for_upload(raw_str) + if not validate_transcript(stripped): + logger.warning("[Baseline] CLI session content invalid after strip") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + transcript_builder.load_previous(stripped, log_prefix="[Baseline]") logger.info( - "[Baseline] Loaded transcript: %dB, msg_count=%d", - len(dl.content), - dl.message_count, + "[Baseline] Loaded CLI session: %dB, msg_count=%d", + len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str), + restore.message_count, ) - return True + + gap = detect_gap(restore, session_messages) + if gap: + _append_gap_to_builder(gap, transcript_builder) + logger.info( + "[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB", + restore.message_count, + len(gap), + ) + + # Return a str-content version so extract_context_messages receives a + # pre-decoded, stripped transcript (avoids redundant decode + strip). + # TranscriptDownload.content is typed as bytes | str; we pass str here + # to avoid a redundant encode + decode round-trip. + str_restore = TranscriptDownload( + content=stripped, + message_count=restore.message_count, + mode=restore.mode, + ) + return True, str_restore async def _upload_final_transcript( @@ -794,10 +877,10 @@ async def _upload_final_transcript( upload_transcript( user_id=user_id, session_id=session_id, - content=content, + content=content.encode("utf-8"), message_count=session_msg_count, + mode="baseline", log_prefix="[Baseline]", - skip_strip=True, ) ) _background_tasks.add(upload_task) @@ -884,7 +967,7 @@ async def stream_chat_completion_baseline( # --- Transcript support (feature parity with SDK path) --- transcript_builder = TranscriptBuilder() - transcript_covers_prefix = True + transcript_upload_safe = True # Build system prompt only on the first turn to avoid mid-conversation # changes from concurrent chats updating business understanding. @@ -901,15 +984,16 @@ async def stream_chat_completion_baseline( # Run download + prompt build concurrently — both are independent I/O # on the request critical path. + transcript_download: TranscriptDownload | None = None if user_id and len(session.messages) > 1: ( - transcript_covers_prefix, + (transcript_upload_safe, transcript_download), (base_system_prompt, understanding), ) = await asyncio.gather( _load_prior_transcript( user_id=user_id, session_id=session_id, - session_msg_count=len(session.messages), + session_messages=session.messages, transcript_builder=transcript_builder, ), prompt_task, @@ -940,17 +1024,23 @@ async def stream_chat_completion_baseline( graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement - # Warm context: pre-load relevant facts from Graphiti on first turn + # Warm context: pre-load relevant facts from Graphiti on first turn. + # Stored here but injected into the user message (not the system prompt) + # after openai_messages is built — keeps system prompt static for caching. + warm_ctx: str | None = None if graphiti_enabled and user_id and len(session.messages) <= 1: from backend.copilot.graphiti.context import fetch_warm_context warm_ctx = await fetch_warm_context(user_id, message or "") - if warm_ctx: - system_prompt += f"\n\n{warm_ctx}" - # Compress context if approaching the model's token limit + # Context path: transcript content (compacted, isCompactSummary preserved) + + # gap (DB messages after watermark) + current user turn. + # This avoids re-reading the full session history from DB on every turn. + # See extract_context_messages() in transcript.py for the shared primitive. + prior_context = extract_context_messages(transcript_download, session.messages) messages_for_context = await _compress_session_messages( - session.messages, model=active_model + prior_context + ([session.messages[-1]] if session.messages else []), + model=active_model, ) # Build OpenAI message list from session history. @@ -996,6 +1086,20 @@ async def stream_chat_completion_baseline( else: logger.warning("[Baseline] No user message found for context injection") + # Inject Graphiti warm context into the first user message (not the + # system prompt) so the system prompt stays static and cacheable. + # warm_ctx is already wrapped in . + # Appended AFTER user_context so stays at the very start. + if warm_ctx: + for msg in openai_messages: + if msg["role"] == "user": + existing = msg.get("content", "") + if isinstance(existing, str): + msg["content"] = f"{existing}\n\n{warm_ctx}" + break + # Do NOT append warm_ctx to user_message_for_transcript — it would + # persist stale temporal context into the transcript for future turns. + # Append user message to transcript. # Always append when the message is present and is from the user, # even on duplicate-suppressed retries (is_new_message=False). @@ -1253,8 +1357,16 @@ async def stream_chat_completion_baseline( if graphiti_enabled and user_id and message and is_user_message: from backend.copilot.graphiti.ingest import enqueue_conversation_turn + # Pass only the final assistant reply (after stripping tool-loop + # chatter) so derived-finding distillation sees the substantive + # response, not intermediate tool-planning text. _ingest_task = asyncio.create_task( - enqueue_conversation_turn(user_id, session_id, message) + enqueue_conversation_turn( + user_id, + session_id, + message, + assistant_msg=final_text if state else "", + ) ) _background_tasks.add(_ingest_task) _ingest_task.add_done_callback(_background_tasks.discard) @@ -1272,7 +1384,7 @@ async def stream_chat_completion_baseline( stop_reason=STOP_REASON_END_TURN, ) - if user_id and should_upload_transcript(user_id, transcript_covers_prefix): + if user_id and should_upload_transcript(user_id, transcript_upload_safe): await _upload_final_transcript( user_id=user_id, session_id=session_id, diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index 624abb9acd..4247c76c19 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -1,7 +1,7 @@ """Integration tests for baseline transcript flow. -Exercises the real helpers in ``baseline/service.py`` that download, -validate, load, append to, backfill, and upload the transcript. +Exercises the real helpers in ``baseline/service.py`` that restore, +validate, load, append to, backfill, and upload the CLI session. Storage is mocked via ``download_transcript`` / ``upload_transcript`` patches; no network access is required. """ @@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch import pytest from backend.copilot.baseline.service import ( + _append_gap_to_builder, _load_prior_transcript, _record_turn_to_transcript, _resolve_baseline_model, _upload_final_transcript, - is_transcript_stale, should_upload_transcript, ) +from backend.copilot.model import ChatMessage from backend.copilot.service import config from backend.copilot.transcript import ( STOP_REASON_END_TURN, @@ -54,6 +55,13 @@ def _make_transcript_content(*roles: str) -> str: return "\n".join(lines) + "\n" +def _make_session_messages(*roles: str) -> list[ChatMessage]: + """Build a list of ChatMessage objects matching the given roles.""" + return [ + ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles) + ] + + class TestResolveBaselineModel: """Model selection honours the per-request mode.""" @@ -68,92 +76,107 @@ class TestResolveBaselineModel: assert _resolve_baseline_model(None) == config.model def test_default_and_fast_models_same(self): - """SDK 0.1.58: both tiers now use the same model (anthropic/claude-sonnet-4).""" + """SDK defaults currently keep standard and fast on Sonnet 4.6.""" assert config.model == config.fast_model class TestLoadPriorTranscript: - """``_load_prior_transcript`` wraps the download + validate + load flow.""" + """``_load_prior_transcript`` wraps the CLI session restore + validate + load flow.""" @pytest.mark.asyncio async def test_loads_fresh_transcript(self): builder = TranscriptBuilder() content = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=content, message_count=2) + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="sdk" + ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True + assert dl is not None + assert dl.message_count == 2 assert builder.entry_count == 2 assert builder.last_entry_type == "assistant" @pytest.mark.asyncio - async def test_rejects_stale_transcript(self): - """msg_count strictly less than session-1 is treated as stale.""" + async def test_fills_gap_when_transcript_is_behind(self): + """When transcript covers fewer messages than session, gap is filled from DB.""" builder = TranscriptBuilder() content = _make_transcript_content("user", "assistant") - # session has 6 messages, transcript only covers 2 → stale. - download = TranscriptDownload(content=content, message_count=2) + # transcript covers 2 messages, session has 4 (plus current user turn = 5) + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="baseline" + ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=6, + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), transcript_builder=builder, ) - assert covers is False - assert builder.is_empty + assert covers is True + assert dl is not None + # 2 from transcript + 2 gap messages (user+assistant at positions 2,3) + assert builder.entry_count == 4 @pytest.mark.asyncio - async def test_missing_transcript_returns_false(self): + async def test_missing_transcript_allows_upload(self): + """Nothing in GCS → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", new=AsyncMock(return_value=None), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True + assert dl is None assert builder.is_empty @pytest.mark.asyncio - async def test_invalid_transcript_returns_false(self): + async def test_invalid_transcript_allows_upload(self): + """Corrupt file in GCS → overwriting with a valid one is better.""" builder = TranscriptBuilder() - download = TranscriptDownload( - content='{"type":"progress","uuid":"a"}\n', + restore = TranscriptDownload( + content=b'{"type":"progress","uuid":"a"}\n', message_count=1, + mode="sdk", ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True + assert dl is None assert builder.is_empty @pytest.mark.asyncio @@ -163,36 +186,39 @@ class TestLoadPriorTranscript: "backend.copilot.baseline.service.download_transcript", new=AsyncMock(side_effect=RuntimeError("boom")), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) assert covers is False + assert dl is None assert builder.is_empty @pytest.mark.asyncio async def test_zero_message_count_not_stale(self): - """When msg_count is 0 (unknown), staleness check is skipped.""" + """When msg_count is 0 (unknown), gap detection is skipped.""" builder = TranscriptBuilder() - download = TranscriptDownload( - content=_make_transcript_content("user", "assistant"), + restore = TranscriptDownload( + content=_make_transcript_content("user", "assistant").encode("utf-8"), message_count=0, + mode="sdk", ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=20, + session_messages=_make_session_messages(*["user"] * 20), transcript_builder=builder, ) assert covers is True + assert dl is not None assert builder.entry_count == 2 @@ -227,7 +253,7 @@ class TestUploadFinalTranscript: assert call_kwargs["user_id"] == "user-1" assert call_kwargs["session_id"] == "session-1" assert call_kwargs["message_count"] == 2 - assert "hello" in call_kwargs["content"] + assert b"hello" in call_kwargs["content"] @pytest.mark.asyncio async def test_skips_upload_when_builder_empty(self): @@ -374,17 +400,19 @@ class TestRoundTrip: @pytest.mark.asyncio async def test_full_round_trip(self): prior = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=prior, message_count=2) + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True @@ -424,11 +452,11 @@ class TestRoundTrip: upload_mock.assert_awaited_once() assert upload_mock.await_args is not None uploaded = upload_mock.await_args.kwargs["content"] - assert "new question" in uploaded - assert "new answer" in uploaded + assert b"new question" in uploaded + assert b"new answer" in uploaded # Original content preserved in the round trip. - assert "user message 0" in uploaded - assert "assistant message 1" in uploaded + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded @pytest.mark.asyncio async def test_backfill_append_guard(self): @@ -459,36 +487,6 @@ class TestRoundTrip: assert builder.entry_count == initial_count -class TestIsTranscriptStale: - """``is_transcript_stale`` gates prior-transcript loading.""" - - def test_none_download_is_not_stale(self): - assert is_transcript_stale(None, session_msg_count=5) is False - - def test_zero_message_count_is_not_stale(self): - """Legacy transcripts without msg_count tracking must remain usable.""" - dl = TranscriptDownload(content="", message_count=0) - assert is_transcript_stale(dl, session_msg_count=20) is False - - def test_stale_when_covers_less_than_prefix(self): - dl = TranscriptDownload(content="", message_count=2) - # session has 6 messages; transcript must cover at least 5 (6-1). - assert is_transcript_stale(dl, session_msg_count=6) is True - - def test_fresh_when_covers_full_prefix(self): - dl = TranscriptDownload(content="", message_count=5) - assert is_transcript_stale(dl, session_msg_count=6) is False - - def test_fresh_when_exceeds_prefix(self): - """Race: transcript ahead of session count is still acceptable.""" - dl = TranscriptDownload(content="", message_count=10) - assert is_transcript_stale(dl, session_msg_count=6) is False - - def test_boundary_equal_to_prefix_minus_one(self): - dl = TranscriptDownload(content="", message_count=5) - assert is_transcript_stale(dl, session_msg_count=6) is False - - class TestShouldUploadTranscript: """``should_upload_transcript`` gates the final upload.""" @@ -510,7 +508,7 @@ class TestShouldUploadTranscript: class TestTranscriptLifecycle: - """End-to-end: download → validate → build → upload. + """End-to-end: restore → validate → build → upload. Simulates the full transcript lifecycle inside ``stream_chat_completion_baseline`` by mocking the storage layer and @@ -519,27 +517,29 @@ class TestTranscriptLifecycle: @pytest.mark.asyncio async def test_full_lifecycle_happy_path(self): - """Fresh download, append a turn, upload covers the session.""" + """Fresh restore, append a turn, upload covers the session.""" builder = TranscriptBuilder() prior = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=prior, message_count=2) + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) upload_mock = AsyncMock(return_value=None) with ( patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ), patch( "backend.copilot.baseline.service.upload_transcript", new=upload_mock, ), ): - # --- 1. Download & load prior transcript --- - covers = await _load_prior_transcript( + # --- 1. Restore & load prior session --- + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True @@ -559,10 +559,7 @@ class TestTranscriptLifecycle: # --- 3. Gate + upload --- assert ( - should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers - ) - is True + should_upload_transcript(user_id="user-1", upload_safe=covers) is True ) await _upload_final_transcript( user_id="user-1", @@ -574,20 +571,21 @@ class TestTranscriptLifecycle: upload_mock.assert_awaited_once() assert upload_mock.await_args is not None uploaded = upload_mock.await_args.kwargs["content"] - assert "follow-up question" in uploaded - assert "follow-up answer" in uploaded + assert b"follow-up question" in uploaded + assert b"follow-up answer" in uploaded # Original prior-turn content preserved. - assert "user message 0" in uploaded - assert "assistant message 1" in uploaded + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded @pytest.mark.asyncio - async def test_lifecycle_stale_download_suppresses_upload(self): - """Stale download → covers=False → upload must be skipped.""" + async def test_lifecycle_stale_download_fills_gap(self): + """When transcript covers fewer messages, gap is filled rather than rejected.""" builder = TranscriptBuilder() - # session has 10 msgs but stored transcript only covers 2 → stale. + # session has 5 msgs but stored transcript only covers 2 → gap filled. stale = TranscriptDownload( - content=_make_transcript_content("user", "assistant"), + content=_make_transcript_content("user", "assistant").encode("utf-8"), message_count=2, + mode="baseline", ) upload_mock = AsyncMock(return_value=None) @@ -601,20 +599,18 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers = await _load_prior_transcript( + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=10, + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), transcript_builder=builder, ) - assert covers is False - # The caller's gate mirrors the production path. - assert ( - should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers) - is False - ) - upload_mock.assert_not_awaited() + assert covers is True + # Gap was filled: 2 from transcript + 2 gap messages + assert builder.entry_count == 4 @pytest.mark.asyncio async def test_lifecycle_anonymous_user_skips_upload(self): @@ -627,15 +623,11 @@ class TestTranscriptLifecycle: stop_reason=STOP_REASON_END_TURN, ) - assert ( - should_upload_transcript(user_id=None, transcript_covers_prefix=True) - is False - ) + assert should_upload_transcript(user_id=None, upload_safe=True) is False @pytest.mark.asyncio async def test_lifecycle_missing_download_still_uploads_new_content(self): - """No prior transcript → covers defaults to True in the service, - new turn should upload cleanly.""" + """No prior session → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() upload_mock = AsyncMock(return_value=None) with ( @@ -648,20 +640,117 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=1, + session_messages=_make_session_messages("user"), transcript_builder=builder, ) - # No download: covers is False, so the production path would - # skip upload. This protects against overwriting a future - # more-complete transcript with a single-turn snapshot. - assert covers is False + # Nothing in GCS → upload is safe so the first baseline turn + # can write the initial transcript snapshot. + assert upload_safe is True + assert dl is None assert ( - should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers - ) - is False + should_upload_transcript(user_id="user-1", upload_safe=upload_safe) + is True ) - upload_mock.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# _append_gap_to_builder +# --------------------------------------------------------------------------- + + +class TestAppendGapToBuilder: + """``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries.""" + + def test_user_message_appended(self): + builder = TranscriptBuilder() + msgs = [ChatMessage(role="user", content="hello")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + assert builder.last_entry_type == "user" + + def test_assistant_text_message_appended(self): + builder = TranscriptBuilder() + msgs = [ + ChatMessage(role="user", content="q"), + ChatMessage(role="assistant", content="answer"), + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 2 + assert builder.last_entry_type == "assistant" + assert "answer" in builder.to_jsonl() + + def test_assistant_with_tool_calls_appended(self): + """Assistant tool_calls are recorded as tool_use blocks in the transcript.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-1", + "type": "function", + "function": {"name": "my_tool", "arguments": '{"key":"val"}'}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "tool_use" in jsonl + assert "my_tool" in jsonl + assert "tc-1" in jsonl + + def test_assistant_invalid_json_args_uses_empty_dict(self): + """Malformed JSON in tool_call arguments falls back to {}.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-bad", + "type": "function", + "function": {"name": "bad_tool", "arguments": "not-json"}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert '"input":{}' in jsonl + + def test_assistant_empty_content_and_no_tools_uses_fallback(self): + """Assistant with no content and no tool_calls gets a fallback empty text block.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="assistant", content=None)] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "text" in jsonl + + def test_tool_role_with_tool_call_id_appended(self): + """Tool result messages are appended when tool_call_id is set.""" + builder = TranscriptBuilder() + # Need a preceding assistant tool_use entry + builder.append_user("use tool") + builder.append_assistant( + content_blocks=[ + {"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}} + ] + ) + msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 3 + assert "tool_result" in builder.to_jsonl() + + def test_tool_role_without_tool_call_id_skipped(self): + """Tool messages without tool_call_id are silently skipped.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 0 + + def test_tool_call_missing_function_key_uses_unknown_name(self): + """A tool_call dict with no 'function' key uses 'unknown' as the tool name.""" + builder = TranscriptBuilder() + # Tool call dict exists but 'function' sub-dict is missing entirely + msgs = [ + ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}]) + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "unknown" in jsonl diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index d5418bf872..36644de680 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -29,13 +29,13 @@ class ChatConfig(BaseSettings): # OpenAI API Configuration model: str = Field( - default="anthropic/claude-sonnet-4", + default="anthropic/claude-sonnet-4-6", description="Default model for extended thinking mode. " - "Changed from Opus ($15/$75 per M) to Sonnet ($3/$15 per M) — " - "5x cheaper. Override via CHAT_MODEL env var for Opus.", + "Uses Sonnet 4.6 as the balanced default. " + "Override via CHAT_MODEL env var if you want a different default.", ) fast_model: str = Field( - default="anthropic/claude-sonnet-4", + default="anthropic/claude-sonnet-4-6", description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.", ) title_model: str = Field( @@ -156,9 +156,10 @@ class ChatConfig(BaseSettings): "history compression. Falls back to compression when unavailable.", ) claude_agent_fallback_model: str = Field( - default="claude-sonnet-4-20250514", + default="", description="Fallback model when the primary model is unavailable (e.g. 529 " - "overloaded). The SDK automatically retries with this cheaper model.", + "overloaded). The SDK automatically retries with this cheaper model. " + "Empty string disables the fallback (no --fallback-model flag passed to CLI).", ) claude_agent_max_turns: int = Field( default=50, diff --git a/autogpt_platform/backend/backend/copilot/context.py b/autogpt_platform/backend/backend/copilot/context.py index 895aa6c4a1..7a22f02cb2 100644 --- a/autogpt_platform/backend/backend/copilot/context.py +++ b/autogpt_platform/backend/backend/copilot/context.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: # Allowed base directory for the Read tool. Public so service.py can use it # for sweep operations without depending on a private implementation detail. # Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's -# _projects_base() function. +# projects_base() function. _config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude") SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects")) diff --git a/autogpt_platform/backend/backend/copilot/graphiti/_format.py b/autogpt_platform/backend/backend/copilot/graphiti/_format.py index fb4a93e393..c6975c5c39 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/_format.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/_format.py @@ -18,15 +18,24 @@ def extract_temporal_validity(edge) -> tuple[str, str]: return str(valid_from), str(valid_to) -def extract_episode_body(episode, max_len: int = 500) -> str: - """Extract the body text from an episode object, truncated to *max_len*.""" - body = str( +def extract_episode_body_raw(episode) -> str: + """Extract the full body text from an episode object (no truncation). + + Use this when the body needs to be parsed as JSON (e.g. scope filtering + on MemoryEnvelope payloads). For display purposes, use + ``extract_episode_body()`` which truncates. + """ + return str( getattr(episode, "content", None) or getattr(episode, "body", None) or getattr(episode, "episode_body", None) or "" ) - return body[:max_len] + + +def extract_episode_body(episode, max_len: int = 500) -> str: + """Extract the body text from an episode object, truncated to *max_len*.""" + return extract_episode_body_raw(episode)[:max_len] def extract_episode_timestamp(episode) -> str: diff --git a/autogpt_platform/backend/backend/copilot/graphiti/client.py b/autogpt_platform/backend/backend/copilot/graphiti/client.py index 9710354915..65fcdb3abb 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/client.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/client.py @@ -3,6 +3,7 @@ import asyncio import logging import re +import weakref from cachetools import TTLCache @@ -13,8 +14,36 @@ logger = logging.getLogger(__name__) _GROUP_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") _MAX_GROUP_ID_LEN = 128 -_client_cache: TTLCache | None = None -_cache_lock = asyncio.Lock() + +# Graphiti clients wrap redis.asyncio connections whose internal Futures are +# pinned to the event loop they were first used on. The CoPilot executor runs +# one asyncio loop per worker thread, so a process-wide client cache would +# hand a loop-1-bound connection to a task running on loop 2 → RuntimeError +# "got Future attached to a different loop". Scope the cache (and its lock) +# per running loop so each loop gets its own clients. +class _LoopState: + __slots__ = ("cache", "lock") + + def __init__(self) -> None: + self.cache: TTLCache = _EvictingTTLCache( + maxsize=graphiti_config.client_cache_maxsize, + ttl=graphiti_config.client_cache_ttl, + ) + self.lock = asyncio.Lock() + + +_loop_state: "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _LoopState]" = ( + weakref.WeakKeyDictionary() +) + + +def _get_loop_state() -> _LoopState: + loop = asyncio.get_running_loop() + state = _loop_state.get(loop) + if state is None: + state = _LoopState() + _loop_state[loop] = state + return state def derive_group_id(user_id: str) -> str: @@ -88,13 +117,8 @@ class _EvictingTTLCache(TTLCache): def _get_cache() -> TTLCache: - global _client_cache - if _client_cache is None: - _client_cache = _EvictingTTLCache( - maxsize=graphiti_config.client_cache_maxsize, - ttl=graphiti_config.client_cache_ttl, - ) - return _client_cache + """Return the client cache for the current running event loop.""" + return _get_loop_state().cache async def get_graphiti_client(group_id: str): @@ -113,9 +137,10 @@ async def get_graphiti_client(group_id: str): from .falkordb_driver import AutoGPTFalkorDriver - cache = _get_cache() + state = _get_loop_state() + cache = state.cache - async with _cache_lock: + async with state.lock: if group_id in cache: return cache[group_id] diff --git a/autogpt_platform/backend/backend/copilot/graphiti/config.py b/autogpt_platform/backend/backend/copilot/graphiti/config.py index 94a452165a..08b533b6fc 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/config.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/config.py @@ -20,8 +20,10 @@ class GraphitiConfig(BaseSettings): """Configuration for Graphiti memory integration. All fields use the ``GRAPHITI_`` env-var prefix, e.g. ``GRAPHITI_ENABLED``. - LLM/embedder keys fall back to the platform-wide OpenRouter and OpenAI keys - when left empty so that operators don't need to manage separate credentials. + LLM/embedder keys fall back to the AutoPilot-dedicated keys + (``CHAT_API_KEY`` / ``CHAT_OPENAI_API_KEY``) so that memory costs are + tracked under AutoPilot, then to the platform-wide OpenRouter / OpenAI + keys as a last resort. """ model_config = SettingsConfigDict(env_prefix="GRAPHITI_", extra="allow") @@ -42,7 +44,7 @@ class GraphitiConfig(BaseSettings): ) llm_api_key: str = Field( default="", - description="API key for LLM — empty falls back to OPEN_ROUTER_API_KEY", + description="API key for LLM — empty falls back to CHAT_API_KEY, then OPEN_ROUTER_API_KEY", ) # Embedder (separate from LLM — embeddings go direct to OpenAI) @@ -53,7 +55,7 @@ class GraphitiConfig(BaseSettings): ) embedder_api_key: str = Field( default="", - description="API key for embedder — empty falls back to OPENAI_API_KEY", + description="API key for embedder — empty falls back to CHAT_OPENAI_API_KEY, then OPENAI_API_KEY", ) # Concurrency @@ -96,7 +98,9 @@ class GraphitiConfig(BaseSettings): def resolve_llm_api_key(self) -> str: if self.llm_api_key: return self.llm_api_key - return os.getenv("OPEN_ROUTER_API_KEY", "") + # Prefer the AutoPilot-dedicated key so memory costs are tracked + # separately from the platform-wide OpenRouter key. + return os.getenv("CHAT_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY", "") def resolve_llm_base_url(self) -> str: if self.llm_base_url: @@ -106,7 +110,9 @@ class GraphitiConfig(BaseSettings): def resolve_embedder_api_key(self) -> str: if self.embedder_api_key: return self.embedder_api_key - return os.getenv("OPENAI_API_KEY", "") + # Prefer the AutoPilot-dedicated OpenAI key so memory costs are + # tracked separately from the platform-wide OpenAI key. + return os.getenv("CHAT_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY", "") def resolve_embedder_base_url(self) -> str | None: if self.embedder_base_url: diff --git a/autogpt_platform/backend/backend/copilot/graphiti/config_test.py b/autogpt_platform/backend/backend/copilot/graphiti/config_test.py index 7c7a90d7bc..efe36c8586 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/config_test.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/config_test.py @@ -8,6 +8,8 @@ _ENV_VARS_TO_CLEAR = ( "GRAPHITI_FALKORDB_HOST", "GRAPHITI_FALKORDB_PORT", "GRAPHITI_FALKORDB_PASSWORD", + "CHAT_API_KEY", + "CHAT_OPENAI_API_KEY", "OPEN_ROUTER_API_KEY", "OPENAI_API_KEY", ) @@ -31,7 +33,15 @@ class TestResolveLlmApiKey: cfg = GraphitiConfig(llm_api_key="my-llm-key") assert cfg.resolve_llm_api_key() == "my-llm-key" - def test_falls_back_to_open_router_env( + def test_falls_back_to_chat_api_key_first( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CHAT_API_KEY", "autopilot-key") + monkeypatch.setenv("OPEN_ROUTER_API_KEY", "platform-key") + cfg = GraphitiConfig(llm_api_key="") + assert cfg.resolve_llm_api_key() == "autopilot-key" + + def test_falls_back_to_open_router_when_no_chat_key( self, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setenv("OPEN_ROUTER_API_KEY", "fallback-router-key") @@ -59,7 +69,15 @@ class TestResolveEmbedderApiKey: cfg = GraphitiConfig(embedder_api_key="my-embedder-key") assert cfg.resolve_embedder_api_key() == "my-embedder-key" - def test_falls_back_to_openai_api_key_env( + def test_falls_back_to_chat_openai_api_key_first( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CHAT_OPENAI_API_KEY", "autopilot-openai-key") + monkeypatch.setenv("OPENAI_API_KEY", "platform-openai-key") + cfg = GraphitiConfig(embedder_api_key="") + assert cfg.resolve_embedder_api_key() == "autopilot-openai-key" + + def test_falls_back_to_openai_when_no_chat_openai_key( self, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setenv("OPENAI_API_KEY", "fallback-openai-key") diff --git a/autogpt_platform/backend/backend/copilot/graphiti/context.py b/autogpt_platform/backend/backend/copilot/graphiti/context.py index 46f9855ab7..29d4e95f47 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/context.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/context.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from ._format import ( extract_episode_body, + extract_episode_body_raw, extract_episode_timestamp, extract_fact, extract_temporal_validity, @@ -68,7 +69,7 @@ async def _fetch(user_id: str, message: str) -> str | None: return _format_context(edges, episodes) -def _format_context(edges, episodes) -> str: +def _format_context(edges, episodes) -> str | None: sections: list[str] = [] if edges: @@ -82,12 +83,35 @@ def _format_context(edges, episodes) -> str: if episodes: ep_lines = [] for ep in episodes: + # Use raw body (no truncation) for scope parsing — truncated + # JSON from extract_episode_body() would fail json.loads(). + raw_body = extract_episode_body_raw(ep) + if _is_non_global_scope(raw_body): + continue + display_body = extract_episode_body(ep) ts = extract_episode_timestamp(ep) - body = extract_episode_body(ep) - ep_lines.append(f" - [{ts}] {body}") - sections.append( - "\n" + "\n".join(ep_lines) + "\n" - ) + ep_lines.append(f" - [{ts}] {display_body}") + if ep_lines: + sections.append( + "\n" + "\n".join(ep_lines) + "\n" + ) + + if not sections: + return None body = "\n\n".join(sections) return f"\n{body}\n" + + +def _is_non_global_scope(body: str) -> bool: + """Check if an episode body is a MemoryEnvelope with a non-global scope.""" + import json + + try: + data = json.loads(body) + if not isinstance(data, dict): + return False + scope = data.get("scope", "real:global") + return scope != "real:global" + except (json.JSONDecodeError, TypeError): + return False diff --git a/autogpt_platform/backend/backend/copilot/graphiti/context_test.py b/autogpt_platform/backend/backend/copilot/graphiti/context_test.py index 616fefa218..ce419b11ff 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/context_test.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/context_test.py @@ -1,12 +1,15 @@ """Tests for Graphiti warm context retrieval.""" import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, patch import pytest from . import context -from .context import fetch_warm_context +from ._format import extract_episode_body +from .context import _format_context, _is_non_global_scope, fetch_warm_context +from .memory_model import MemoryEnvelope, MemoryKind, SourceKind class TestFetchWarmContextEmptyUserId: @@ -52,3 +55,212 @@ class TestFetchWarmContextGeneralError: result = await fetch_warm_context("abc", "hello") assert result is None + + +# --------------------------------------------------------------------------- +# Bug: extract_episode_body() truncation breaks scope filtering +# --------------------------------------------------------------------------- + + +class TestFetchInternal: + """Test the internal _fetch function with mocked graphiti client.""" + + @pytest.mark.asyncio + async def test_returns_none_when_no_edges_or_episodes(self) -> None: + mock_client = AsyncMock() + mock_client.search.return_value = [] + mock_client.retrieve_episodes.return_value = [] + + with ( + patch.object(context, "derive_group_id", return_value="user_abc"), + patch.object( + context, + "get_graphiti_client", + new_callable=AsyncMock, + return_value=mock_client, + ), + ): + result = await context._fetch("test-user", "hello") + + assert result is None + + @pytest.mark.asyncio + async def test_returns_context_with_edges(self) -> None: + edge = SimpleNamespace( + fact="user likes python", + name="preference", + valid_at="2025-01-01", + invalid_at=None, + ) + mock_client = AsyncMock() + mock_client.search.return_value = [edge] + mock_client.retrieve_episodes.return_value = [] + + with ( + patch.object(context, "derive_group_id", return_value="user_abc"), + patch.object( + context, + "get_graphiti_client", + new_callable=AsyncMock, + return_value=mock_client, + ), + ): + result = await context._fetch("test-user", "hello") + + assert result is not None + assert "" in result + assert "user likes python" in result + + @pytest.mark.asyncio + async def test_returns_context_with_episodes(self) -> None: + ep = SimpleNamespace( + content="talked about coffee", + created_at="2025-06-01T00:00:00Z", + ) + mock_client = AsyncMock() + mock_client.search.return_value = [] + mock_client.retrieve_episodes.return_value = [ep] + + with ( + patch.object(context, "derive_group_id", return_value="user_abc"), + patch.object( + context, + "get_graphiti_client", + new_callable=AsyncMock, + return_value=mock_client, + ), + ): + result = await context._fetch("test-user", "hello") + + assert result is not None + assert "talked about coffee" in result + + +class TestFormatContextWithContent: + """Test _format_context with actual edges and episodes.""" + + def test_with_edges_only(self) -> None: + edge = SimpleNamespace( + fact="user likes coffee", + name="preference", + valid_at="2025-01-01", + invalid_at="present", + ) + result = _format_context(edges=[edge], episodes=[]) + assert result is not None + assert "" in result + assert "user likes coffee" in result + assert "" in result + + def test_with_episodes_only(self) -> None: + ep = SimpleNamespace( + content="plain conversation text", + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is not None + assert "" in result + assert "plain conversation text" in result + + def test_with_both_edges_and_episodes(self) -> None: + edge = SimpleNamespace( + fact="user likes coffee", + valid_at="2025-01-01", + invalid_at=None, + ) + ep = SimpleNamespace( + content="talked about coffee", + created_at="2025-06-01T00:00:00Z", + ) + result = _format_context(edges=[edge], episodes=[ep]) + assert result is not None + assert "" in result + assert "" in result + + def test_global_scope_episode_included(self) -> None: + envelope = MemoryEnvelope(content="global note", scope="real:global") + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is not None + assert "" in result + + def test_non_global_scope_episode_excluded(self) -> None: + envelope = MemoryEnvelope(content="project note", scope="project:crm") + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is None + + +class TestIsNonGlobalScopeEdgeCases: + """Verify _is_non_global_scope handles non-dict JSON without crashing.""" + + def test_list_json_treated_as_global(self) -> None: + assert _is_non_global_scope("[1, 2, 3]") is False + + def test_string_json_treated_as_global(self) -> None: + assert _is_non_global_scope('"just a string"') is False + + def test_null_json_treated_as_global(self) -> None: + assert _is_non_global_scope("null") is False + + def test_plain_text_treated_as_global(self) -> None: + assert _is_non_global_scope("plain conversation text") is False + + +class TestIsNonGlobalScopeTruncation: + """Verify _is_non_global_scope handles long MemoryEnvelope JSON. + + extract_episode_body() truncates to 500 chars. A MemoryEnvelope with + a long content field serializes to >500 chars, so the truncated string + is invalid JSON. The except clause falls through to return False, + incorrectly treating a project-scoped episode as global. + """ + + def test_long_envelope_with_non_global_scope_detected(self) -> None: + """Long MemoryEnvelope JSON should be parsed with raw (untruncated) body.""" + envelope = MemoryEnvelope( + content="x" * 600, + source_kind=SourceKind.user_asserted, + scope="project:crm", + memory_kind=MemoryKind.fact, + ) + full_json = envelope.model_dump_json() + assert len(full_json) > 500, "precondition: JSON must exceed truncation limit" + + # With the fix: _is_non_global_scope on the raw (untruncated) body + # correctly detects the non-global scope. + assert _is_non_global_scope(full_json) is True + + # Truncated body still fails — that's expected; callers must use raw body. + ep = SimpleNamespace(content=full_json) + truncated = extract_episode_body(ep) + assert _is_non_global_scope(truncated) is False # truncated JSON → parse fails + + +# --------------------------------------------------------------------------- +# Bug: empty wrapper when all episodes are non-global +# --------------------------------------------------------------------------- + + +class TestFormatContextEmptyWrapper: + """When all episodes are non-global and edges is empty, _format_context + should return None (no useful content) instead of an empty XML wrapper. + """ + + def test_returns_none_when_all_episodes_filtered(self) -> None: + envelope = MemoryEnvelope( + content="project-only note", + scope="project:crm", + ) + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + result = _format_context(edges=[], episodes=[ep]) + assert result is None diff --git a/autogpt_platform/backend/backend/copilot/graphiti/ingest.py b/autogpt_platform/backend/backend/copilot/graphiti/ingest.py index e36f521a35..58d086e55c 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/ingest.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/ingest.py @@ -7,17 +7,45 @@ ingestion while keeping it fire-and-forget from the caller's perspective. import asyncio import logging +import weakref from datetime import datetime, timezone from graphiti_core.nodes import EpisodeType from .client import derive_group_id, get_graphiti_client +from .memory_model import MemoryEnvelope, MemoryKind, MemoryStatus, SourceKind logger = logging.getLogger(__name__) -_user_queues: dict[str, asyncio.Queue] = {} -_user_workers: dict[str, asyncio.Task] = {} -_workers_lock = asyncio.Lock() + +# The CoPilot executor runs one asyncio loop per worker thread, and +# asyncio.Queue / asyncio.Lock / asyncio.Task are all bound to the loop they +# were first used on. A process-wide worker registry would hand a loop-1-bound +# Queue to a coroutine running on loop 2 → RuntimeError "Future attached to a +# different loop". Scope the registry per running loop so each loop has its +# own queues, workers, and lock. Entries auto-clean when the loop is GC'd. +class _LoopIngestState: + __slots__ = ("user_queues", "user_workers", "workers_lock") + + def __init__(self) -> None: + self.user_queues: dict[str, asyncio.Queue] = {} + self.user_workers: dict[str, asyncio.Task] = {} + self.workers_lock = asyncio.Lock() + + +_loop_state: ( + "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _LoopIngestState]" +) = weakref.WeakKeyDictionary() + + +def _get_loop_state() -> _LoopIngestState: + loop = asyncio.get_running_loop() + state = _loop_state.get(loop) + if state is None: + state = _LoopIngestState() + _loop_state[loop] = state + return state + # Idle workers are cleaned up after this many seconds of inactivity. _WORKER_IDLE_TIMEOUT = 60 @@ -37,6 +65,10 @@ async def _ingestion_worker(user_id: str, queue: asyncio.Queue) -> None: Exits after ``_WORKER_IDLE_TIMEOUT`` seconds of inactivity so that idle workers don't leak memory indefinitely. """ + # Snapshot the loop-local state at task start so cleanup always runs + # against the same state dict the worker was registered in, even if the + # worker is cancelled from another task. + state = _get_loop_state() try: while True: try: @@ -63,20 +95,25 @@ async def _ingestion_worker(user_id: str, queue: asyncio.Queue) -> None: raise finally: # Clean up so the next message re-creates the worker. - _user_queues.pop(user_id, None) - _user_workers.pop(user_id, None) + state.user_queues.pop(user_id, None) + state.user_workers.pop(user_id, None) async def enqueue_conversation_turn( user_id: str, session_id: str, user_msg: str, + assistant_msg: str = "", ) -> None: """Enqueue a conversation turn for async background ingestion. This returns almost immediately — the actual graphiti-core ``add_episode()`` call (which triggers LLM entity extraction) runs in a background worker task. + + If ``assistant_msg`` is provided and contains substantive findings + (not just acknowledgments), a separate derived-finding episode is + queued with ``source_kind=assistant_derived`` and ``status=tentative``. """ if not user_id: return @@ -117,6 +154,35 @@ async def enqueue_conversation_turn( "Graphiti ingestion queue full for user %s — dropping episode", user_id[:12], ) + return + + # --- Derived-finding lane --- + # If the assistant response is substantive, distill it into a + # structured finding with tentative status. + if assistant_msg and _is_finding_worthy(assistant_msg): + finding = _distill_finding(assistant_msg) + if finding: + envelope = MemoryEnvelope( + content=finding, + source_kind=SourceKind.assistant_derived, + memory_kind=MemoryKind.finding, + status=MemoryStatus.tentative, + provenance=f"session:{session_id}", + ) + try: + queue.put_nowait( + { + "name": f"finding_{session_id}", + "episode_body": envelope.model_dump_json(), + "source": EpisodeType.json, + "source_description": f"Assistant-derived finding in session {session_id}", + "reference_time": datetime.now(timezone.utc), + "group_id": group_id, + "custom_extraction_instructions": CUSTOM_EXTRACTION_INSTRUCTIONS, + } + ) + except asyncio.QueueFull: + pass # user canonical episode already queued — finding is best-effort async def enqueue_episode( @@ -126,12 +192,18 @@ async def enqueue_episode( name: str, episode_body: str, source_description: str = "Conversation memory", + is_json: bool = False, ) -> bool: """Enqueue an arbitrary episode for background ingestion. Used by ``MemoryStoreTool`` so that explicit memory-store calls go through the same per-user serialization queue as conversation turns. + Args: + is_json: When ``True``, ingest as ``EpisodeType.json`` (for + structured ``MemoryEnvelope`` payloads). Otherwise uses + ``EpisodeType.text``. + Returns ``True`` if the episode was queued, ``False`` if it was dropped. """ if not user_id: @@ -145,12 +217,14 @@ async def enqueue_episode( queue = await _ensure_worker(user_id) + source = EpisodeType.json if is_json else EpisodeType.text + try: queue.put_nowait( { "name": name, "episode_body": episode_body, - "source": EpisodeType.text, + "source": source, "source_description": source_description, "reference_time": datetime.now(timezone.utc), "group_id": group_id, @@ -170,18 +244,19 @@ async def _ensure_worker(user_id: str) -> asyncio.Queue: """Create a queue and worker for *user_id* if one doesn't exist. Returns the queue directly so callers don't need to look it up from - ``_user_queues`` (which avoids a TOCTOU race if the worker times out + the state dict (which avoids a TOCTOU race if the worker times out and cleans up between this call and the put_nowait). """ - async with _workers_lock: - if user_id not in _user_queues: + state = _get_loop_state() + async with state.workers_lock: + if user_id not in state.user_queues: q: asyncio.Queue = asyncio.Queue(maxsize=100) - _user_queues[user_id] = q - _user_workers[user_id] = asyncio.create_task( + state.user_queues[user_id] = q + state.user_workers[user_id] = asyncio.create_task( _ingestion_worker(user_id, q), name=f"graphiti-ingest-{user_id[:12]}", ) - return _user_queues[user_id] + return state.user_queues[user_id] async def _resolve_user_name(user_id: str) -> str: @@ -195,3 +270,58 @@ async def _resolve_user_name(user_id: str) -> str: except Exception: logger.debug("Could not resolve user name for %s", user_id[:12]) return "User" + + +# --- Derived-finding distillation --- + +# Phrases that indicate workflow chatter, not substantive findings. +_CHATTER_PREFIXES = ( + "done", + "got it", + "sure, i", + "sure!", + "ok", + "okay", + "i've created", + "i've updated", + "i've sent", + "i'll ", + "let me ", + "a sign-in button", + "please click", +) + +# Minimum length for an assistant message to be considered finding-worthy. +_MIN_FINDING_LENGTH = 150 + + +def _is_finding_worthy(assistant_msg: str) -> bool: + """Heuristic gate: is this assistant response worth distilling into a finding? + + Skips short acknowledgments, workflow chatter, and UI prompts. + Only passes through responses that likely contain substantive + factual content (research results, analysis, conclusions). + """ + if len(assistant_msg) < _MIN_FINDING_LENGTH: + return False + + lower = assistant_msg.lower().strip() + for prefix in _CHATTER_PREFIXES: + if lower.startswith(prefix): + return False + + return True + + +def _distill_finding(assistant_msg: str) -> str | None: + """Extract the core finding from an assistant response. + + For now, uses a simple truncation approach. Phase 3+ could use + a lightweight LLM call for proper distillation. + """ + # Take the first 500 chars as the finding content. + # Strip markdown formatting artifacts. + content = assistant_msg.strip() + if len(content) > 500: + content = content[:500] + "..." + return content if content else None diff --git a/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py b/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py index 3aebd283a5..6cb9c5fbaf 100644 --- a/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py +++ b/autogpt_platform/backend/backend/copilot/graphiti/ingest_test.py @@ -8,21 +8,9 @@ import pytest from . import ingest - -def _clean_module_state() -> None: - """Reset module-level state to avoid cross-test contamination.""" - ingest._user_queues.clear() - ingest._user_workers.clear() - - -@pytest.fixture(autouse=True) -def _reset_state(): - _clean_module_state() - yield - # Cancel any lingering worker tasks. - for task in ingest._user_workers.values(): - task.cancel() - _clean_module_state() +# Per-loop state in ingest.py auto-isolates between tests: pytest-asyncio +# creates a fresh event loop per test function, and the WeakKeyDictionary +# forgets the previous loop's state when it is GC'd. No manual reset needed. class TestIngestionWorkerExceptionHandling: @@ -75,7 +63,7 @@ class TestEnqueueConversationTurn: user_msg="hi", ) # No queue should have been created. - assert len(ingest._user_queues) == 0 + assert len(ingest._get_loop_state().user_queues) == 0 class TestQueueFullScenario: @@ -106,7 +94,7 @@ class TestQueueFullScenario: # Replace the queue with one that is already full. tiny_q: asyncio.Queue = asyncio.Queue(maxsize=1) tiny_q.put_nowait({"dummy": True}) - ingest._user_queues[user_id] = tiny_q + ingest._get_loop_state().user_queues[user_id] = tiny_q # Should not raise even though the queue is full. await ingest.enqueue_conversation_turn( @@ -162,6 +150,149 @@ class TestResolveUserName: assert name == "User" +class TestEnqueueEpisode: + @pytest.mark.asyncio + async def test_enqueue_episode_returns_true_on_success(self) -> None: + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + result = await ingest.enqueue_episode( + user_id="abc", + session_id="sess1", + name="test_ep", + episode_body="hello", + is_json=False, + ) + assert result is True + assert not q.empty() + + @pytest.mark.asyncio + async def test_enqueue_episode_returns_false_for_empty_user(self) -> None: + result = await ingest.enqueue_episode( + user_id="", + session_id="sess1", + name="test_ep", + episode_body="hello", + ) + assert result is False + + @pytest.mark.asyncio + async def test_enqueue_episode_returns_false_on_invalid_user(self) -> None: + with patch.object(ingest, "derive_group_id", side_effect=ValueError("bad id")): + result = await ingest.enqueue_episode( + user_id="bad", + session_id="sess1", + name="test_ep", + episode_body="hello", + ) + assert result is False + + @pytest.mark.asyncio + async def test_enqueue_episode_json_mode(self) -> None: + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + result = await ingest.enqueue_episode( + user_id="abc", + session_id="sess1", + name="test_ep", + episode_body='{"content": "hello"}', + is_json=True, + ) + assert result is True + item = q.get_nowait() + from graphiti_core.nodes import EpisodeType + + assert item["source"] == EpisodeType.json + + +class TestDerivedFindingLane: + @pytest.mark.asyncio + async def test_finding_worthy_message_enqueues_two_episodes(self) -> None: + """A substantive assistant message should enqueue both the user + episode and a derived-finding episode.""" + long_msg = "The analysis reveals significant growth patterns " + "x" * 200 + + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + patch( + "backend.copilot.graphiti.ingest._resolve_user_name", + new_callable=AsyncMock, + return_value="Alice", + ), + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + await ingest.enqueue_conversation_turn( + user_id="abc", + session_id="sess1", + user_msg="tell me about growth", + assistant_msg=long_msg, + ) + # Should have 2 items: user episode + derived finding + assert q.qsize() == 2 + + @pytest.mark.asyncio + async def test_short_assistant_msg_skips_finding(self) -> None: + with ( + patch.object(ingest, "derive_group_id", return_value="user_abc"), + patch.object( + ingest, "_ensure_worker", new_callable=AsyncMock + ) as mock_worker, + patch( + "backend.copilot.graphiti.ingest._resolve_user_name", + new_callable=AsyncMock, + return_value="Alice", + ), + ): + q: asyncio.Queue = asyncio.Queue(maxsize=100) + mock_worker.return_value = q + + await ingest.enqueue_conversation_turn( + user_id="abc", + session_id="sess1", + user_msg="hi", + assistant_msg="ok", + ) + # Only 1 item: the user episode (no finding for short msg) + assert q.qsize() == 1 + + +class TestDerivedFindingDistillation: + """_is_finding_worthy and _distill_finding gate derived-finding creation.""" + + def test_short_message_not_finding_worthy(self) -> None: + assert ingest._is_finding_worthy("ok") is False + + def test_chatter_prefix_not_finding_worthy(self) -> None: + assert ingest._is_finding_worthy("done " + "x" * 200) is False + + def test_long_substantive_message_is_finding_worthy(self) -> None: + msg = "The quarterly revenue analysis shows a 15% increase " + "x" * 200 + assert ingest._is_finding_worthy(msg) is True + + def test_distill_finding_truncates_to_500(self) -> None: + result = ingest._distill_finding("x" * 600) + assert result is not None + assert len(result) == 503 # 500 + "..." + + class TestWorkerIdleTimeout: @pytest.mark.asyncio async def test_worker_cleans_up_on_idle(self) -> None: @@ -169,9 +300,10 @@ class TestWorkerIdleTimeout: queue: asyncio.Queue = asyncio.Queue(maxsize=10) # Pre-populate state so cleanup can remove entries. - ingest._user_queues[user_id] = queue + state = ingest._get_loop_state() + state.user_queues[user_id] = queue task_sentinel = MagicMock() - ingest._user_workers[user_id] = task_sentinel + state.user_workers[user_id] = task_sentinel original_timeout = ingest._WORKER_IDLE_TIMEOUT ingest._WORKER_IDLE_TIMEOUT = 0.05 @@ -181,5 +313,5 @@ class TestWorkerIdleTimeout: ingest._WORKER_IDLE_TIMEOUT = original_timeout # After idle timeout the worker should have cleaned up. - assert user_id not in ingest._user_queues - assert user_id not in ingest._user_workers + assert user_id not in state.user_queues + assert user_id not in state.user_workers diff --git a/autogpt_platform/backend/backend/copilot/graphiti/memory_model.py b/autogpt_platform/backend/backend/copilot/graphiti/memory_model.py new file mode 100644 index 0000000000..d8105cb731 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/graphiti/memory_model.py @@ -0,0 +1,118 @@ +"""Generic memory metadata model for Graphiti episodes. + +Domain-agnostic envelope that works across business, fiction, research, +personal life, and arbitrary knowledge domains. Designed so retrieval +can distinguish user-asserted facts from assistant-derived findings +and filter by scope. +""" + +from enum import Enum + +from pydantic import BaseModel, Field + + +class SourceKind(str, Enum): + user_asserted = "user_asserted" + assistant_derived = "assistant_derived" + tool_observed = "tool_observed" + + +class MemoryKind(str, Enum): + fact = "fact" + preference = "preference" + rule = "rule" + finding = "finding" + plan = "plan" + event = "event" + procedure = "procedure" + + +class MemoryStatus(str, Enum): + active = "active" + tentative = "tentative" + superseded = "superseded" + contradicted = "contradicted" + + +class RuleMemory(BaseModel): + """Structured representation of a standing instruction or rule. + + Preserves the exact user intent rather than relying on LLM + extraction to reconstruct it from prose. + """ + + instruction: str = Field( + description="The actionable instruction (e.g. 'CC Sarah on client communications')" + ) + actor: str | None = Field( + default=None, description="Who performs or is subject to the rule" + ) + trigger: str | None = Field( + default=None, + description="When the rule applies (e.g. 'client-related communications')", + ) + negation: str | None = Field( + default=None, + description="What NOT to do, if applicable (e.g. 'do not use SMTP')", + ) + + +class ProcedureStep(BaseModel): + """A single step in a multi-step procedure.""" + + order: int = Field(description="Step number (1-based)") + action: str = Field(description="What to do in this step") + tool: str | None = Field(default=None, description="Tool or service to use") + condition: str | None = Field(default=None, description="When/if this step applies") + negation: str | None = Field( + default=None, description="What NOT to do in this step" + ) + + +class ProcedureMemory(BaseModel): + """Structured representation of a multi-step workflow. + + Steps with ordering, tools, conditions, and negations that don't + decompose cleanly into fact triples. + """ + + description: str = Field(description="What this procedure accomplishes") + steps: list[ProcedureStep] = Field(default_factory=list) + + +class MemoryEnvelope(BaseModel): + """Structured wrapper for explicit memory storage. + + Serialized as JSON and ingested via ``EpisodeType.json`` so that + Graphiti extracts entities from the ``content`` field while the + metadata fields survive as episode-level context. + + For ``memory_kind=rule``, populate the ``rule`` field with a + ``RuleMemory`` to preserve the exact instruction. For + ``memory_kind=procedure``, populate ``procedure`` with a + ``ProcedureMemory`` for structured steps. + """ + + content: str = Field( + description="The memory content — the actual fact, rule, or finding" + ) + source_kind: SourceKind = Field(default=SourceKind.user_asserted) + scope: str = Field( + default="real:global", + description="Namespace: 'real:global', 'project:', 'book:', 'session:<id>'", + ) + memory_kind: MemoryKind = Field(default=MemoryKind.fact) + status: MemoryStatus = Field(default=MemoryStatus.active) + confidence: float | None = Field(default=None, ge=0.0, le=1.0) + provenance: str | None = Field( + default=None, + description="Origin reference — session_id, tool_call_id, or URL", + ) + rule: RuleMemory | None = Field( + default=None, + description="Structured rule data — populate when memory_kind=rule", + ) + procedure: ProcedureMemory | None = Field( + default=None, + description="Structured procedure data — populate when memory_kind=procedure", + ) diff --git a/autogpt_platform/backend/backend/copilot/permissions.py b/autogpt_platform/backend/backend/copilot/permissions.py index cc01a124c4..a30ee282f7 100644 --- a/autogpt_platform/backend/backend/copilot/permissions.py +++ b/autogpt_platform/backend/backend/copilot/permissions.py @@ -89,6 +89,8 @@ ToolName = Literal[ "get_mcp_guide", "list_folders", "list_workspace_files", + "memory_forget_confirm", + "memory_forget_search", "memory_search", "memory_store", "move_agents_to_folder", diff --git a/autogpt_platform/backend/backend/copilot/prompting.py b/autogpt_platform/backend/backend/copilot/prompting.py index ec136933e9..ed436733dd 100644 --- a/autogpt_platform/backend/backend/copilot/prompting.py +++ b/autogpt_platform/backend/backend/copilot/prompting.py @@ -174,6 +174,7 @@ sandbox so `bash_exec` can access it for further processing. The exact sandbox path is shown in the `[Sandbox copy available at ...]` note. ### GitHub CLI (`gh`) and git +- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it. - If the user has connected their GitHub account, both `gh` and `git` are pre-authenticated — use them directly without any manual login step. `git` HTTPS operations (clone, push, pull) work automatically. diff --git a/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py b/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py index 5e1ef41979..212fca189b 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py @@ -8,7 +8,7 @@ Cross-mode transcript flow ========================== Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking -mode) read and write the same JSONL transcript store via +mode) read and write the same CLI session store via ``backend.copilot.transcript.upload_transcript`` / ``download_transcript``. @@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch: @pytest.mark.asyncio async def test_scenario_s_baseline_loads_sdk_transcript(self): - """Scenario S: SDK-written transcript is accepted by baseline's load helper.""" + """Scenario S: SDK-written CLI session is accepted by baseline's load helper.""" from backend.copilot.baseline.service import _load_prior_transcript + from backend.copilot.model import ChatMessage from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload from backend.copilot.transcript_builder import TranscriptBuilder @@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch: sdk_transcript = builder_sdk.to_jsonl() # Baseline session now has those 2 SDK messages + 1 new baseline message. - download = TranscriptDownload(content=sdk_transcript, message_count=2) + restore = TranscriptDownload( + content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk" + ) baseline_builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, # 2 SDK + 1 new baseline + session_messages=[ + ChatMessage(role="user", content="sdk-question"), + ChatMessage(role="assistant", content="sdk-answer"), + ChatMessage(role="user", content="baseline-question"), + ], transcript_builder=baseline_builder, ) - # Transcript is valid and covers the prefix. + # CLI session is valid and covers the prefix. assert covers is True + assert dl is not None assert baseline_builder.entry_count == 2 @pytest.mark.asyncio async def test_scenario_s_stale_sdk_transcript_not_loaded(self): - """Scenario S (stale): SDK transcript is stale — baseline does not load it. + """Scenario S (stale): SDK CLI session is stale — baseline does not load it. - If SDK mode produced more turns than the transcript captured (e.g. - upload failed on one turn), the baseline rejects the stale transcript + If SDK mode produced more turns than the session captured (e.g. + upload failed on one turn), the baseline rejects the stale session to avoid injecting an incomplete history. """ from backend.copilot.baseline.service import _load_prior_transcript + from backend.copilot.model import ChatMessage from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload from backend.copilot.transcript_builder import TranscriptBuilder @@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch: ) sdk_transcript = builder_sdk.to_jsonl() - # Transcript covers only 2 messages but session has 10 (many SDK turns). - download = TranscriptDownload(content=sdk_transcript, message_count=2) + # Session covers only 2 messages but session has 10 (many SDK turns). + # With watermark=2 and 10 total messages, detect_gap will fill the gap + # by appending messages 2..8 (positions 2 to total-2). + restore = TranscriptDownload( + content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk" + ) + + # Build a session with 10 alternating user/assistant messages + current user + session_messages = [ + ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}") + for i in range(10) + ] baseline_builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=10, + session_messages=session_messages, transcript_builder=baseline_builder, ) - # Stale transcript must be rejected. - assert covers is False - assert baseline_builder.is_empty + # With gap filling, covers is True and gap messages are appended. + assert covers is True + assert dl is not None + # 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn) + assert baseline_builder.entry_count == 9 diff --git a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py index 9305320fea..17b54797b8 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py @@ -86,15 +86,14 @@ class TestResolveFallbackModel: assert result == "claude-sonnet-4.5-20250514" def test_default_value(self): - """Default fallback model resolves to a valid string.""" + """Default fallback model resolves to None (disabled by default).""" cfg = _make_config() with patch(f"{_SVC}.config", cfg): from backend.copilot.sdk.service import _resolve_fallback_model result = _resolve_fallback_model() - assert result is not None - assert "sonnet" in result.lower() or "claude" in result.lower() + assert result is None # --------------------------------------------------------------------------- @@ -198,8 +197,7 @@ class TestConfigDefaults: def test_fallback_model_default(self): cfg = _make_config() - assert cfg.claude_agent_fallback_model - assert "sonnet" in cfg.claude_agent_fallback_model.lower() + assert cfg.claude_agent_fallback_model == "" def test_max_turns_default(self): cfg = _make_config() diff --git a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py index a48d7def3d..60c65f00ce 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py @@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.copilot.transcript import ( + TranscriptDownload, _flatten_assistant_content, _flatten_tool_result_content, _messages_to_transcript, @@ -999,14 +1000,15 @@ def _make_sdk_patches( f"{_SVC}.download_transcript", dict( new_callable=AsyncMock, - return_value=MagicMock(content=original_transcript, message_count=2), + return_value=TranscriptDownload( + content=original_transcript.encode("utf-8"), + message_count=2, + mode="sdk", + ), ), ), - ( - f"{_SVC}.restore_cli_session", - dict(new_callable=AsyncMock, return_value=True), - ), - (f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)), + (f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)), + (f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)), (f"{_SVC}.validate_transcript", dict(return_value=True)), ( f"{_SVC}.compact_transcript", @@ -1037,7 +1039,6 @@ def _make_sdk_patches( claude_agent_fallback_model=None, ), ), - (f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)), (f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)), ] @@ -1914,14 +1915,14 @@ class TestStreamChatCompletionRetryIntegration: compacted_transcript=None, client_side_effect=_client_factory, ) - # Override restore_cli_session to return False (CLI native session unavailable) + # Override download_transcript to return None (CLI native session unavailable) patches = [ ( ( - f"{_SVC}.restore_cli_session", - dict(new_callable=AsyncMock, return_value=False), + f"{_SVC}.download_transcript", + dict(new_callable=AsyncMock, return_value=None), ) - if p[0] == f"{_SVC}.restore_cli_session" + if p[0] == f"{_SVC}.download_transcript" else p ) for p in patches @@ -1944,7 +1945,7 @@ class TestStreamChatCompletionRetryIntegration: # captured_options holds {"options": ClaudeAgentOptions}, so check # the attribute directly rather than dict keys. assert not getattr(captured_options.get("options"), "resume", None), ( - f"--resume was set even though restore_cli_session returned False: " + f"--resume was set even though download_transcript returned None: " f"{captured_options}" ) assert any(isinstance(e, StreamStart) for e in events) diff --git a/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py index e5ba184f4f..666e55fbba 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py +++ b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py @@ -365,7 +365,7 @@ def create_security_hooks( trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50) # Sanitize untrusted input: strip control chars for logging AND # for the value passed downstream. read_compacted_entries() - # validates against _projects_base() as defence-in-depth, but + # validates against projects_base() as defence-in-depth, but # sanitizing here prevents log injection and rejects obviously # malformed paths early. transcript_path = _sanitize( diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index a55380ea6e..936f1f8df1 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -16,6 +16,7 @@ import uuid from collections.abc import AsyncGenerator, AsyncIterator from dataclasses import dataclass from dataclasses import field as dataclass_field +from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple, NotRequired, cast if TYPE_CHECKING: @@ -92,12 +93,15 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path from ..tracking import track_user_message from ..transcript import ( _run_compression, + TranscriptDownload, cleanup_stale_project_dirs, + cli_session_path, compact_transcript, download_transcript, + extract_context_messages, + projects_base, read_compacted_entries, - restore_cli_session, - upload_cli_session, + strip_for_upload, upload_transcript, validate_transcript, ) @@ -849,6 +853,181 @@ def _make_sdk_cwd(session_id: str) -> str: return cwd +def _write_cli_session_to_disk( + content: bytes, + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> bool: + """Write downloaded CLI session bytes to disk so the CLI can --resume. + + Returns True on success, False if the path is invalid or the write fails. + Path-traversal guard: rejects paths outside the CLI projects base. + """ + session_file = cli_session_path(sdk_cwd, session_id) + real_path = os.path.realpath(session_file) + _pbase = projects_base() + if not real_path.startswith(_pbase + os.sep): + logger.warning( + "%s CLI session restore path outside projects base: %s", + log_prefix, + os.path.basename(session_file), + ) + return False + try: + os.makedirs(os.path.dirname(real_path), exist_ok=True) + Path(real_path).write_bytes(content) + logger.info( + "%s Wrote CLI session to disk (%dB) for --resume", + log_prefix, + len(content), + ) + return True + except OSError as e: + logger.warning( + "%s Failed to write CLI session file %s: %s", + log_prefix, + os.path.basename(session_file), + e.strerror or str(e), + ) + return False + + +def _read_cli_session_from_disk( + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> bytes | None: + """Read the CLI session JSONL file from disk after the SDK turn. + + Returns the file bytes, or None if the file is missing, outside the + projects base, or unreadable. + Path-traversal guard: rejects paths outside the CLI projects base. + """ + session_file = cli_session_path(sdk_cwd, session_id) + real_path = os.path.realpath(session_file) + _pbase = projects_base() + if not real_path.startswith(_pbase + os.sep): + logger.warning( + "%s CLI session file outside projects base, skipping upload: %s", + log_prefix, + os.path.basename(real_path), + ) + return None + try: + raw_bytes = Path(real_path).read_bytes() + except FileNotFoundError: + logger.debug( + "%s CLI session file not found, skipping upload: %s", + log_prefix, + os.path.basename(session_file), + ) + return None + except OSError as e: + logger.warning( + "%s Failed to read CLI session file %s: %s", + log_prefix, + os.path.basename(session_file), + e.strerror or str(e), + ) + return None + + # Strip stale thinking blocks and metadata entries before uploading. + # Thinking blocks from non-last turns can be massive; keeping them causes + # the CLI to auto-compact its session when the context window fills up, + # silently losing conversation history. + try: + raw_text = raw_bytes.decode("utf-8") + stripped_text = strip_for_upload(raw_text) + stripped_bytes = stripped_text.encode("utf-8") + except UnicodeDecodeError: + logger.warning("%s CLI session is not valid UTF-8, uploading raw", log_prefix) + return raw_bytes + except (OSError, ValueError) as e: + # OSError: encode/decode I/O failure; ValueError: malformed JSONL in strip. + # Other unexpected exceptions are not silently swallowed here so they propagate + # to the outer OSError handler and are logged with exc_info. + logger.warning( + "%s Failed to strip CLI session, uploading raw: %s", log_prefix, e + ) + return raw_bytes + + if len(stripped_bytes) < len(raw_bytes): + # Write back locally so same-pod turns also benefit. + try: + Path(real_path).write_bytes(stripped_bytes) + logger.info( + "%s Stripped CLI session: %dB → %dB", + log_prefix, + len(raw_bytes), + len(stripped_bytes), + ) + except OSError as e: + # write_bytes failed — stripped content is still valid for GCS upload even + # though the local write-back failed (same-pod optimization silently skipped). + logger.warning( + "%s Failed to write back stripped CLI session: %s", + log_prefix, + e.strerror or str(e), + ) + return stripped_bytes + + +def _process_cli_restore( + cli_restore: TranscriptDownload, + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> tuple[str, bool]: + """Validate and write a restored CLI session to disk. + + Decodes bytes → UTF-8, strips progress entries and stale thinking blocks, + validates the result, then writes the stripped content to disk so the CLI + can ``--resume`` from it. + + Returns ``(stripped_content, success)`` where ``success=False`` means the + content was invalid or the disk write failed (caller should skip --resume). + """ + try: + raw_bytes = cli_restore.content + raw_str = ( + raw_bytes.decode("utf-8") if isinstance(raw_bytes, bytes) else raw_bytes + ) + except UnicodeDecodeError: + logger.warning( + "%s CLI session content is not valid UTF-8, skipping", log_prefix + ) + return "", False + + stripped = strip_for_upload(raw_str) + is_valid = validate_transcript(stripped) + # Use len(raw_str) rather than len(cli_restore.content) so the unit is always + # characters (raw_str is always str at this point regardless of input type). + # lines_stripped = original lines minus remaining lines after stripping. + _original_lines = len(raw_str.strip().split("\n")) if raw_str.strip() else 0 + _remaining_lines = len(stripped.strip().split("\n")) if stripped.strip() else 0 + logger.info( + "%s Restored CLI session: %dB raw, %d lines stripped, msg_count=%d, valid=%s", + log_prefix, + len(raw_str), + _original_lines - _remaining_lines, + cli_restore.message_count, + is_valid, + ) + if not is_valid: + logger.warning( + "%s CLI session content invalid after strip — running without --resume", + log_prefix, + ) + return "", False + + stripped_bytes = stripped.encode("utf-8") + if not _write_cli_session_to_disk(stripped_bytes, sdk_cwd, session_id, log_prefix): + return "", False + + return stripped, True + + async def _cleanup_sdk_tool_results(cwd: str) -> None: """Remove SDK session artifacts for a specific working directory. @@ -922,8 +1101,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]: result.append(block) else: logger.warning( - f"[SDK] Unknown content block type: {type(block).__name__}. " - f"This may indicate a new SDK version with additional block types." + "[SDK] Unknown content block type: %s." + " This may indicate a new SDK version with additional block types.", + type(block).__name__, ) return result @@ -978,10 +1158,11 @@ async def _compress_messages( if result.was_compacted: logger.info( - f"[SDK] Context compacted: {result.original_token_count} -> " - f"{result.token_count} tokens " - f"({result.messages_summarized} summarized, " - f"{result.messages_dropped} dropped)" + "[SDK] Context compacted: %d -> %d tokens (%d summarized, %d dropped)", + result.original_token_count, + result.token_count, + result.messages_summarized, + result.messages_dropped, ) # Convert compressed dicts back to ChatMessages return [ @@ -1048,11 +1229,17 @@ def _session_messages_to_transcript(messages: list[ChatMessage]) -> str: ) if blocks: builder.append_assistant(blocks) - elif msg.role == "tool" and msg.tool_call_id: - builder.append_tool_result( - tool_use_id=msg.tool_call_id, - content=msg.content or "", - ) + elif msg.role == "tool": + if msg.tool_call_id: + builder.append_tool_result( + tool_use_id=msg.tool_call_id, + content=msg.content or "", + ) + else: + # Malformed tool message — no tool_call_id to link to an + # assistant tool_use block. Skip to avoid an unmatched + # tool_result entry in the builder (which would confuse --resume). + logger.warning("[SDK] Skipping tool gap message with no tool_call_id") return builder.to_jsonl() @@ -1098,6 +1285,7 @@ async def _build_query_message( transcript_msg_count: int, session_id: str, target_tokens: int | None = None, + prior_messages: "list[ChatMessage] | None" = None, ) -> tuple[str, bool]: """Build the query message with appropriate context. @@ -1203,15 +1391,16 @@ async def _build_query_message( ) return current_message, False + source = prior_messages if prior_messages is not None else prior logger.warning( - "[SDK] [%s] No --resume for %d-message session — compressing" - " full session history (pod affinity issue or first turn after" - " restore failure); target_tokens=%s", + "[SDK] [%s] No --resume for %d-message session — compressing context " + "(source=%s, target_tokens=%s)", session_id[:8], msg_count, + "transcript+gap" if prior_messages is not None else "full-db", target_tokens, ) - compressed, was_compressed = await _compress_messages(prior, target_tokens) + compressed, was_compressed = await _compress_messages(source, target_tokens) history_context = _format_conversation_context(compressed) if history_context: logger.info( @@ -1228,7 +1417,7 @@ async def _build_query_message( "[SDK] [%s] Fallback context empty after compression" " (%d messages) — sending message without history", session_id[:8], - len(prior), + len(source), ) return current_message, False @@ -1994,6 +2183,39 @@ async def _run_stream_attempt( # --- Dispatch adapter responses --- adapter_responses = state.adapter.convert_message(sdk_msg) + + # Pre-create the new assistant message in the session BEFORE + # yielding any events so it survives a GeneratorExit (client + # disconnect) that interrupts the yield loop at StreamStartStep. + # + # Without this, the sequence is: + # tool result saved → intermediate flush → StreamStartStep + # yield → GeneratorExit → finally saves session with + # last_role=tool (the text response was generated but never + # appended because _dispatch_response(StreamTextDelta) was + # skipped). + # + # We only pre-create when: + # 1. Tool results were received this turn (has_tool_results). + # 2. The prior assistant message is already appended + # (has_appended_assistant) — so this is a post-tool turn. + # 3. This batch contains StreamTextDelta — text IS coming, so + # we won't leave a spurious empty message for tool-only turns. + # + # Subsequent StreamTextDelta dispatches accumulate content into + # acc.assistant_response in-place (ChatMessage is mutable), so + # the DB record is updated without a second append. + if ( + acc.has_tool_results + and acc.has_appended_assistant + and any(isinstance(r, StreamTextDelta) for r in adapter_responses) + ): + acc.assistant_response = ChatMessage(role="assistant", content="") + acc.accumulated_tool_calls = [] + acc.has_tool_results = False + ctx.session.messages.append(acc.assistant_response) + # acc.has_appended_assistant stays True — placeholder is live + # When StreamFinish is in this batch (ResultMessage), flush any # text buffered by the thinking stripper and inject it as a # StreamTextDelta BEFORE the StreamTextEnd so the Vercel AI SDK @@ -2200,6 +2422,163 @@ async def _seed_transcript( return _seeded, True, len(_prior) +@dataclass +class _RestoreResult: + """Return value from ``_restore_cli_session_for_turn``.""" + + transcript_content: str = "" + transcript_covers_prefix: bool = True + use_resume: bool = False + resume_file: str | None = None + transcript_msg_count: int = 0 + baseline_download: "TranscriptDownload | None" = None + context_messages: "list[ChatMessage] | None" = None + + +async def _restore_cli_session_for_turn( + user_id: str | None, + session_id: str, + session: "ChatSession", + sdk_cwd: str, + transcript_builder: "TranscriptBuilder", + log_prefix: str, +) -> _RestoreResult: + """Download, validate and restore a CLI session for ``--resume`` on this turn. + + Performs a single GCS round-trip to fetch the session bytes + message_count + watermark. Falls back to DB-message reconstruction when GCS has no session + (first turn or upload missed). + + Returns a ``_RestoreResult`` with all transcript-related state ready for the + caller to merge into its local variables. + """ + result = _RestoreResult() + + if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1): + return result + + try: + cli_restore = await download_transcript( + user_id, session_id, log_prefix=log_prefix + ) + except Exception as restore_err: + logger.warning( + "%s CLI session restore failed, continuing without --resume: %s", + log_prefix, + restore_err, + ) + cli_restore = None + + # Only attempt --resume for SDK-written transcripts. + # Baseline-written transcripts use TranscriptBuilder format (synthetic IDs, + # stripped fields) that may not be valid for --resume. + if cli_restore is not None and cli_restore.mode != "sdk": + logger.info( + "%s Transcript written by mode=%r — skipping --resume, " + "will use transcript content + gap for context", + log_prefix, + cli_restore.mode, + ) + result.baseline_download = cli_restore # keep for extract_context_messages + cli_restore = None + + # Validate, strip, and write to disk — delegate to helper to reduce + # function complexity. Writing an invalid/corrupt file to disk then + # falling back to "no --resume" would cause the CLI to fail with + # "Session ID already in use" because the file exists at the expected + # session path, so we validate BEFORE any disk write. + stripped = "" + if cli_restore is not None and sdk_cwd: + stripped, ok = _process_cli_restore( + cli_restore, sdk_cwd, session_id, log_prefix + ) + if not ok: + result.transcript_covers_prefix = False + cli_restore = None + + if cli_restore is None and sdk_cwd: + # Validation failed or GCS returned no session. Delete any + # existing local session file so the CLI doesn't reject the + # session_id with "Session ID already in use". T1 may have + # left a valid file at this path; we clear it so the fallback + # path (session_id= without --resume) can create a new session. + _stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id)) + if Path(_stale_path).exists() and _stale_path.startswith( + projects_base() + os.sep + ): + try: + Path(_stale_path).unlink() + logger.debug( + "%s Removed stale local CLI session file for clean fallback", + log_prefix, + ) + except OSError as _unlink_err: + logger.debug( + "%s Failed to remove stale local session file: %s", + log_prefix, + _unlink_err, + ) + + if cli_restore is not None: + result.transcript_content = stripped + transcript_builder.load_previous(stripped, log_prefix=log_prefix) + result.use_resume = True + result.resume_file = session_id + result.transcript_msg_count = cli_restore.message_count + return result + + # No valid --resume source (mode="baseline" or no GCS file). + # Build context from transcript content + gap, falling back to full DB. + # extract_context_messages handles both: non-None baseline_download uses + # the compacted transcript + gap; None falls back to all prior DB messages. + context_msgs = extract_context_messages(result.baseline_download, session.messages) + result.context_messages = context_msgs + result.transcript_msg_count = ( + result.baseline_download.message_count + if result.baseline_download is not None + and result.baseline_download.message_count > 0 + else len(session.messages) - 1 + ) + result.transcript_covers_prefix = True + logger.info( + "%s Context built from %s: %d messages (transcript watermark=%d, " + "will inject as <conversation_history>)", + log_prefix, + ( + "baseline transcript + gap" + if result.baseline_download is not None + else "DB fallback" + ), + len(context_msgs), + result.transcript_msg_count, + ) + + # Load baseline transcript content into builder so the upload path has accurate state. + # Also sets result.transcript_content so the _seed_transcript guard in the caller + # (``not transcript_content``) does not overwrite this builder state with a DB + # reconstruction — which would duplicate entries since load_previous appends. + if result.baseline_download is not None: + try: + raw_for_builder = result.baseline_download.content + if isinstance(raw_for_builder, bytes): + raw_for_builder = raw_for_builder.decode("utf-8") + stripped = strip_for_upload(raw_for_builder) + if validate_transcript(stripped): + transcript_builder.load_previous(stripped, log_prefix=log_prefix) + result.transcript_content = stripped + except (UnicodeDecodeError, ValueError, OSError) as _load_err: + # UnicodeDecodeError: non-UTF-8 content; ValueError: malformed JSONL in + # strip_for_upload; OSError: encode/decode I/O failure. Unexpected + # exceptions propagate so programming errors are not silently masked. + logger.debug( + "%s Could not load baseline transcript into builder: %s", + log_prefix, + _load_err, + ) + + return result + + async def stream_chat_completion_sdk( session_id: str, message: str | None = None, @@ -2337,6 +2716,7 @@ async def stream_chat_completion_sdk( turn_cache_creation_tokens = 0 turn_cost_usd: float | None = None graphiti_enabled = False + pre_attempt_msg_count = 0 # Defaults ensure the finally block can always reference these safely even when # an early return (e.g. sdk_cwd error) skips their normal assignment below. sdk_model: str | None = None @@ -2393,28 +2773,9 @@ async def stream_chat_completion_sdk( return sandbox - async def _fetch_transcript(): - """Download transcript for --resume if applicable.""" - if not ( - config.claude_agent_use_resume and user_id and len(session.messages) > 1 - ): - return None - try: - return await download_transcript( - user_id, session_id, log_prefix=log_prefix - ) - except Exception as transcript_err: - logger.warning( - "%s Transcript download failed, continuing without --resume: %s", - log_prefix, - transcript_err, - ) - return None - - e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather( + e2b_sandbox, (base_system_prompt, understanding) = await asyncio.gather( _setup_e2b(), _build_system_prompt(user_id if not has_history else None), - _fetch_transcript(), ) use_e2b = e2b_sandbox is not None @@ -2439,95 +2800,17 @@ async def stream_chat_completion_sdk( warm_ctx = await fetch_warm_context(user_id, message or "") or "" - # Process transcript download result and restore CLI native session. - # The CLI native session file (uploaded after each turn) is the - # source of truth for --resume. Our custom JSONL (TranscriptEntry) - # is loaded into the builder for future upload_transcript calls. - transcript_msg_count = 0 - if dl: - is_valid = validate_transcript(dl.content) - dl_lines = dl.content.strip().split("\n") if dl.content else [] - logger.info( - "%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s", - log_prefix, - len(dl.content), - len(dl_lines), - dl.message_count, - is_valid, - ) - if is_valid: - # Load previous FULL context into builder for state tracking. - transcript_content = dl.content - transcript_builder.load_previous(dl.content, log_prefix=log_prefix) - # Restore CLI's native session file so --resume session_id works. - # Falls back gracefully if not available (first turn or upload missed). - # user_id is guaranteed non-None here: _fetch_transcript only sets dl - # when `config.claude_agent_use_resume and user_id` is truthy. - cli_restored = user_id is not None and await restore_cli_session( - user_id, session_id, sdk_cwd, log_prefix=log_prefix - ) - if cli_restored: - use_resume = True - resume_file = session_id # CLI --resume expects UUID, not file path - transcript_msg_count = dl.message_count - logger.info( - "%s Using --resume %s (%dB transcript, msg_count=%d)", - log_prefix, - session_id[:8], - len(dl.content), - transcript_msg_count, - ) - else: - # Builder loaded but CLI native session not available. - # --resume will not be used this turn; upload after turn - # will seed the native session for the next turn. - # - # Still record transcript_msg_count so _build_query_message - # can use the transcript-aware gap path (inject only new - # messages since the transcript end) instead of compressing - # the full DB history. This avoids prompt-too-long on - # large sessions where the CLI session is temporarily - # unavailable (e.g. mixed-version rolling deployment). - transcript_msg_count = dl.message_count - logger.info( - "%s CLI session not restored — running without" - " --resume this turn (transcript_msg_count=%d for" - " gap-aware fallback)", - log_prefix, - transcript_msg_count, - ) - else: - logger.warning("%s Transcript downloaded but invalid", log_prefix) - transcript_covers_prefix = False - elif config.claude_agent_use_resume and user_id and len(session.messages) > 1: - # No transcript in storage — reconstruct from DB messages as a - # last-resort fallback (e.g., first turn after a crash or transition). - # This path loses tool call IDs and structural fidelity but prevents - # a completely context-free response for established sessions. - prior = session.messages[:-1] - reconstructed = _session_messages_to_transcript(prior) - if reconstructed: - # Populate builder only; no --resume since there is no CLI - # native session to restore. The transcript builder state is - # still useful for the upload that seeds future native sessions. - transcript_content = reconstructed - transcript_builder.load_previous(reconstructed, log_prefix=log_prefix) - transcript_msg_count = len(prior) - transcript_covers_prefix = True - logger.info( - "%s Reconstructed transcript from %d session messages " - "(no CLI native session — running without --resume this turn)", - log_prefix, - len(prior), - ) - else: - logger.warning( - "%s No transcript available and reconstruction produced empty" - " output (%d messages in session)", - log_prefix, - len(session.messages), - ) - transcript_covers_prefix = False + # Restore CLI session — single GCS round-trip covers both --resume and builder state. + # message_count watermark lives in the companion .meta.json alongside the session file. + _restore = await _restore_cli_session_for_turn( + user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix + ) + transcript_content = _restore.transcript_content + transcript_covers_prefix = _restore.transcript_covers_prefix + use_resume = _restore.use_resume + resume_file = _restore.resume_file + transcript_msg_count = _restore.transcript_msg_count + restore_context_messages = _restore.context_messages yield StreamStart(messageId=message_id, sessionId=session_id) @@ -2646,14 +2929,14 @@ async def stream_chat_completion_sdk( else: # Set session_id whenever NOT resuming so the CLI writes the # native session file to a predictable path for - # upload_cli_session() after the turn. This covers: + # upload_transcript() after the turn. This covers: # • T1 fresh: no prior history, first SDK turn. # • Mode-switch T1: has_history=True (prior baseline turns in # DB) but no CLI session file was ever uploaded — the CLI has # never been invoked with this session_id before. # • T2+ without --resume (restore failed): no session file was - # restored to local storage (restore_cli_session returned - # False), so no conflict with an existing file. + # restored to local storage (download_transcript returned + # None), so no conflict with an existing file. # When --resume is active the session_id is already implied by # the resume file; passing it again would be rejected by the CLI. sdk_options_kwargs["session_id"] = session_id @@ -2746,6 +3029,7 @@ async def stream_chat_completion_sdk( use_resume, transcript_msg_count, session_id, + prior_messages=restore_context_messages, ) # If files are attached, prepare them: images become vision # content blocks in the user message, other files go to sdk_cwd. @@ -2755,6 +3039,9 @@ async def stream_chat_completion_sdk( if attachments.hint: query_message = f"{query_message}\n\n{attachments.hint}" + # warm_ctx is injected via inject_user_context above (warm_ctx= kwarg). + # No separate injection needed here. + # When running without --resume and no prior transcript in storage, # seed the transcript builder from compressed DB messages so that # upload_transcript saves a compact version for future turns. @@ -2872,15 +3159,16 @@ async def stream_chat_completion_sdk( elif "session_id" in sdk_options_kwargs: # Initial invocation used session_id (T1 or mode-switch # T1): keep it so the CLI writes the session file to the - # predictable path for upload_cli_session(). Storage is + # predictable path for upload_transcript(). Storage is # ephemeral per invocation, so no "Session ID already in # use" conflict occurs — no prior file was restored. sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry["session_id"] = session_id else: - # T2+ retry without --resume: do not pass --session-id. - # The T1 session file already exists at that path; re-using - # the same ID would fail with "Session ID already in use". + # T2+ retry without --resume: initial invocation used + # --resume, which restored the T1 session file to local + # storage. Re-using session_id without --resume would + # fail with "Session ID already in use". sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry.pop("session_id", None) # Recompute system_prompt for retry — ctx.use_resume may have @@ -2894,6 +3182,10 @@ async def stream_chat_completion_sdk( system_prompt, cross_user_cache=_cross_user_retry ) state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs + # Retry intentionally omits prior_messages (transcript+gap context) and + # falls back to full session.messages[:-1] from DB — the authoritative + # source. transcript+gap is an optimisation for the first attempt only; + # on retry the extra overhead of full-DB context is acceptable. state.query_message, state.was_compacted = await _build_query_message( current_message, session, @@ -2904,6 +3196,8 @@ async def stream_chat_completion_sdk( ) if attachments.hint: state.query_message = f"{state.query_message}\n\n{attachments.hint}" + # warm_ctx is already baked into current_message via + # inject_user_context — no separate injection needed. state.adapter = SDKResponseAdapter( message_id=message_id, session_id=session_id ) @@ -3308,92 +3602,42 @@ async def stream_chat_completion_sdk( if graphiti_enabled and user_id and message and is_user_message: from ..graphiti.ingest import enqueue_conversation_turn + # Extract last assistant message from THIS TURN only (not all + # session history) to avoid distilling stale content from prior + # turns when the current turn errors before producing output. + _this_turn_msgs = ( + session.messages[pre_attempt_msg_count:] if session else [] + ) + _assistant_msgs = [ + m.content or "" for m in _this_turn_msgs if m.role == "assistant" + ] + _last_assistant = _assistant_msgs[-1] if _assistant_msgs else "" + _ingest_task = asyncio.create_task( - enqueue_conversation_turn(user_id, session_id, message) + enqueue_conversation_turn( + user_id, session_id, message, assistant_msg=_last_assistant + ) ) _background_tasks.add(_ingest_task) _ingest_task.add_done_callback(_background_tasks.discard) - # --- Upload transcript for next-turn --resume --- - # TranscriptBuilder is the single source of truth. It mirrors the - # CLI's active context: on compaction, replace_entries() syncs it - # with the compacted session file. No CLI file read needed here. - if skip_transcript_upload: - logger.warning( - "%s Skipping transcript upload — transcript was dropped " - "during prompt-too-long recovery", - log_prefix, - ) - elif ( - config.claude_agent_use_resume - and user_id - and session is not None - and state is not None - ): - try: - transcript_upload_content = state.transcript_builder.to_jsonl() - entry_count = state.transcript_builder.entry_count - - if not transcript_upload_content: - logger.warning( - "%s No transcript to upload (builder empty)", log_prefix - ) - elif not validate_transcript(transcript_upload_content): - logger.warning( - "%s Transcript invalid, skipping upload (entries=%d)", - log_prefix, - entry_count, - ) - elif not transcript_covers_prefix: - logger.warning( - "%s Skipping transcript upload — builder does not " - "cover full session prefix (entries=%d, session=%d)", - log_prefix, - entry_count, - len(session.messages), - ) - else: - logger.info( - "%s Uploading transcript (entries=%d, bytes=%d)", - log_prefix, - entry_count, - len(transcript_upload_content), - ) - await asyncio.shield( - upload_transcript( - user_id=user_id, - session_id=session_id, - content=transcript_upload_content, - message_count=len(session.messages), - log_prefix=log_prefix, - ) - ) - except Exception as upload_err: - logger.error( - "%s Transcript upload failed in finally: %s", - log_prefix, - upload_err, - exc_info=True, - ) - # --- Upload CLI native session file for cross-pod --resume --- # The CLI writes its native session JSONL after each turn completes. - # Uploading it here enables --resume on any pod (no pod affinity needed). - # Runs after upload_transcript so both are available for the next turn. - # asyncio.shield: same pattern as upload_transcript above — if the - # outer finally-block coroutine is cancelled while awaiting shield, - # the CancelledError propagates (BaseException, not caught by - # `except Exception`) letting the caller handle cancellation, while - # the shielded inner coroutine continues running to completion so the - # upload is not lost. This is intentional and matches the pattern - # used for upload_transcript immediately above. + # The companion .meta.json carries the message_count watermark and mode + # so the next turn can restore both --resume context and gap-fill state + # in a single GCS round-trip via download_transcript(). + # asyncio.shield: if the outer finally-block coroutine is cancelled + # while awaiting shield, the CancelledError propagates (BaseException, + # not caught by `except Exception`) letting the caller handle + # cancellation, while the shielded inner coroutine continues running + # to completion so the upload is not lost. # # NOTE: upload is attempted regardless of state.use_resume — even when # this turn ran without --resume (restore failed or first T2+ on a new # pod), the T1 session file at the expected path may still be present # and should be re-uploaded so the next turn can resume from it. - # upload_cli_session silently skips when the file is absent, so this is - # always safe. + # _read_cli_session_from_disk returns None when the file is absent, so + # this is always safe. # # Intentionally NOT gated on skip_transcript_upload: that flag is set # when our custom JSONL transcript is dropped (transcript_lost=True on @@ -3419,14 +3663,36 @@ async def stream_chat_completion_sdk( skip_transcript_upload, ) try: - await asyncio.shield( - upload_cli_session( - user_id=user_id, - session_id=session_id, - sdk_cwd=sdk_cwd, - log_prefix=log_prefix, - ) + # Read the CLI's native session file from disk (written by the CLI + # after the turn), then upload the bytes to GCS. + _cli_content = _read_cli_session_from_disk( + sdk_cwd, session_id, log_prefix ) + if _cli_content: + # Watermark = number of DB messages this transcript covers. + # len(session.messages) is accurate: the CLI session file + # was just written after the turn completed, so it covers + # all messages through this turn. Any gap from a prior + # missed upload was already detected by detect_gap and + # injected as context, so the model has the full history. + # + # Previously this used _final_tmsg_count + 2, which + # under-counted for tool-use turns (delta = 2 + 2*N_tool_calls), + # causing persistent spurious gap-fills on every subsequent turn. + # That concern was addressed by the inflated-watermark fix + # (using the GCS watermark as the anchor for gap detection), + # which makes len(session.messages) safe to use here. + _jsonl_covered = len(session.messages) + await asyncio.shield( + upload_transcript( + user_id=user_id, + session_id=session_id, + content=_cli_content, + message_count=_jsonl_covered, + mode="sdk", + log_prefix=log_prefix, + ) + ) except Exception as cli_upload_err: logger.warning( "%s CLI session upload failed in finally: %s", diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index 470858dc55..3b919c6036 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -22,6 +22,7 @@ from .service import ( _iter_sdk_messages, _normalize_model_name, _reduce_context, + _restore_cli_session_for_turn, _TokenUsage, ) @@ -392,7 +393,9 @@ class TestNormalizeModelName: def test_sonnet_openrouter_model(self): """Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly.""" - assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4" + assert ( + _normalize_model_name("anthropic/claude-sonnet-4-6") == "claude-sonnet-4-6" + ) # --------------------------------------------------------------------------- @@ -613,3 +616,340 @@ class TestSdkSessionIdSelection: ) assert retry.get("resume") == self.SESSION_ID assert "session_id" not in retry + + +# --------------------------------------------------------------------------- +# _restore_cli_session_for_turn — mode check +# --------------------------------------------------------------------------- + + +class TestRestoreCliSessionModeCheck: + """SDK skips --resume when the transcript was written by the baseline mode.""" + + @pytest.mark.asyncio + async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path): + """A transcript with mode='baseline' must not be used as the --resume source. + + The mode check discards the GCS baseline content and falls back to DB + reconstruction from session.messages instead. + """ + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="hello-unique-marker"), + ChatMessage(role="assistant", content="world-unique-marker"), + ChatMessage(role="user", content="follow up"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + # Baseline content with a sentinel that must NOT appear in the final transcript + baseline_restore = TranscriptDownload( + content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n', + message_count=1, + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + download_mock = AsyncMock(return_value=baseline_restore) + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=download_mock, + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + # download_transcript was called (attempted GCS restore) + download_mock.assert_awaited_once() + # use_resume must be False — baseline transcripts cannot be used with --resume + assert result.use_resume is False + # context_messages must be populated — new behaviour uses transcript content + gap + # instead of full DB reconstruction. + assert result.context_messages is not None + # The baseline transcript has 1 user message (BASELINE_SENTINEL). + # Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns []. + # Result: 1 message from transcript, no gap. + assert len(result.context_messages) == 1 + assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "") + + @pytest.mark.asyncio + async def test_sdk_mode_transcript_allows_resume(self, tmp_path): + """A valid SDK-written transcript is accepted for --resume.""" + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "hi"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "hello"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="hi"), + ChatMessage(role="assistant", content="hello"), + ChatMessage(role="user", content="follow up"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + sdk_restore = TranscriptDownload( + content=content, + message_count=2, + mode="sdk", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=sdk_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is True + + @pytest.mark.asyncio + async def test_baseline_mode_context_messages_from_transcript_content( + self, tmp_path + ): + """mode='baseline' → context_messages populated from transcript content + gap. + + When a baseline-mode transcript exists, extract_context_messages converts + the JSONL content to ChatMessage objects and returns them in context_messages. + use_resume must remain False. + """ + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + # Build a minimal valid JSONL transcript with 2 messages + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "TRANSCRIPT_USER"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="DB_USER"), + ChatMessage(role="assistant", content="DB_ASSISTANT"), + ChatMessage(role="user", content="current turn"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + baseline_restore = TranscriptDownload( + content=content, + message_count=2, + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=baseline_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is False + assert result.context_messages is not None + # Transcript content has 2 messages, no gap (watermark=2, session prior=2) + assert len(result.context_messages) == 2 + assert result.context_messages[0].role == "user" + assert result.context_messages[1].role == "assistant" + assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "") + # transcript_content must be non-empty so the _seed_transcript guard in + # stream_chat_completion_sdk skips DB reconstruction (which would duplicate + # builder entries since load_previous appends). + assert result.transcript_content != "" + + @pytest.mark.asyncio + async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path): + """mode='baseline' + gap → context_messages includes transcript msgs and gap.""" + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + # Transcript covers only 2 messages; session has 4 prior + current turn + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "TRANSCRIPT_USER_0"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="DB_USER_0"), + ChatMessage(role="assistant", content="DB_ASSISTANT_1"), + ChatMessage(role="user", content="GAP_USER_2"), + ChatMessage(role="assistant", content="GAP_ASSISTANT_3"), + ChatMessage(role="user", content="current turn"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + baseline_restore = TranscriptDownload( + content=content, + message_count=2, # watermark=2; session has 4 prior → gap of 2 + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=baseline_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is False + assert result.context_messages is not None + # 2 from transcript + 2 gap messages = 4 total + assert len(result.context_messages) == 4 + roles = [m.role for m in result.context_messages] + assert roles == ["user", "assistant", "user", "assistant"] + # Gap messages come from DB (ChatMessage objects) + gap_user = result.context_messages[2] + gap_asst = result.context_messages[3] + assert gap_user.content == "GAP_USER_2" + assert gap_asst.content == "GAP_ASSISTANT_3" diff --git a/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py new file mode 100644 index 0000000000..ea7b128927 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py @@ -0,0 +1,217 @@ +"""Tests for the pre-create assistant message logic that prevents +last_role=tool after client disconnect. + +Reproduces the bug where: + 1. Tool result is saved by intermediate flush → last_role=tool + 2. SDK generates a text response + 3. GeneratorExit at StreamStartStep yield (client disconnect) + 4. _dispatch_response(StreamTextDelta) is never called + 5. Session saved with last_role=tool instead of last_role=assistant + +The fix: before yielding any events, pre-create the assistant message in +ctx.session.messages when has_tool_results=True and a StreamTextDelta is +present in adapter_responses. This test verifies the resulting accumulator +state allows correct content accumulation by _dispatch_response. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +from backend.copilot.model import ChatMessage, ChatSession +from backend.copilot.response_model import StreamStartStep, StreamTextDelta +from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator + +_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc) + + +def _make_session() -> ChatSession: + return ChatSession( + session_id="test", + user_id="test-user", + title="test", + messages=[], + usage=[], + started_at=_NOW, + updated_at=_NOW, + ) + + +def _make_ctx(session: ChatSession | None = None) -> MagicMock: + ctx = MagicMock() + ctx.session = session or _make_session() + ctx.log_prefix = "[test]" + return ctx + + +def _make_state() -> MagicMock: + state = MagicMock() + state.transcript_builder = MagicMock() + return state + + +def _simulate_pre_create(acc: _StreamAccumulator, ctx: MagicMock) -> None: + """Mirror the pre-create block from _run_stream_attempt so tests + can verify its effect without invoking the full async generator. + + Keep in sync with the block in service.py _run_stream_attempt + (search: "Pre-create the new assistant message"). + """ + acc.assistant_response = ChatMessage(role="assistant", content="") + acc.accumulated_tool_calls = [] + acc.has_tool_results = False + ctx.session.messages.append(acc.assistant_response) + # acc.has_appended_assistant stays True + + +class TestPreCreateAssistantMessage: + """Verify that the pre-create logic correctly seeds the session message + and that subsequent _dispatch_response(StreamTextDelta) accumulates + content in-place without a double-append.""" + + def test_pre_create_adds_message_to_session(self) -> None: + """After pre-create, session has one assistant message.""" + session = _make_session() + ctx = _make_ctx(session) + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + + _simulate_pre_create(acc, ctx) + + assert len(session.messages) == 1 + assert session.messages[-1].role == "assistant" + assert session.messages[-1].content == "" + + def test_pre_create_resets_tool_result_flag(self) -> None: + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + ctx = _make_ctx() + _simulate_pre_create(acc, ctx) + + assert acc.has_tool_results is False + + def test_pre_create_resets_accumulated_tool_calls(self) -> None: + existing_call = { + "id": "call_1", + "type": "function", + "function": {"name": "bash"}, + } + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[existing_call], + has_appended_assistant=True, + has_tool_results=True, + ) + ctx = _make_ctx() + _simulate_pre_create(acc, ctx) + + assert acc.accumulated_tool_calls == [] + + def test_text_delta_accumulates_in_preexisting_message(self) -> None: + """StreamTextDelta after pre-create updates the already-appended message + in-place — no double-append.""" + session = _make_session() + ctx = _make_ctx(session) + state = _make_state() + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + + _simulate_pre_create(acc, ctx) + assert len(session.messages) == 1 + + # Simulate the first text delta arriving after pre-create + delta = StreamTextDelta(id="t1", delta="Hello world") + _dispatch_response(delta, acc, ctx, state, False, "[test]") + + # Still only one message (no double-append) + assert len(session.messages) == 1 + # Content accumulated in the pre-created message + assert session.messages[-1].content == "Hello world" + assert session.messages[-1].role == "assistant" + + def test_subsequent_deltas_append_to_content(self) -> None: + """Multiple deltas build up the full response text.""" + session = _make_session() + ctx = _make_ctx(session) + state = _make_state() + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + + _simulate_pre_create(acc, ctx) + + for word in ["You're ", "right ", "about ", "that."]: + _dispatch_response( + StreamTextDelta(id="t1", delta=word), acc, ctx, state, False, "[test]" + ) + + assert len(session.messages) == 1 + assert session.messages[-1].content == "You're right about that." + + def test_pre_create_not_triggered_without_tool_results(self) -> None: + """Pre-create condition requires has_tool_results=True; no-op otherwise.""" + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=False, # no prior tool results + ) + ctx = _make_ctx() + + # Condition is False — simulate: do nothing + if acc.has_tool_results and acc.has_appended_assistant: + _simulate_pre_create(acc, ctx) + + assert len(ctx.session.messages) == 0 + + def test_pre_create_not_triggered_when_not_yet_appended(self) -> None: + """Pre-create requires has_appended_assistant=True.""" + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=False, # first turn, nothing appended yet + has_tool_results=True, + ) + ctx = _make_ctx() + + if acc.has_tool_results and acc.has_appended_assistant: + _simulate_pre_create(acc, ctx) + + assert len(ctx.session.messages) == 0 + + def test_pre_create_not_triggered_without_text_delta(self) -> None: + """Pre-create is skipped when adapter_responses has no StreamTextDelta + (e.g. a tool-only batch). Verifies the third guard condition.""" + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + has_appended_assistant=True, + has_tool_results=True, + ) + ctx = _make_ctx() + adapter_responses = [StreamStartStep()] # no StreamTextDelta + + if ( + acc.has_tool_results + and acc.has_appended_assistant + and any(isinstance(r, StreamTextDelta) for r in adapter_responses) + ): + _simulate_pre_create(acc, ctx) + + assert len(ctx.session.messages) == 0 diff --git a/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py b/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py new file mode 100644 index 0000000000..592dbde82f --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py @@ -0,0 +1,95 @@ +"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk. + +The fix is at the upload step: when use_resume=True and transcript_msg_count>0 +we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just +recorded) instead of len(session.messages). This prevents the "inflated +watermark" bug where a stale JSONL in GCS could hide missing context from +future gap-fill checks. +""" + +from __future__ import annotations + + +def _compute_jsonl_covered( + use_resume: bool, + transcript_msg_count: int, + session_msg_count: int, +) -> int: + """Mirror the watermark computation from ``stream_chat_completion_sdk``. + + Extracted here so we can unit-test it independently without invoking the + full streaming stack. + """ + if use_resume and transcript_msg_count > 0: + return transcript_msg_count + 2 + return session_msg_count + + +class TestWatermarkFix: + """Watermark computation logic — mirrors the finally-block in SDK service.""" + + def test_inflated_watermark_triggers_gap_fill(self): + """Stale JSONL (T12) with high watermark (46) → after fix, watermark=14. + + Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1) + never fires because 46 >= 47-1=46, so context loss is silent. + After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and + the model receives the missing turns. + """ + # Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47 + use_resume = True + transcript_msg_count = 12 + session_msg_count = 47 # DB count (what old code used to set watermark) + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == 14 # 12 + 2, NOT 47 + # Verify: the gap check would fire on next turn + # next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True + assert watermark < session_msg_count - 1 + + def test_no_false_positive_when_transcript_current(self): + """Transcript current (watermark=46, DB=47) → gap stays 0. + + When the JSONL actually covers T46 (the most recent assistant turn), + uploading watermark=46+2=48 means next turn's gap check sees + 48 >= 48-1=47 → no gap. Correct. + """ + use_resume = True + transcript_msg_count = 46 + session_msg_count = 47 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == 48 # 46 + 2 + # Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap + next_turn_session = 48 + assert watermark >= next_turn_session - 1 + + def test_fresh_session_falls_back_to_db_count(self): + """use_resume=False → watermark = len(session.messages) (original behaviour).""" + use_resume = False + transcript_msg_count = 0 + session_msg_count = 3 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == session_msg_count + + def test_old_format_meta_zero_count_falls_back_to_db(self): + """transcript_msg_count=0 (old-format meta with no count field) → DB fallback.""" + use_resume = True + transcript_msg_count = 0 # old-format meta or not-yet-set + session_msg_count = 10 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == session_msg_count diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript.py b/autogpt_platform/backend/backend/copilot/sdk/transcript.py index cfbf01a466..d5cf3c3e94 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/transcript.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript.py @@ -12,18 +12,20 @@ from backend.copilot.transcript import ( ENTRY_TYPE_MESSAGE, STOP_REASON_END_TURN, STRIPPABLE_TYPES, - TRANSCRIPT_STORAGE_PREFIX, TranscriptDownload, + TranscriptMode, cleanup_stale_project_dirs, + cli_session_path, compact_transcript, delete_transcript, + detect_gap, download_transcript, + extract_context_messages, + projects_base, read_compacted_entries, - restore_cli_session, strip_for_upload, strip_progress_entries, strip_stale_thinking_blocks, - upload_cli_session, upload_transcript, validate_transcript, write_transcript_to_tempfile, @@ -34,18 +36,20 @@ __all__ = [ "ENTRY_TYPE_MESSAGE", "STOP_REASON_END_TURN", "STRIPPABLE_TYPES", - "TRANSCRIPT_STORAGE_PREFIX", "TranscriptDownload", + "TranscriptMode", "cleanup_stale_project_dirs", + "cli_session_path", "compact_transcript", "delete_transcript", + "detect_gap", "download_transcript", + "extract_context_messages", + "projects_base", "read_compacted_entries", - "restore_cli_session", "strip_for_upload", "strip_progress_entries", "strip_stale_thinking_blocks", - "upload_cli_session", "upload_transcript", "validate_transcript", "write_transcript_to_tempfile", diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py index 14e404a994..f8e1608094 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py @@ -297,8 +297,8 @@ class TestStripProgressEntries: class TestDeleteTranscript: @pytest.mark.asyncio - async def test_deletes_both_jsonl_and_meta(self): - """delete_transcript removes both the .jsonl and .meta.json files.""" + async def test_deletes_cli_session_and_meta(self): + """delete_transcript removes the CLI session .jsonl and .meta.json.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock() @@ -309,7 +309,7 @@ class TestDeleteTranscript: ): await delete_transcript("user-123", "session-456") - assert mock_storage.delete.call_count == 3 + assert mock_storage.delete.call_count == 2 paths = [call.args[0] for call in mock_storage.delete.call_args_list] assert any(p.endswith(".jsonl") for p in paths) assert any(p.endswith(".meta.json") for p in paths) @@ -319,7 +319,7 @@ class TestDeleteTranscript: """If .jsonl delete fails, .meta.json delete is still attempted.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock( - side_effect=[Exception("jsonl delete failed"), None, None] + side_effect=[Exception("jsonl delete failed"), None] ) with patch( @@ -330,14 +330,14 @@ class TestDeleteTranscript: # Should not raise await delete_transcript("user-123", "session-456") - assert mock_storage.delete.call_count == 3 + assert mock_storage.delete.call_count == 2 @pytest.mark.asyncio async def test_handles_meta_delete_failure(self): """If .meta.json delete fails, no exception propagates.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock( - side_effect=[None, Exception("meta delete failed"), None] + side_effect=[None, Exception("meta delete failed")] ) with patch( @@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs: nonexistent = str(tmp_path / "does-not-exist" / "projects") monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: nonexistent, ) @@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks: # Both entries of last turn (msg_last) preserved assert lines[1]["message"]["content"][0]["type"] == "thinking" assert lines[2]["message"]["content"][0]["type"] == "text" + + +class TestProcessCliRestore: + """``_process_cli_restore`` validates, strips, and writes CLI session to disk.""" + + def test_writes_stripped_bytes_not_raw(self, tmp_path): + """Stripped bytes (not raw bytes) must be written to disk for --resume.""" + import os + import re + from pathlib import Path + from unittest.mock import patch + + from backend.copilot.sdk.service import _process_cli_restore + from backend.copilot.transcript import TranscriptDownload + + session_id = "12345678-0000-0000-0000-abcdef000001" + sdk_cwd = str(tmp_path) + projects_base_dir = str(tmp_path) + + # Build raw content with a strippable progress entry + a valid user/assistant pair + raw_content = ( + '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + '{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n' + '{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n' + ) + raw_bytes = raw_content.encode("utf-8") + restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk") + + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + stripped_str, ok = _process_cli_restore( + restore, sdk_cwd, session_id, "[Test]" + ) + + assert ok, "Expected successful restore" + + # Find the written session file + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl" + assert session_file.exists(), "Session file should have been written" + + written_bytes = session_file.read_bytes() + # The written bytes must be the stripped version (no progress entry) + assert ( + b"progress" not in written_bytes + ), "Raw bytes with progress entry should not have been written" + assert ( + b"hello" in written_bytes + ), "Stripped content should still contain assistant turn" + + # Written bytes must equal the stripped string re-encoded + assert written_bytes == stripped_str.encode( + "utf-8" + ), "Written bytes must equal stripped content" + + def test_invalid_content_returns_false(self): + """Content that fails validation after strip returns (empty, False).""" + from backend.copilot.sdk.service import _process_cli_restore + from backend.copilot.transcript import TranscriptDownload + + # A single progress-only entry — stripped result will be empty/invalid + raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + restore = TranscriptDownload( + content=raw_content.encode("utf-8"), message_count=1, mode="sdk" + ) + + stripped_str, ok = _process_cli_restore( + restore, + "/tmp/nonexistent-sdk-cwd", + "12345678-0000-0000-0000-000000000099", + "[Test]", + ) + + assert not ok + assert stripped_str == "" + + +class TestReadCliSessionFromDisk: + """``_read_cli_session_from_disk`` reads, strips, and optionally writes back the session.""" + + def _build_session_file(self, tmp_path, session_id: str): + """Build the session file path inside tmp_path using the same encoding as cli_session_path.""" + import os + import re + from pathlib import Path + + sdk_cwd = str(tmp_path) + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_dir = Path(str(tmp_path)) / encoded_cwd + session_dir.mkdir(parents=True, exist_ok=True) + return sdk_cwd, session_dir / f"{session_id}.jsonl" + + def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path): + """Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback).""" + from unittest.mock import patch + + from backend.copilot.sdk.service import _read_cli_session_from_disk + + session_id = "12345678-0000-0000-0000-aabbccdd0001" + projects_base_dir = str(tmp_path) + sdk_cwd, session_file = self._build_session_file(tmp_path, session_id) + + # Write raw invalid UTF-8 bytes + session_file.write_bytes(b"\xff\xfe invalid utf-8\n") + + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]") + + # UnicodeDecodeError path returns the raw bytes (upload-raw fallback) + assert result == b"\xff\xfe invalid utf-8\n" + + def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path): + """OSError on write-back returns stripped bytes for GCS upload (not raw).""" + from unittest.mock import patch + + from backend.copilot.sdk.service import _read_cli_session_from_disk + + session_id = "12345678-0000-0000-0000-aabbccdd0002" + projects_base_dir = str(tmp_path) + sdk_cwd, session_file = self._build_session_file(tmp_path, session_id) + + # Content with a strippable progress entry so stripped_bytes < raw_bytes + raw_content = ( + '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + '{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n' + '{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n' + ) + session_file.write_bytes(raw_content.encode("utf-8")) + # Make the file read-only so write_bytes raises OSError on the write-back + session_file.chmod(0o444) + + try: + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]") + finally: + session_file.chmod(0o644) + + # Must return stripped bytes (not raw, not None) so GCS gets the clean version + assert result is not None + assert ( + b"progress" not in result + ), "Stripped bytes must not contain progress entry" + assert b"hello" in result, "Stripped bytes should contain assistant turn" diff --git a/autogpt_platform/backend/backend/copilot/service_test.py b/autogpt_platform/backend/backend/copilot/service_test.py index c4b1c3182e..ec9b13fb22 100644 --- a/autogpt_platform/backend/backend/copilot/service_test.py +++ b/autogpt_platform/backend/backend/copilot/service_test.py @@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id): # (CLI version, platform). When that happens, multi-turn still works # via conversation compression (non-resume path), but we can't test # the --resume round-trip. - transcript = None + cli_session = None for _ in range(10): await asyncio.sleep(0.5) - transcript = await download_transcript(test_user_id, session.session_id) - if transcript: + cli_session = await download_transcript(test_user_id, session.session_id) + # Wait until both the session bytes AND the message_count watermark are + # present — a session with message_count=0 means the .meta.json hasn't + # been uploaded yet, so --resume on the next turn would skip gap-fill. + if cli_session and cli_session.message_count > 0: break - if not transcript: + if not cli_session: return pytest.skip( "CLI did not produce a usable transcript — " "cannot test --resume round-trip in this environment" ) - logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes") + logger.info( + f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}" + ) # Reload session for turn 2 session = await get_chat_session(session.session_id, test_user_id) diff --git a/autogpt_platform/backend/backend/copilot/tools/__init__.py b/autogpt_platform/backend/backend/copilot/tools/__init__.py index c4913a9411..75a0a8f4e4 100644 --- a/autogpt_platform/backend/backend/copilot/tools/__init__.py +++ b/autogpt_platform/backend/backend/copilot/tools/__init__.py @@ -26,6 +26,7 @@ from .fix_agent import FixAgentGraphTool from .get_agent_building_guide import GetAgentBuildingGuideTool from .get_doc_page import GetDocPageTool from .get_mcp_guide import GetMCPGuideTool +from .graphiti_forget import MemoryForgetConfirmTool, MemoryForgetSearchTool from .graphiti_search import MemorySearchTool from .graphiti_store import MemoryStoreTool from .manage_folders import ( @@ -66,6 +67,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "find_block": FindBlockTool(), "find_library_agent": FindLibraryAgentTool(), # Graphiti memory tools + "memory_forget_confirm": MemoryForgetConfirmTool(), + "memory_forget_search": MemoryForgetSearchTool(), "memory_search": MemorySearchTool(), "memory_store": MemoryStoreTool(), # Folder management tools diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_forget.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget.py new file mode 100644 index 0000000000..c3a30a583e --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget.py @@ -0,0 +1,349 @@ +"""Two-step tool for targeted memory deletion. + +Step 1 (memory_forget_search): search for matching facts, return candidates. +Step 2 (memory_forget_confirm): delete specific edges by UUID after user confirms. +""" + +import logging +from typing import Any + +from backend.copilot.graphiti._format import extract_fact, extract_temporal_validity +from backend.copilot.graphiti.client import derive_group_id, get_graphiti_client +from backend.copilot.graphiti.config import is_enabled_for_user +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import ( + ErrorResponse, + MemoryForgetCandidatesResponse, + MemoryForgetConfirmResponse, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class MemoryForgetSearchTool(BaseTool): + """Search for memories to forget — returns candidates for user confirmation.""" + + @property + def name(self) -> str: + return "memory_forget_search" + + @property + def description(self) -> str: + return ( + "Search for stored memories matching a description so the user can " + "choose which to delete. Returns candidate facts with UUIDs. " + "Use memory_forget_confirm with the UUIDs to actually delete them." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Natural language description of what to forget (e.g. 'the Q2 marketing budget')", + }, + }, + "required": ["query"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + *, + query: str = "", + **kwargs, + ) -> ToolResponseBase: + if not user_id: + return ErrorResponse( + message="Authentication required.", + session_id=session.session_id, + ) + + if not await is_enabled_for_user(user_id): + return ErrorResponse( + message="Memory features are not enabled for your account.", + session_id=session.session_id, + ) + + if not query: + return ErrorResponse( + message="A search query is required to find memories to forget.", + session_id=session.session_id, + ) + + try: + group_id = derive_group_id(user_id) + except ValueError: + return ErrorResponse( + message="Invalid user ID for memory operations.", + session_id=session.session_id, + ) + + try: + client = await get_graphiti_client(group_id) + edges = await client.search( + query=query, + group_ids=[group_id], + num_results=10, + ) + except Exception: + logger.warning( + "Memory forget search failed for user %s", user_id[:12], exc_info=True + ) + return ErrorResponse( + message="Memory search is temporarily unavailable.", + session_id=session.session_id, + ) + + if not edges: + return MemoryForgetCandidatesResponse( + message="No matching memories found.", + session_id=session.session_id, + candidates=[], + ) + + candidates = [] + for e in edges: + edge_uuid = getattr(e, "uuid", None) or getattr(e, "id", None) + if not edge_uuid: + continue + fact = extract_fact(e) + valid_from, valid_to = extract_temporal_validity(e) + candidates.append( + { + "uuid": str(edge_uuid), + "fact": fact, + "valid_from": str(valid_from), + "valid_to": str(valid_to), + } + ) + + return MemoryForgetCandidatesResponse( + message=f"Found {len(candidates)} candidate(s). Show these to the user and ask which to delete, then call memory_forget_confirm with the UUIDs.", + session_id=session.session_id, + candidates=candidates, + ) + + +class MemoryForgetConfirmTool(BaseTool): + """Delete specific memory edges by UUID after user confirmation. + + Supports both soft delete (temporal invalidation — reversible) and + hard delete (remove from graph — irreversible, for GDPR). + """ + + @property + def name(self) -> str: + return "memory_forget_confirm" + + @property + def description(self) -> str: + return ( + "Delete specific memories by UUID. Use after memory_forget_search " + "returns candidates and the user confirms which to delete. " + "Default is soft delete (marks as expired but keeps history). " + "Set hard_delete=true for permanent removal (GDPR)." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "uuids": { + "type": "array", + "items": {"type": "string"}, + "description": "List of edge UUIDs to delete (from memory_forget_search results)", + }, + "hard_delete": { + "type": "boolean", + "description": "If true, permanently removes edges from the graph (GDPR). Default false (soft delete — marks as expired).", + "default": False, + }, + }, + "required": ["uuids"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + *, + uuids: list[str] | None = None, + hard_delete: bool = False, + **kwargs, + ) -> ToolResponseBase: + if not user_id: + return ErrorResponse( + message="Authentication required.", + session_id=session.session_id, + ) + + if not await is_enabled_for_user(user_id): + return ErrorResponse( + message="Memory features are not enabled for your account.", + session_id=session.session_id, + ) + + if not uuids: + return ErrorResponse( + message="At least one UUID is required. Use memory_forget_search first.", + session_id=session.session_id, + ) + + try: + group_id = derive_group_id(user_id) + except ValueError: + return ErrorResponse( + message="Invalid user ID for memory operations.", + session_id=session.session_id, + ) + + try: + client = await get_graphiti_client(group_id) + except Exception: + logger.warning( + "Failed to get Graphiti client for user %s", user_id[:12], exc_info=True + ) + return ErrorResponse( + message="Memory service is temporarily unavailable.", + session_id=session.session_id, + ) + + driver = getattr(client, "graph_driver", None) or getattr( + client, "driver", None + ) + if not driver: + return ErrorResponse( + message="Could not access graph driver for deletion.", + session_id=session.session_id, + ) + + if hard_delete: + deleted, failed = await _hard_delete_edges(driver, uuids, user_id) + mode = "permanently deleted" + else: + deleted, failed = await _soft_delete_edges(driver, uuids, user_id) + mode = "invalidated" + + return MemoryForgetConfirmResponse( + message=( + f"{len(deleted)} memory edge(s) {mode}." + + (f" {len(failed)} failed." if failed else "") + ), + session_id=session.session_id, + deleted_uuids=deleted, + failed_uuids=failed, + ) + + +async def _soft_delete_edges( + driver, uuids: list[str], user_id: str +) -> tuple[list[str], list[str]]: + """Temporal invalidation — mark edges as expired without removing them. + + Sets ``invalid_at`` and ``expired_at`` to now, which excludes them + from default search results while preserving history. + + Matches the same edge types as ``_hard_delete_edges`` so that edges of + any type (RELATES_TO, MENTIONS, HAS_MEMBER) can be soft-deleted. + """ + deleted = [] + failed = [] + for uuid in uuids: + try: + records, _, _ = await driver.execute_query( + """ + MATCH ()-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->() + SET e.invalid_at = datetime(), + e.expired_at = datetime() + RETURN e.uuid AS uuid + """, + uuid=uuid, + ) + if records: + deleted.append(uuid) + else: + failed.append(uuid) + except Exception: + logger.warning( + "Failed to soft-delete edge %s for user %s", + uuid, + user_id[:12], + exc_info=True, + ) + failed.append(uuid) + return deleted, failed + + +async def _hard_delete_edges( + driver, uuids: list[str], user_id: str +) -> tuple[list[str], list[str]]: + """Permanent removal — delete edges and clean up back-references. + + Uses graphiti's ``Edge.delete()`` pattern (handles MENTIONS, + RELATES_TO, HAS_MEMBER in one query). Does NOT delete orphaned + entity nodes — they may have summaries, embeddings, or future + connections. Cleans up episode ``entity_edges`` back-references. + """ + deleted = [] + failed = [] + for uuid in uuids: + try: + # Use WITH to capture the uuid before DELETE so we don't + # access properties of deleted relationships (FalkorDB #1393). + # Single atomic query avoids TOCTOU between check and delete. + records, _, _ = await driver.execute_query( + """ + MATCH ()-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->() + WITH e.uuid AS uuid, e + DELETE e + RETURN uuid + """, + uuid=uuid, + ) + if not records: + failed.append(uuid) + continue + # Edge was deleted — report success regardless of cleanup outcome. + deleted.append(uuid) + # Clean up episode back-references (best-effort). + try: + await driver.execute_query( + """ + MATCH (ep:Episodic) + WHERE $uuid IN ep.entity_edges + SET ep.entity_edges = [x IN ep.entity_edges WHERE x <> $uuid] + """, + uuid=uuid, + ) + except Exception: + logger.warning( + "Edge %s deleted but back-ref cleanup failed for user %s", + uuid, + user_id[:12], + exc_info=True, + ) + except Exception: + logger.warning( + "Failed to hard-delete edge %s for user %s", + uuid, + user_id[:12], + exc_info=True, + ) + failed.append(uuid) + return deleted, failed diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_forget_test.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget_test.py new file mode 100644 index 0000000000..94bbeb5d4f --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_forget_test.py @@ -0,0 +1,77 @@ +"""Tests for graphiti_forget delete helpers.""" + +from unittest.mock import AsyncMock + +import pytest + +from backend.copilot.tools.graphiti_forget import _hard_delete_edges, _soft_delete_edges + + +class TestSoftDeleteOverReportsSuccess: + """_soft_delete_edges always appends UUID to deleted list even when + the Cypher MATCH found no edge (query succeeds but matches nothing). + """ + + @pytest.mark.asyncio + async def test_reports_failure_when_no_edge_matched(self) -> None: + driver = AsyncMock() + # execute_query returns empty result set — no edge matched + driver.execute_query.return_value = ([], None, None) + + deleted, failed = await _soft_delete_edges( + driver, ["nonexistent-uuid"], "test-user" + ) + # Should NOT report success when nothing was actually updated + assert deleted == [], f"over-reported success: {deleted}" + assert failed == ["nonexistent-uuid"] + + +class TestSoftDeleteNoMatchReportsFailure: + """When the query returns empty records (no edge with that UUID exists + in the database), _soft_delete_edges should report it as failed. + """ + + @pytest.mark.asyncio + async def test_soft_delete_handles_non_relates_to_edge(self) -> None: + driver = AsyncMock() + # Simulate: RELATES_TO match returns nothing (edge is MENTIONS type) + driver.execute_query.return_value = ([], None, None) + + deleted, failed = await _soft_delete_edges( + driver, ["mentions-edge-uuid"], "test-user" + ) + # With the bug, this reports success even though nothing was updated + assert "mentions-edge-uuid" not in deleted + + +class TestHardDeleteBasicFlow: + """Verify _hard_delete_edges calls the right queries.""" + + @pytest.mark.asyncio + async def test_hard_delete_calls_both_queries(self) -> None: + driver = AsyncMock() + # First call (delete) returns a matched record, second (cleanup) returns empty + driver.execute_query.side_effect = [ + ([{"uuid": "uuid-1"}], None, None), + ([], None, None), + ] + + deleted, failed = await _hard_delete_edges(driver, ["uuid-1"], "test-user") + assert deleted == ["uuid-1"] + assert failed == [] + # Should call: 1) delete edge, 2) clean episode back-refs + assert driver.execute_query.call_count == 2 + + @pytest.mark.asyncio + async def test_hard_delete_reports_failure_when_no_edge_matched(self) -> None: + driver = AsyncMock() + # Delete query returns no records — edge not found + driver.execute_query.return_value = ([], None, None) + + deleted, failed = await _hard_delete_edges( + driver, ["nonexistent-uuid"], "test-user" + ) + assert deleted == [] + assert failed == ["nonexistent-uuid"] + # Only the delete query should run — cleanup skipped + assert driver.execute_query.call_count == 1 diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py index 27f47a6b29..0aef554bbf 100644 --- a/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_search.py @@ -7,6 +7,7 @@ from typing import Any from backend.copilot.graphiti._format import ( extract_episode_body, + extract_episode_body_raw, extract_episode_timestamp, extract_fact, extract_temporal_validity, @@ -52,6 +53,15 @@ class MemorySearchTool(BaseTool): "description": "Maximum number of results to return", "default": 15, }, + "scope": { + "type": "string", + "description": ( + "Optional scope filter. When set, only memories matching " + "this scope are returned (hard filter). " + "Examples: 'real:global', 'project:crm', 'book:my-novel'. " + "Omit to search all scopes." + ), + }, }, "required": ["query"], } @@ -67,6 +77,7 @@ class MemorySearchTool(BaseTool): *, query: str = "", limit: int = 15, + scope: str = "", **kwargs, ) -> ToolResponseBase: if not user_id: @@ -122,7 +133,14 @@ class MemorySearchTool(BaseTool): ) facts = _format_edges(edges) - recent = _format_episodes(episodes) + + # Scope hard-filter: if a scope was requested, filter episodes + # whose MemoryEnvelope JSON contains a different scope. + # Skip redundant _format_episodes() when scope is set. + if scope: + recent = _filter_episodes_by_scope(episodes, scope) + else: + recent = _format_episodes(episodes) if not facts and not recent: return MemorySearchResponse( @@ -132,9 +150,10 @@ class MemorySearchTool(BaseTool): recent_episodes=[], ) + scope_note = f" (scope filter: {scope})" if scope else "" return MemorySearchResponse( message=( - f"Found {len(facts)} relationship facts and {len(recent)} stored memories. " + f"Found {len(facts)} relationship facts and {len(recent)} stored memories{scope_note}. " "Use BOTH sections to answer — stored memories often contain operational " "rules and instructions that relationship facts summarize." ), @@ -160,3 +179,35 @@ def _format_episodes(episodes) -> list[str]: body = extract_episode_body(ep) results.append(f"[{ts}] {body}") return results + + +def _filter_episodes_by_scope(episodes, scope: str) -> list[str]: + """Filter episodes by scope — hard filter on MemoryEnvelope JSON content. + + Episodes that are plain conversation text (not JSON envelopes) are + included by default since they have no scope metadata and belong + to the implicit ``real:global`` scope. + + Uses ``extract_episode_body_raw`` (no truncation) for JSON parsing + so that long MemoryEnvelope payloads are parsed correctly. + """ + import json + + results = [] + for ep in episodes: + raw_body = extract_episode_body_raw(ep) + try: + data = json.loads(raw_body) + if not isinstance(data, dict): + raise TypeError("non-dict JSON") + ep_scope = data.get("scope", "real:global") + if ep_scope != scope: + continue + except (json.JSONDecodeError, TypeError): + # Not JSON or non-dict JSON — plain conversation episode, treat as real:global + if scope != "real:global": + continue + display_body = extract_episode_body(ep) + ts = extract_episode_timestamp(ep) + results.append(f"[{ts}] {display_body}") + return results diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_search_test.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_search_test.py new file mode 100644 index 0000000000..99e2de78ea --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_search_test.py @@ -0,0 +1,64 @@ +"""Tests for graphiti_search helper functions.""" + +from types import SimpleNamespace + +from backend.copilot.graphiti.memory_model import MemoryEnvelope, MemoryKind, SourceKind +from backend.copilot.tools.graphiti_search import ( + _filter_episodes_by_scope, + _format_episodes, +) + + +class TestFilterEpisodesByScopeTruncation: + """extract_episode_body() truncates to 500 chars. A MemoryEnvelope + with a long content field exceeds that limit, producing invalid JSON. + _filter_episodes_by_scope then treats it as a plain-text episode + (real:global), leaking project-scoped data into global results. + """ + + def test_long_envelope_filtered_by_scope(self) -> None: + envelope = MemoryEnvelope( + content="x" * 600, + source_kind=SourceKind.user_asserted, + scope="project:crm", + memory_kind=MemoryKind.fact, + ) + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + # Requesting real:global scope — this project:crm episode should be excluded + results = _filter_episodes_by_scope([ep], "real:global") + assert ( + results == [] + ), f"project-scoped episode leaked into global results: {results}" + + def test_short_envelope_filtered_correctly(self) -> None: + """Short envelopes (under 500 chars) are parsed correctly.""" + envelope = MemoryEnvelope( + content="short note", + scope="project:crm", + ) + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + results = _filter_episodes_by_scope([ep], "real:global") + assert results == [] + + +class TestRedundantFormatting: + """_format_episodes is called even when scope filter will overwrite it. + Not a correctness bug, but verify the scope path doesn't depend on it. + """ + + def test_scope_filter_independent_of_format_episodes(self) -> None: + envelope = MemoryEnvelope(content="note", scope="real:global") + ep = SimpleNamespace( + content=envelope.model_dump_json(), + created_at="2025-01-01T00:00:00Z", + ) + from_format = _format_episodes([ep]) + from_scope = _filter_episodes_by_scope([ep], "real:global") + assert len(from_format) == 1 + assert len(from_scope) == 1 diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py index 6e75eb2ed4..3112820e54 100644 --- a/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_store.py @@ -5,6 +5,15 @@ from typing import Any from backend.copilot.graphiti.config import is_enabled_for_user from backend.copilot.graphiti.ingest import enqueue_episode +from backend.copilot.graphiti.memory_model import ( + MemoryEnvelope, + MemoryKind, + MemoryStatus, + ProcedureMemory, + ProcedureStep, + RuleMemory, + SourceKind, +) from backend.copilot.model import ChatSession from .base import BaseTool @@ -26,7 +35,7 @@ class MemoryStoreTool(BaseTool): "Store a memory or fact about the user for future recall. " "Use when the user shares preferences, business context, decisions, " "relationships, or other important information worth remembering " - "across sessions." + "across sessions. Supports optional metadata for scoping and classification." ) @property @@ -47,6 +56,94 @@ class MemoryStoreTool(BaseTool): "description": "Context about where this info came from", "default": "Conversation memory", }, + "source_kind": { + "type": "string", + "enum": [e.value for e in SourceKind], + "description": "Who asserted this: user_asserted (default), assistant_derived, or tool_observed", + "default": "user_asserted", + }, + "scope": { + "type": "string", + "description": "Namespace for this memory: 'real:global' (default), 'project:<name>', 'book:<title>'", + "default": "real:global", + }, + "memory_kind": { + "type": "string", + "enum": [e.value for e in MemoryKind], + "description": "Type of memory: fact (default), preference, rule, finding, plan, event, procedure", + "default": "fact", + }, + "rule": { + "type": "object", + "description": ( + "Structured rule data — use when memory_kind=rule to preserve " + "exact operational instructions. Example: " + '{"instruction": "CC Sarah on client communications", ' + '"actor": "Sarah", "trigger": "client-related communications"}' + ), + "properties": { + "instruction": { + "type": "string", + "description": "The actionable instruction", + }, + "actor": { + "type": "string", + "description": "Who performs or is subject to the rule", + }, + "trigger": { + "type": "string", + "description": "When the rule applies", + }, + "negation": { + "type": "string", + "description": "What NOT to do, if applicable", + }, + }, + "required": ["instruction"], + }, + "procedure": { + "type": "object", + "description": ( + "Structured procedure data — use when memory_kind=procedure " + "for multi-step workflows with ordering, tools, and conditions." + ), + "properties": { + "description": { + "type": "string", + "description": "What this procedure accomplishes", + }, + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "order": { + "type": "integer", + "description": "Step number", + }, + "action": { + "type": "string", + "description": "What to do", + }, + "tool": { + "type": "string", + "description": "Tool or service to use", + }, + "condition": { + "type": "string", + "description": "When this step applies", + }, + "negation": { + "type": "string", + "description": "What NOT to do", + }, + }, + "required": ["order", "action"], + }, + }, + }, + "required": ["description", "steps"], + }, }, "required": ["name", "content"], } @@ -63,6 +160,11 @@ class MemoryStoreTool(BaseTool): name: str = "", content: str = "", source_description: str = "Conversation memory", + source_kind: str = "user_asserted", + scope: str = "real:global", + memory_kind: str = "fact", + rule: dict | None = None, + procedure: dict | None = None, **kwargs, ) -> ToolResponseBase: if not user_id: @@ -83,12 +185,53 @@ class MemoryStoreTool(BaseTool): session_id=session.session_id, ) + rule_model = None + if rule and memory_kind == "rule": + try: + rule_model = RuleMemory(**rule) + except Exception: + logger.warning("Invalid rule data, storing as plain fact") + memory_kind = "fact" + + procedure_model = None + if procedure and memory_kind == "procedure": + try: + steps = [ProcedureStep(**s) for s in procedure.get("steps", [])] + procedure_model = ProcedureMemory( + description=procedure.get("description", content), + steps=steps, + ) + except Exception: + logger.warning("Invalid procedure data, storing as plain fact") + memory_kind = "fact" + + try: + resolved_source = SourceKind(source_kind) + except ValueError: + resolved_source = SourceKind.user_asserted + try: + resolved_kind = MemoryKind(memory_kind) + except ValueError: + resolved_kind = MemoryKind.fact + + envelope = MemoryEnvelope( + content=content, + source_kind=resolved_source, + scope=scope, + memory_kind=resolved_kind, + status=MemoryStatus.active, + provenance=session.session_id, + rule=rule_model, + procedure=procedure_model, + ) + queued = await enqueue_episode( user_id, session.session_id, name=name, - episode_body=content, + episode_body=envelope.model_dump_json(), source_description=source_description, + is_json=True, ) if not queued: diff --git a/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py b/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py index 3742355d76..21224d39c0 100644 --- a/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/graphiti_store_test.py @@ -1,5 +1,6 @@ """Tests for MemoryStoreTool.""" +import json from datetime import UTC, datetime from unittest.mock import AsyncMock, patch @@ -153,13 +154,14 @@ class TestMemoryStoreTool: assert "queued for storage" in result.message assert result.session_id == "test-session" - mock_enqueue.assert_awaited_once_with( - "user-1", - "test-session", - name="user_prefers_python", - episode_body="The user prefers Python over JavaScript.", - source_description="Direct statement", - ) + mock_enqueue.assert_awaited_once() + call_kwargs = mock_enqueue.await_args.kwargs + assert call_kwargs["name"] == "user_prefers_python" + assert call_kwargs["source_description"] == "Direct statement" + assert call_kwargs["is_json"] is True + envelope = json.loads(call_kwargs["episode_body"]) + assert envelope["content"] == "The user prefers Python over JavaScript." + assert envelope["memory_kind"] == "fact" @pytest.mark.asyncio async def test_store_success_uses_default_source_description(self): @@ -187,10 +189,132 @@ class TestMemoryStoreTool: ) assert isinstance(result, MemoryStoreResponse) - mock_enqueue.assert_awaited_once_with( - "user-1", - "test-session", - name="some_fact", - episode_body="A fact worth remembering.", - source_description="Conversation memory", - ) + mock_enqueue.assert_awaited_once() + call_kwargs = mock_enqueue.await_args.kwargs + assert call_kwargs["name"] == "some_fact" + assert call_kwargs["source_description"] == "Conversation memory" + assert call_kwargs["is_json"] is True + envelope = json.loads(call_kwargs["episode_body"]) + assert envelope["content"] == "A fact worth remembering." + + @pytest.mark.asyncio + async def test_store_invalid_source_kind_falls_back(self): + """Invalid enum values should fall back to defaults, not crash.""" + tool = MemoryStoreTool() + session = _make_session() + + mock_enqueue = AsyncMock() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + mock_enqueue, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="some_fact", + content="A fact.", + source_kind="INVALID_SOURCE", + memory_kind="INVALID_KIND", + ) + + assert isinstance(result, MemoryStoreResponse) + envelope = json.loads(mock_enqueue.await_args.kwargs["episode_body"]) + assert envelope["source_kind"] == "user_asserted" + assert envelope["memory_kind"] == "fact" + + @pytest.mark.asyncio + async def test_store_valid_enum_values_preserved(self): + tool = MemoryStoreTool() + session = _make_session() + + mock_enqueue = AsyncMock() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + mock_enqueue, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="rule_1", + content="Always CC Sarah.", + source_kind="user_asserted", + memory_kind="rule", + ) + + assert isinstance(result, MemoryStoreResponse) + envelope = json.loads(mock_enqueue.await_args.kwargs["episode_body"]) + assert envelope["source_kind"] == "user_asserted" + assert envelope["memory_kind"] == "rule" + + @pytest.mark.asyncio + async def test_store_queue_full_returns_error(self): + tool = MemoryStoreTool() + session = _make_session() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + new_callable=AsyncMock, + return_value=False, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="pref", + content="likes python", + ) + + assert isinstance(result, ErrorResponse) + assert "queue" in result.message.lower() + + @pytest.mark.asyncio + async def test_store_with_scope(self): + tool = MemoryStoreTool() + session = _make_session() + + mock_enqueue = AsyncMock() + + with ( + patch( + "backend.copilot.tools.graphiti_store.is_enabled_for_user", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "backend.copilot.tools.graphiti_store.enqueue_episode", + mock_enqueue, + ), + ): + result = await tool._execute( + user_id="user-1", + session=session, + name="project_note", + content="CRM uses PostgreSQL.", + scope="project:crm", + ) + + assert isinstance(result, MemoryStoreResponse) + envelope = json.loads(mock_enqueue.await_args.kwargs["episode_body"]) + assert envelope["scope"] == "project:crm" diff --git a/autogpt_platform/backend/backend/copilot/tools/models.py b/autogpt_platform/backend/backend/copilot/tools/models.py index bf211e2da7..90aa3d51db 100644 --- a/autogpt_platform/backend/backend/copilot/tools/models.py +++ b/autogpt_platform/backend/backend/copilot/tools/models.py @@ -84,6 +84,8 @@ class ResponseType(str, Enum): # Graphiti memory MEMORY_STORE = "memory_store" MEMORY_SEARCH = "memory_search" + MEMORY_FORGET_CANDIDATES = "memory_forget_candidates" + MEMORY_FORGET_CONFIRM = "memory_forget_confirm" # Base response model @@ -712,3 +714,18 @@ class MemorySearchResponse(ToolResponseBase): type: ResponseType = ResponseType.MEMORY_SEARCH facts: list[str] = Field(default_factory=list) recent_episodes: list[str] = Field(default_factory=list) + + +class MemoryForgetCandidatesResponse(ToolResponseBase): + """Response with candidate memories to forget.""" + + type: ResponseType = ResponseType.MEMORY_FORGET_CANDIDATES + candidates: list[dict[str, str]] = Field(default_factory=list) + + +class MemoryForgetConfirmResponse(ToolResponseBase): + """Response after deleting specific memory edges.""" + + type: ResponseType = ResponseType.MEMORY_FORGET_CONFIRM + deleted_uuids: list[str] = Field(default_factory=list) + failed_uuids: list[str] = Field(default_factory=list) diff --git a/autogpt_platform/backend/backend/copilot/transcript.py b/autogpt_platform/backend/backend/copilot/transcript.py index a1e11f352d..c4d3de28af 100644 --- a/autogpt_platform/backend/backend/copilot/transcript.py +++ b/autogpt_platform/backend/backend/copilot/transcript.py @@ -1,10 +1,10 @@ """JSONL transcript management for stateless multi-turn resume. The Claude Code CLI persists conversations as JSONL files (one JSON object per -line). When the SDK's ``Stop`` hook fires we read this file, strip bloat -(progress entries, metadata), and upload the result to bucket storage. On the -next turn we download the transcript, write it to a temp file, and pass -``--resume`` so the CLI can reconstruct the full conversation. +line). When the SDK's ``Stop`` hook fires the caller reads this file, strips +bloat (progress entries, metadata), and uploads the result to bucket storage. +On the next turn the caller downloads the bytes and writes them to disk before +passing ``--resume`` so the CLI can reconstruct the full conversation. Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local filesystem for self-hosted) — no DB column needed. @@ -20,6 +20,7 @@ import shutil import time from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING, Literal from uuid import uuid4 from backend.util import json @@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client from backend.util.prompt import CompressResult, compress_context from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage +if TYPE_CHECKING: + from .model import ChatMessage + logger = logging.getLogger(__name__) # UUIDs are hex + hyphens; strip everything else to prevent path injection. @@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset( ) +TranscriptMode = Literal["sdk", "baseline"] + + @dataclass class TranscriptDownload: - """Result of downloading a transcript with its metadata.""" - - content: str - message_count: int = 0 # session.messages length when uploaded - uploaded_at: float = 0.0 # epoch timestamp of upload + content: bytes | str + message_count: int = 0 + # "sdk" = Claude CLI native, "baseline" = TranscriptBuilder + mode: TranscriptMode = "sdk" -# Workspace storage constants — deterministic path from session_id. -TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" # Storage prefix for the CLI's native session JSONL files (for cross-pod --resume). _CLI_SESSION_STORAGE_PREFIX = "cli-sessions" @@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str: _SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") -def _projects_base() -> str: +def projects_base() -> str: """Return the resolved path to the CLI's projects directory.""" config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude") return os.path.realpath(os.path.join(config_dir, "projects")) @@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: Returns the number of directories removed. """ - projects_base = _projects_base() - if not os.path.isdir(projects_base): + _pbase = projects_base() + if not os.path.isdir(_pbase): return 0 now = time.time() @@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: # Scoped mode: only clean up the one directory for the current session. if encoded_cwd: - target = Path(projects_base) / encoded_cwd + target = Path(_pbase) / encoded_cwd if not target.is_dir(): return 0 # Guard: only sweep copilot-generated dirs. @@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: # Only safe for single-tenant deployments; callers should prefer the # scoped variant by passing encoded_cwd. try: - entries = Path(projects_base).iterdir() + entries = Path(_pbase).iterdir() except OSError as e: logger.warning("[Transcript] Failed to list projects dir: %s", e) return 0 @@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None: if not transcript_path: return None - projects_base = _projects_base() + _pbase = projects_base() real_path = os.path.realpath(transcript_path) - if not real_path.startswith(projects_base + os.sep): + if not real_path.startswith(_pbase + os.sep): logger.warning( "[Transcript] transcript_path outside projects base: %s", transcript_path ) @@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool: # --------------------------------------------------------------------------- -def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: - """Return (workspace_id, file_id, filename) for a session's transcript. - - Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl`` - IDs are sanitized to hex+hyphen to prevent path traversal. - """ - return ( - TRANSCRIPT_STORAGE_PREFIX, - _sanitize_id(user_id), - f"{_sanitize_id(session_id)}.jsonl", - ) - - -def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: - """Return (workspace_id, file_id, filename) for a session's transcript metadata.""" - return ( - TRANSCRIPT_STORAGE_PREFIX, - _sanitize_id(user_id), - f"{_sanitize_id(session_id)}.meta.json", - ) - - def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str: """Build a full storage path from (workspace_id, file_id, filename) parts.""" wid, fid, fname = parts @@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str: return f"local://{wid}/{fid}/{fname}" -def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: - """Build the full storage path string that ``retrieve()`` expects.""" - return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend) - - -def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str: - """Build the full storage path for the companion .meta.json file.""" - return _build_path_from_parts( - _meta_storage_path_parts(user_id, session_id), backend - ) - - # --------------------------------------------------------------------------- # CLI native session file — cross-pod --resume support # --------------------------------------------------------------------------- -def _cli_session_path(sdk_cwd: str, session_id: str) -> str: +def cli_session_path(sdk_cwd: str, session_id: str) -> str: """Expected path of the CLI's native session JSONL file. The CLI resolves the working directory via ``os.path.realpath``, then @@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str: """ encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) safe_id = _sanitize_id(session_id) - return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl") + return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl") def _cli_session_storage_path_parts( @@ -689,209 +659,82 @@ def _cli_session_storage_path_parts( ) -async def upload_cli_session( - user_id: str, - session_id: str, - sdk_cwd: str, - log_prefix: str = "[Transcript]", -) -> None: - """Upload the CLI's native session JSONL file to remote storage. - - Called after each turn so the next turn can restore the file on any pod - (eliminating the pod-affinity requirement for --resume). - - The CLI only writes the session file after the turn completes, so this - must run in the finally block, AFTER the SDK stream has finished. - """ - session_file = _cli_session_path(sdk_cwd, session_id) - real_path = os.path.realpath(session_file) - projects_base = _projects_base() - - if not real_path.startswith(projects_base + os.sep): - logger.warning( - "%s CLI session file outside projects base, skipping upload: %s", - log_prefix, - os.path.basename(real_path), - ) - return - - try: - content = Path(real_path).read_bytes() - except FileNotFoundError: - logger.debug( - "%s CLI session file not found, skipping upload: %s", - log_prefix, - session_file, - ) - return - except OSError as e: - logger.warning("%s Failed to read CLI session file: %s", log_prefix, e) - return - - storage = await get_workspace_storage() - wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id) - try: - await storage.store( - workspace_id=wid, file_id=fid, filename=fname, content=content - ) - logger.info( - "%s Uploaded CLI session file (%dB) for cross-pod --resume", - log_prefix, - len(content), - ) - except Exception as e: - logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e) - - -async def restore_cli_session( - user_id: str, - session_id: str, - sdk_cwd: str, - log_prefix: str = "[Transcript]", -) -> bool: - """Download and restore the CLI's native session file for --resume. - - Returns True if the file was successfully restored and --resume can be - used with the session UUID. Returns False if not available (first turn - or upload failed), in which case the caller should not set --resume. - """ - session_file = _cli_session_path(sdk_cwd, session_id) - real_path = os.path.realpath(session_file) - projects_base = _projects_base() - - if not real_path.startswith(projects_base + os.sep): - logger.warning( - "%s CLI session restore path outside projects base: %s", - log_prefix, - os.path.basename(session_file), - ) - return False - - # If the session file already exists locally (same-pod reuse), use it directly. - # Downloading from storage could overwrite a newer local version when a previous - # turn's upload failed: stored content is stale while the local file already - # contains extended history from that turn. - if Path(real_path).exists(): - logger.debug( - "%s CLI session file already exists locally — using it for --resume", - log_prefix, - ) - return True - - storage = await get_workspace_storage() - path = _build_path_from_parts( - _cli_session_storage_path_parts(user_id, session_id), storage +def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: + """Return (workspace_id, file_id, filename) for the CLI session meta file.""" + return ( + _CLI_SESSION_STORAGE_PREFIX, + _sanitize_id(user_id), + f"{_sanitize_id(session_id)}.meta.json", ) - try: - content = await storage.retrieve(path) - except FileNotFoundError: - logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix) - return False - except Exception as e: - logger.warning("%s Failed to download CLI session: %s", log_prefix, e) - return False - - try: - os.makedirs(os.path.dirname(real_path), exist_ok=True) - Path(real_path).write_bytes(content) - logger.info( - "%s Restored CLI session file (%dB) for --resume", - log_prefix, - len(content), - ) - return True - except OSError as e: - logger.warning("%s Failed to write CLI session file: %s", log_prefix, e) - return False - async def upload_transcript( user_id: str, session_id: str, - content: str, + content: bytes, message_count: int = 0, + mode: TranscriptMode = "sdk", log_prefix: str = "[Transcript]", - skip_strip: bool = False, ) -> None: - """Strip progress entries and stale thinking blocks, then upload transcript. + """Upload CLI session content to GCS with companion meta.json. - The transcript represents the FULL active context (atomic). - Each upload REPLACES the previous transcript entirely. + Pure GCS operation — no disk I/O. The caller is responsible for reading + the session file from disk before calling this function. - The executor holds a cluster lock per session, so concurrent uploads for - the same session cannot happen. + Also uploads a companion .meta.json with the message_count watermark so + download_transcript can return it without a separate fetch. - Args: - content: Complete JSONL transcript (from TranscriptBuilder). - message_count: ``len(session.messages)`` at upload time. - skip_strip: When ``True``, skip the strip + re-validate pass. - Safe for builder-generated content (baseline path) which - never emits progress entries or stale thinking blocks. + Called after each turn so the next turn can restore the file on any pod + (eliminating the pod-affinity requirement for --resume). """ - if skip_strip: - # Caller guarantees the content is already clean and valid. - stripped = content - else: - # Strip metadata entries and stale thinking blocks in a single parse. - # SDK-built transcripts may have progress entries; strip for safety. - stripped = strip_for_upload(content) - if not skip_strip and not validate_transcript(stripped): - # Log entry types for debugging — helps identify why validation failed - entry_types = [ - json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?") - for line in stripped.strip().split("\n") - ] - logger.warning( - "%s Skipping upload — stripped content not valid " - "(types=%s, stripped_len=%d, raw_len=%d)", - log_prefix, - entry_types, - len(stripped), - len(content), - ) - logger.debug("%s Raw content preview: %s", log_prefix, content[:500]) - logger.debug("%s Stripped content: %s", log_prefix, stripped[:500]) - return - storage = await get_workspace_storage() - wid, fid, fname = _storage_path_parts(user_id, session_id) - encoded = stripped.encode("utf-8") - meta = {"message_count": message_count, "uploaded_at": time.time()} - mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id) + wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id) + mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id) + meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()} meta_encoded = json.dumps(meta).encode("utf-8") - # Transcript + metadata are independent objects at different keys, so - # write them concurrently. ``return_exceptions`` keeps a metadata - # failure from sinking the transcript write. - transcript_result, metadata_result = await asyncio.gather( - storage.store( - workspace_id=wid, - file_id=fid, - filename=fname, - content=encoded, - ), - storage.store( - workspace_id=mwid, - file_id=mfid, - filename=mfname, - content=meta_encoded, - ), - return_exceptions=True, - ) - if isinstance(transcript_result, BaseException): - raise transcript_result - if isinstance(metadata_result, BaseException): - # Metadata is best-effort — the gap-fill logic in - # _build_query_message tolerates a missing metadata file. - logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result) + # Write JSONL first, meta second — sequential so a crash between the two + # leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong + # watermark / mode paired with stale or absent content). + # On any failure we roll back the other file so the pair is always absent + # together; download_transcript returns None when either file is missing. + try: + await storage.store( + workspace_id=wid, file_id=fid, filename=fname, content=content + ) + except Exception as session_err: + logger.warning( + "%s Failed to upload CLI session file: %s", log_prefix, session_err + ) + return + + try: + await storage.store( + workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded + ) + except Exception as meta_err: + logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err) + # Roll back the JSONL so neither file exists — avoids orphaned JSONL being + # used with wrong mode/watermark defaults on the next restore. + try: + session_path = _build_path_from_parts( + _cli_session_storage_path_parts(user_id, session_id), storage + ) + await storage.delete(session_path) + except Exception as rollback_err: + logger.debug( + "%s Session rollback failed (harmless — download will return None): %s", + log_prefix, + rollback_err, + ) + return logger.info( - "%s Uploaded %dB (stripped from %dB, msg_count=%d)", + "%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)", log_prefix, - len(encoded), len(content), message_count, + mode, ) @@ -900,83 +743,173 @@ async def download_transcript( session_id: str, log_prefix: str = "[Transcript]", ) -> TranscriptDownload | None: - """Download transcript and metadata from bucket storage. + """Download CLI session from GCS. Returns content + message_count + mode, or None if not found. - Returns a ``TranscriptDownload`` with the JSONL content and the - ``message_count`` watermark from the upload, or ``None`` if not found. + Pure GCS operation — no disk I/O. The caller is responsible for writing + content to disk if --resume is needed. - The content and metadata fetches run concurrently since they are - independent objects in the bucket. + Returns a TranscriptDownload with the raw content, message_count watermark, + and mode on success, or None if not available (first turn or upload failed). """ storage = await get_workspace_storage() - path = _build_storage_path(user_id, session_id, storage) - meta_path = _build_meta_storage_path(user_id, session_id, storage) + path = _build_path_from_parts( + _cli_session_storage_path_parts(user_id, session_id), storage + ) + meta_path = _build_path_from_parts( + _cli_session_meta_path_parts(user_id, session_id), storage + ) - content_task = asyncio.create_task(storage.retrieve(path)) - meta_task = asyncio.create_task(storage.retrieve(meta_path)) content_result, meta_result = await asyncio.gather( - content_task, meta_task, return_exceptions=True + storage.retrieve(path), + storage.retrieve(meta_path), + return_exceptions=True, ) if isinstance(content_result, FileNotFoundError): - logger.debug("%s No transcript in storage", log_prefix) + logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix) return None if isinstance(content_result, BaseException): logger.warning( - "%s Failed to download transcript: %s", log_prefix, content_result + "%s Failed to download CLI session: %s", log_prefix, content_result ) return None - content = content_result.decode("utf-8") + content: bytes = content_result - # Metadata is best-effort — old transcripts won't have it. + # Parse message_count and mode from companion meta — best-effort, defaults. message_count = 0 - uploaded_at = 0.0 + mode: TranscriptMode = "sdk" if isinstance(meta_result, FileNotFoundError): - pass # No metadata — treat as unknown (msg_count=0 → always fill gap) + pass # No meta — old upload; default to "sdk" elif isinstance(meta_result, BaseException): - logger.debug( - "%s Failed to load transcript metadata: %s", log_prefix, meta_result - ) + logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result) else: - meta = json.loads(meta_result.decode("utf-8"), fallback={}) - message_count = meta.get("message_count", 0) - uploaded_at = meta.get("uploaded_at", 0.0) + try: + meta_str = meta_result.decode("utf-8") + except UnicodeDecodeError: + logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix) + meta_str = None + if meta_str is not None: + meta = json.loads(meta_str, fallback={}) + if isinstance(meta, dict): + raw_count = meta.get("message_count", 0) + message_count = ( + raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0 + ) + raw_mode = meta.get("mode", "sdk") + mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk" logger.info( - "%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count - ) - return TranscriptDownload( - content=content, - message_count=message_count, - uploaded_at=uploaded_at, + "%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)", + log_prefix, + len(content), + message_count, + mode, ) + return TranscriptDownload(content=content, message_count=message_count, mode=mode) + + +def detect_gap( + download: TranscriptDownload, + session_messages: list[ChatMessage], +) -> list[ChatMessage]: + """Return chat-db messages after the transcript watermark (excluding current user turn). + + Returns [] if transcript is current, watermark is zero, or the watermark + position doesn't end on an assistant turn (misaligned watermark). + """ + if download.message_count == 0: + return [] + wm = download.message_count + total = len(session_messages) + if wm >= total - 1: + return [] + # Sanity: position wm-1 should be an assistant turn; misaligned watermark + # means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context. + # In normal operation ``message_count`` is always written after a complete + # user→assistant exchange (never mid-turn), so the last covered position is + # always assistant. This guard fires only on data corruption or message deletion. + if session_messages[wm - 1].role != "assistant": + return [] + return list(session_messages[wm : total - 1]) + + +def extract_context_messages( + download: TranscriptDownload | None, + session_messages: "list[ChatMessage]", +) -> "list[ChatMessage]": + """Return context messages for the current turn: transcript content + gap. + + This is the shared context primitive used by both the SDK path + (``use_resume=False`` → ``<conversation_history>`` injection) and the + baseline path (OpenAI messages array). + + How it works: + + - When a transcript exists, ``TranscriptBuilder.load_previous`` preserves + ``isCompactSummary=True`` compaction entries, so the returned messages + mirror the compacted context the CLI would see via ``--resume``. + - The gap (DB messages after the transcript watermark) is always small in + normal operation; it only grows during mode switches or when an upload + was missed. + - Falls back to full DB messages when no transcript exists (first turn, + upload failure, or GCS unavailable). + - Returns *prior* messages only (excluding the current user turn at + ``session_messages[-1]``). Callers that need the current turn append + ``session_messages[-1]`` themselves. + - **Tool calls from transcript entries are flattened to text**: assistant + messages derived from the JSONL use ``_flatten_assistant_content``, which + serialises ``tool_use`` blocks as human-readable text rather than + structured ``tool_calls``. Gap messages (from DB) preserve their + original ``tool_calls`` field. This is the same trade-off as the old + ``_compress_session_messages(session.messages)`` approach — no regression. + + Args: + download: The ``TranscriptDownload`` from GCS, or ``None`` when no + transcript is available. ``content`` may be either ``bytes`` or + ``str`` (the baseline path decodes + strips before returning). + session_messages: All messages in the session, with the current user + turn as the last element. + + Returns: + A list of ``ChatMessage`` objects covering the prior conversation + context, suitable for injection as conversation history. + """ + from .model import ChatMessage as _ChatMessage # runtime import + + prior = session_messages[:-1] + + if download is None: + return prior + + raw_content = download.content + if not raw_content: + return prior + + # Handle both bytes (raw GCS download) and str (pre-decoded baseline path). + if isinstance(raw_content, bytes): + try: + content_str: str = raw_content.decode("utf-8") + except UnicodeDecodeError: + return prior + else: + content_str = raw_content + + raw = _transcript_to_messages(content_str) + if not raw: + return prior + + transcript_msgs = [ + _ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw + ] + gap = detect_gap(download, session_messages) + return transcript_msgs + gap async def delete_transcript(user_id: str, session_id: str) -> None: - """Delete transcript and its metadata from bucket storage. - - Removes both the ``.jsonl`` transcript and the companion ``.meta.json`` - so stale ``message_count`` watermarks cannot corrupt gap-fill logic. - """ + """Delete CLI session JSONL and its companion .meta.json from bucket storage.""" storage = await get_workspace_storage() - path = _build_storage_path(user_id, session_id, storage) - try: - await storage.delete(path) - logger.info("[Transcript] Deleted transcript for session %s", session_id) - except Exception as e: - logger.warning("[Transcript] Failed to delete transcript: %s", e) - - # Also delete the companion .meta.json to avoid orphaned metadata. - try: - meta_path = _build_meta_storage_path(user_id, session_id, storage) - await storage.delete(meta_path) - logger.info("[Transcript] Deleted metadata for session %s", session_id) - except Exception as e: - logger.warning("[Transcript] Failed to delete metadata: %s", e) - - # Also delete the CLI native session file to prevent storage growth. try: cli_path = _build_path_from_parts( _cli_session_storage_path_parts(user_id, session_id), storage @@ -986,6 +919,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None: except Exception as e: logger.warning("[Transcript] Failed to delete CLI session: %s", e) + try: + cli_meta_path = _build_path_from_parts( + _cli_session_meta_path_parts(user_id, session_id), storage + ) + await storage.delete(cli_meta_path) + logger.info("[Transcript] Deleted CLI session meta for session %s", session_id) + except Exception as e: + logger.warning("[Transcript] Failed to delete CLI session meta: %s", e) + # --------------------------------------------------------------------------- # Transcript compaction — LLM summarization for prompt-too-long recovery diff --git a/autogpt_platform/backend/backend/copilot/transcript_test.py b/autogpt_platform/backend/backend/copilot/transcript_test.py index fec869b6ac..15ed9662be 100644 --- a/autogpt_platform/backend/backend/copilot/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/transcript_test.py @@ -16,11 +16,11 @@ from .transcript import ( _flatten_assistant_content, _flatten_tool_result_content, _messages_to_transcript, - _meta_storage_path_parts, _rechain_tail, _sanitize_id, - _storage_path_parts, _transcript_to_messages, + detect_gap, + extract_context_messages, strip_for_upload, validate_transcript, ) @@ -64,24 +64,6 @@ class TestSanitizeId: assert _sanitize_id("!@#$%^&*()") == "unknown" -# --------------------------------------------------------------------------- -# _storage_path_parts / _meta_storage_path_parts -# --------------------------------------------------------------------------- - - -class TestStoragePathParts: - def test_returns_triple(self): - prefix, uid, fname = _storage_path_parts("user-1", "sess-2") - assert prefix == "chat-transcripts" - assert "e" in uid # hex chars from "user-1" sanitized - assert fname.endswith(".jsonl") - - def test_meta_returns_meta_json(self): - prefix, _, fname = _meta_storage_path_parts("user-1", "sess-2") - assert prefix == "chat-transcripts" - assert fname.endswith(".meta.json") - - # --------------------------------------------------------------------------- # _build_path_from_parts # --------------------------------------------------------------------------- @@ -103,24 +85,6 @@ class TestBuildPathFromParts: assert path == "local://wid/fid/file.jsonl" -# --------------------------------------------------------------------------- -# TranscriptDownload dataclass -# --------------------------------------------------------------------------- - - -class TestTranscriptDownload: - def test_defaults(self): - td = TranscriptDownload(content="hello") - assert td.content == "hello" - assert td.message_count == 0 - assert td.uploaded_at == 0.0 - - def test_custom_values(self): - td = TranscriptDownload(content="data", message_count=5, uploaded_at=123.45) - assert td.message_count == 5 - assert td.uploaded_at == 123.45 - - # --------------------------------------------------------------------------- # _flatten_assistant_content # --------------------------------------------------------------------------- @@ -733,118 +697,241 @@ class TestValidateTranscript: class TestCliSessionPath: def test_encodes_slashes_to_dashes(self): - from .transcript import _cli_session_path, _projects_base + from .transcript import cli_session_path, projects_base sdk_cwd = "/tmp/copilot-abc" - result = _cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc") - base = _projects_base() + result = cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc") + base = projects_base() assert result.startswith(base) # Encoded cwd replaces '/' with '-' assert "-tmp-copilot-abc" in result assert result.endswith(".jsonl") def test_sanitizes_session_id(self): - from .transcript import _cli_session_path + from .transcript import cli_session_path - result = _cli_session_path("/tmp/cwd", "../../etc/passwd") + result = cli_session_path("/tmp/cwd", "../../etc/passwd") # _sanitize_id strips non-hex/hyphen chars; path traversal impossible assert ".." not in result assert "passwd" not in result class TestUploadCliSession: - def test_skips_upload_when_path_outside_projects_base(self, tmp_path): - """Files outside the CLI projects base are rejected without upload.""" + def test_uploads_content_bytes_successfully(self): + """Happy path: content bytes are stored as jsonl + meta.json.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import upload_cli_session + from .transcript import upload_transcript mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=str(tmp_path), - ), - # Return a path that is genuinely outside tmp_path so that - # realpath(session_file).startswith(projects_base + "/") is False - # and the boundary guard actually fires. - patch( - "backend.copilot.transcript._cli_session_path", - return_value="/outside/escaped/session.jsonl", - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, ): asyncio.run( - upload_cli_session( + upload_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), + session_id="12345678-0000-0000-0000-000000000001", + content=content, ) ) - # storage.store must NOT be called — boundary guard should reject the path - mock_storage.store.assert_not_called() + # Two calls expected: session JSONL + companion .meta.json + assert mock_storage.store.call_count == 2 - def test_skips_upload_when_file_not_found(self, tmp_path): - """Missing CLI session file logs debug and skips upload silently.""" + def test_uploads_companion_meta_json_with_message_count(self): + """upload_transcript stores a companion .meta.json with message_count.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import upload_transcript + + mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000010", + content=content, + message_count=5, + ) + ) + + assert mock_storage.store.call_count == 2 + # Find the meta.json store call + meta_call = next( + c + for c in mock_storage.store.call_args_list + if c.kwargs.get("filename", "").endswith(".meta.json") + ) + meta_content = json.loads(meta_call.kwargs["content"]) + assert meta_content["message_count"] == 5 + + def test_skips_upload_on_storage_failure(self): + """Storage exception on jsonl write is logged and does not propagate. + + With sequential writes, JSONL failure returns early — meta store is + never called, so no rollback is needed. + """ import asyncio from unittest.mock import AsyncMock, patch - from .transcript import upload_cli_session + from .transcript import upload_transcript mock_storage = AsyncMock() - projects_base = str(tmp_path) + mock_storage.store.side_effect = RuntimeError("gcs unavailable") + content = b'{"type":"assistant"}\n' - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, ): - # session file doesn't exist — should not raise + # Should not raise — failures are logged as warnings asyncio.run( - upload_cli_session( + upload_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), + session_id="12345678-0000-0000-0000-000000000002", + content=content, ) ) - mock_storage.store.assert_not_called() + # Only one store call attempted (the JSONL); meta never reached + mock_storage.store.assert_called_once() + mock_storage.delete.assert_not_called() - def test_uploads_file_successfully(self, tmp_path): - """Happy path: session file exists within projects base → upload called.""" + def test_rolls_back_session_when_meta_upload_fails(self): + """When meta upload fails after JSONL succeeds, JSONL is rolled back. + + Guarantees the pair is either both present or both absent — avoids an + orphaned JSONL being used with wrong mode/watermark defaults. + """ import asyncio from unittest.mock import AsyncMock, patch + from .transcript import upload_transcript + + mock_storage = AsyncMock() + # First store (JSONL) succeeds; second store (meta) fails + mock_storage.store.side_effect = [None, RuntimeError("meta write failed")] + content = b'{"type":"assistant"}\n' + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000099", + content=content, + ) + ) + + # Both store calls were attempted (JSONL then meta) + assert mock_storage.store.call_count == 2 + # JSONL should be rolled back via delete + mock_storage.delete.assert_called_once() + + def test_baseline_mode_stored_in_meta(self): + """upload_transcript with mode='baseline' stores mode in companion meta.json.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import upload_transcript + + mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000098", + content=content, + message_count=4, + mode="baseline", + ) + ) + + meta_call = next( + c + for c in mock_storage.store.call_args_list + if c.kwargs.get("filename", "").endswith(".meta.json") + ) + meta_content = json.loads(meta_call.kwargs["content"]) + assert meta_content["mode"] == "baseline" + assert meta_content["message_count"] == 4 + + def test_strips_session_before_upload_and_writes_back(self, tmp_path): + """Strippable entries (progress, thinking blocks) are removed before upload. + + The stripped content is written back to disk (so same-pod turns benefit) + and the smaller bytes are uploaded to GCS. + """ + import asyncio + import os + import re + from unittest.mock import AsyncMock, patch + from .transcript import _sanitize_id, upload_cli_session projects_base = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000001" + session_id = "12345678-0000-0000-0000-000000000010" sdk_cwd = str(tmp_path) - # Build the path the same way _cli_session_path does, but using our tmp_path - # as projects_base so the boundary check passes. - # Must use the same encoding: re.sub non-alphanumeric → "-" on realpath. - import os - import re - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) session_dir = tmp_path / encoded_cwd session_dir.mkdir(parents=True, exist_ok=True) session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" - session_file.write_bytes(b'{"type":"assistant"}\n') + + # A CLI session with a progress entry (strippable) and a real assistant message. + import json + + progress_entry = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress", "stdout": "running..."}, + } + user_entry = { + "type": "user", + "uuid": "u1", + "message": {"role": "user", "content": "hello"}, + } + asst_entry = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": {"role": "assistant", "content": "world"}, + } + raw_content = ( + json.dumps(progress_entry) + + "\n" + + json.dumps(user_entry) + + "\n" + + json.dumps(asst_entry) + + "\n" + ) + raw_bytes = raw_content.encode("utf-8") + session_file.write_bytes(raw_bytes) mock_storage = AsyncMock() @@ -867,68 +954,142 @@ class TestUploadCliSession: ) ) + # Upload should have been called with stripped bytes (no progress entry). mock_storage.store.assert_called_once() + stored_content: bytes = mock_storage.store.call_args.kwargs["content"] + stored_lines = stored_content.decode("utf-8").strip().split("\n") + stored_types = [json.loads(line).get("type") for line in stored_lines] + assert "progress" not in stored_types + assert "user" in stored_types + assert "assistant" in stored_types + # Stripped bytes should be smaller than raw. + assert len(stored_content) < len(raw_bytes) + # File on disk should also be the stripped version. + disk_content = session_file.read_bytes() + assert disk_content == stored_content - def test_skips_upload_on_oserror(self, tmp_path): - """OSError reading session file is logged as warning; upload is skipped.""" + def test_strips_stale_thinking_blocks_before_upload(self, tmp_path): + """Thinking blocks in non-last assistant turns are stripped to reduce size.""" import asyncio + import json + import os + import re from unittest.mock import AsyncMock, patch from .transcript import _sanitize_id, upload_cli_session projects_base = str(tmp_path) + session_id = "12345678-0000-0000-0000-000000000011" sdk_cwd = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000002" - - # Build file at a path inside projects_base so boundary check passes. - import os - import re encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) session_dir = tmp_path / encoded_cwd session_dir.mkdir(parents=True, exist_ok=True) session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" - session_file.write_bytes(b'{"type":"assistant"}\n') - # Remove read permission to trigger OSError - session_file.chmod(0o000) + + # Two turns: first assistant has thinking block (stale), second doesn't. + u1 = { + "type": "user", + "uuid": "u1", + "message": {"role": "user", "content": "q1"}, + } + a1_with_thinking = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": { + "id": "msg_a1", + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "A" * 5000}, + {"type": "text", "text": "answer1"}, + ], + }, + } + u2 = { + "type": "user", + "uuid": "u2", + "parentUuid": "a1", + "message": {"role": "user", "content": "q2"}, + } + a2_no_thinking = { + "type": "assistant", + "uuid": "a2", + "parentUuid": "u2", + "message": { + "id": "msg_a2", + "role": "assistant", + "content": [{"type": "text", "text": "answer2"}], + }, + } + raw_content = ( + json.dumps(u1) + + "\n" + + json.dumps(a1_with_thinking) + + "\n" + + json.dumps(u2) + + "\n" + + json.dumps(a2_no_thinking) + + "\n" + ) + raw_bytes = raw_content.encode("utf-8") + session_file.write_bytes(raw_bytes) mock_storage = AsyncMock() - try: - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - ): - asyncio.run( - upload_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) + with ( + patch( + "backend.copilot.transcript._projects_base", + return_value=projects_base, + ), + patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ), + ): + asyncio.run( + upload_cli_session( + user_id="user-1", + session_id=session_id, + sdk_cwd=sdk_cwd, ) - finally: - session_file.chmod(0o644) # restore so tmp_path cleanup works + ) - mock_storage.store.assert_not_called() + stored_content: bytes = mock_storage.store.call_args.kwargs["content"] + stored_lines = stored_content.decode("utf-8").strip().split("\n") + + # a1 should have its thinking block stripped (it's not the last assistant turn). + a1_stored = json.loads(stored_lines[1]) + a1_content = a1_stored["message"]["content"] + assert all( + b["type"] != "thinking" for b in a1_content + ), "stale thinking block should be stripped from a1" + assert any( + b["type"] == "text" for b in a1_content + ), "text block should be kept in a1" + + # a2 (last turn) should be unchanged. + a2_stored = json.loads(stored_lines[3]) + assert a2_stored["message"]["content"] == [{"type": "text", "text": "answer2"}] + + # Stripped bytes smaller than raw. + assert len(stored_content) < len(raw_bytes) class TestRestoreCliSession: - def test_returns_false_when_file_not_found_in_storage(self): - """Returns False (graceful degradation) when the session is missing.""" + def test_returns_none_when_file_not_found_in_storage(self): + """Returns None (graceful degradation) when the session is missing.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import restore_cli_session + from .transcript import download_transcript mock_storage = AsyncMock() - mock_storage.retrieve.side_effect = FileNotFoundError("not found") + mock_storage.retrieve.side_effect = [ + FileNotFoundError("no session"), + FileNotFoundError("no meta"), + ] with patch( "backend.copilot.transcript.get_workspace_storage", @@ -936,144 +1097,26 @@ class TestRestoreCliSession: return_value=mock_storage, ): result = asyncio.run( - restore_cli_session( + download_transcript( user_id="user-1", session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd="/tmp/copilot-test", ) ) - assert result is False + assert result is None - def test_returns_false_when_restore_path_outside_projects_base(self, tmp_path): - """Path traversal guard: rejects restoration outside the projects base.""" + def test_returns_transcript_download_on_success_no_meta(self): + """Happy path with no meta.json: returns TranscriptDownload with message_count=0.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import restore_cli_session + from .transcript import download_transcript - mock_storage = AsyncMock() - mock_storage.retrieve.return_value = b'{"type":"assistant"}\n' - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=str(tmp_path), - ), - # Return a path genuinely outside tmp_path so the boundary guard fires. - patch( - "backend.copilot.transcript._cli_session_path", - return_value="/outside/escaped/session.jsonl", - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), - ) - ) - - assert result is False - - def test_returns_true_when_local_file_already_exists(self, tmp_path): - """Same-pod reuse: if local file exists, skip storage download and return True.""" - import asyncio - import os - import re - from pathlib import Path - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - session_id = "12345678-0000-0000-0000-000000000099" - sdk_cwd = str(tmp_path) - - # Pre-create the local session file (simulates previous turn on same pod) - projects_base = os.path.realpath(str(tmp_path)) - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", projects_base) - session_dir = Path(projects_base) / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - existing_content = b'{"type":"user"}\n{"type":"assistant"}\n' - (session_dir / f"{session_id}.jsonl").write_bytes(existing_content) - - mock_storage = AsyncMock() - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - assert result is True - # Storage should NOT have been accessed (local file was used as-is) - mock_storage.retrieve.assert_not_called() - # Local file should be unchanged - assert (session_dir / f"{session_id}.jsonl").read_bytes() == existing_content - - def test_returns_true_on_success(self, tmp_path): - """Happy path: storage has the session → file written → returns True.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - projects_base = str(tmp_path) - sdk_cwd = str(tmp_path) session_id = "12345678-0000-0000-0000-000000000003" content = b'{"type":"assistant"}\n' mock_storage = AsyncMock() - mock_storage.retrieve.return_value = content - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - assert result is True - - def test_returns_false_on_download_exception(self): - """Non-FileNotFoundError during retrieve logs warning and returns False.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - mock_storage = AsyncMock() - mock_storage.retrieve.side_effect = RuntimeError("network error") + mock_storage.retrieve.side_effect = [content, FileNotFoundError("no meta")] with patch( "backend.copilot.transcript.get_workspace_storage", @@ -1081,11 +1124,411 @@ class TestRestoreCliSession: return_value=mock_storage, ): result = asyncio.run( - restore_cli_session( + download_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000004", - sdk_cwd="/tmp/copilot-test", + session_id=session_id, ) ) - assert result is False + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 0 + assert result.mode == "sdk" + + def test_returns_transcript_download_with_message_count_from_meta(self): + """When meta.json is present, message_count and mode are read from it.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + session_id = "12345678-0000-0000-0000-000000000005" + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps( + {"message_count": 7, "mode": "sdk", "uploaded_at": 1234567.0} + ).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id=session_id, + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 7 + assert result.mode == "sdk" + + def test_returns_none_on_download_exception(self): + """Non-FileNotFoundError during retrieve logs warning and returns None.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [ + RuntimeError("network error"), + FileNotFoundError("no meta"), + ] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000004", + ) + ) + + assert result is None + + def test_baseline_mode_in_meta_returned(self): + """When meta.json contains mode='baseline', result.mode is 'baseline'.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps( + {"message_count": 3, "mode": "baseline", "uploaded_at": 0.0} + ).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000020", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.mode == "baseline" + assert result.message_count == 3 + + def test_invalid_mode_in_meta_defaults_to_sdk(self): + """Unknown mode value in meta.json falls back to 'sdk'.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps({"message_count": 2, "mode": "unknown_mode"}).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000021", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.mode == "sdk" + + def test_invalid_utf8_meta_uses_defaults(self): + """Meta bytes that fail UTF-8 decode fall back to message_count=0, mode='sdk'.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + bad_meta = b"\xff\xfe" + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, bad_meta] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000022", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.message_count == 0 + assert result.mode == "sdk" + + def test_meta_fetch_exception_uses_defaults(self): + """Non-FileNotFoundError on meta fetch still returns content with defaults.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, RuntimeError("meta unavailable")] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000023", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 0 + assert result.mode == "sdk" + + +# --------------------------------------------------------------------------- +# detect_gap +# --------------------------------------------------------------------------- + + +def _msgs(*roles: str): + """Build a list of ChatMessage objects with the given roles.""" + from .model import ChatMessage + + return [ChatMessage(role=r, content=f"{r}-{i}") for i, r in enumerate(roles)] + + +class TestDetectGap: + """``detect_gap`` returns messages between transcript watermark and current turn.""" + + def _dl(self, message_count: int) -> TranscriptDownload: + return TranscriptDownload(content=b"", message_count=message_count, mode="sdk") + + def test_zero_watermark_returns_empty(self): + """message_count=0 means no watermark — skip gap detection.""" + dl = self._dl(0) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_watermark_covers_all_prefix_returns_empty(self): + """Transcript already covers all messages up to the current user turn.""" + # session: [user, assistant, user(current)] — wm=2 means covers up to assistant + dl = self._dl(2) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_watermark_exceeds_session_returns_empty(self): + """Watermark ahead of session count (race / over-count) → no gap.""" + dl = self._dl(10) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_misaligned_watermark_not_on_assistant_returns_empty(self): + """Watermark at a user-role position is misaligned — skip gap.""" + # wm=1: position 0 is 'user', not 'assistant' → skip + dl = self._dl(1) + messages = _msgs("user", "assistant", "user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_returns_gap_messages(self): + """Watermark behind session — gap messages returned (excluding current turn).""" + # session: [user0, assistant1, user2, assistant3, user4(current)] + # wm=2: transcript covers [0,1]; gap = [user2, assistant3] + dl = self._dl(2) + messages = _msgs("user", "assistant", "user", "assistant", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 2 + assert gap[0].role == "user" + assert gap[1].role == "assistant" + + def test_excludes_current_user_turn(self): + """The last message (current user turn) is never included in the gap.""" + # wm=2, session has 4 msgs: gap = [msg2] only (msg3 is current turn → excluded) + dl = self._dl(2) + messages = _msgs("user", "assistant", "user", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 1 + assert gap[0].role == "user" + + def test_single_gap_message(self): + """One message between watermark and current turn.""" + # session: [user0, assistant1, user2, assistant3, user4(current)] + # wm=3: position 2 is 'user' → misaligned, returns [] + # use wm=4: but 4 >= total-1=4 → also empty + # wm=3 with session [u, a, u, a, u, a, u(current)]: position 2 is 'user' → empty + # Valid case: wm=2 has 3 messages (assistant at 1), wm=4 with [u,a,u,a,u,a,u]: + # let's use wm=4 with 7 messages: wm=4 >= total-1=6? no, 4<6. pos[3]=assistant → gap=[msg4,msg5] + # simpler: wm=2, [u0,a1,a2,u3(current)] — pos[1]=assistant, gap=[a2] only + dl = self._dl(2) + messages = _msgs("user", "assistant", "assistant", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 1 + assert gap[0].role == "assistant" + + +# --------------------------------------------------------------------------- +# extract_context_messages +# --------------------------------------------------------------------------- + + +def _make_valid_transcript(*roles: str) -> str: + """Build a minimal valid JSONL transcript with the given message roles.""" + import json as stdlib_json + + from .transcript import STOP_REASON_END_TURN + + lines = [] + parent = "" + for i, role in enumerate(roles): + uid = f"uid-{i}" + entry: dict = { + "type": role, + "uuid": uid, + "parentUuid": parent, + "message": { + "role": role, + "content": f"{role} content {i}", + }, + } + if role == "assistant": + entry["message"]["id"] = f"msg_{i}" + entry["message"]["model"] = "test-model" + entry["message"]["type"] = "message" + entry["message"]["stop_reason"] = STOP_REASON_END_TURN + entry["message"]["content"] = [ + {"type": "text", "text": f"assistant content {i}"} + ] + lines.append(stdlib_json.dumps(entry)) + parent = uid + return "\n".join(lines) + "\n" + + +class TestExtractContextMessages: + """``extract_context_messages`` returns the shared context primitive.""" + + def test_none_download_returns_prior(self): + """No download → falls back to all session messages except current turn.""" + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(None, messages) + assert result == messages[:-1] + assert len(result) == 2 + + def test_empty_content_download_returns_prior(self): + """Empty bytes content → falls back to all prior messages.""" + dl = TranscriptDownload(content=b"", message_count=2, mode="sdk") + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + assert result == messages[:-1] + + def test_valid_transcript_no_gap_returns_transcript_messages(self): + """Transcript covers all prior turns → only transcript messages returned.""" + # Transcript: [user, assistant] — 2 messages + # Session: [user, assistant, user(current)] — watermark=2 covers prefix + transcript_content = _make_valid_transcript("user", "assistant") + dl = TranscriptDownload( + content=transcript_content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + # Transcript has 2 messages (user + assistant) and no gap + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + + def test_valid_transcript_with_gap_returns_transcript_plus_gap(self): + """Transcript is stale → gap messages appended after transcript content.""" + # Transcript: [user, assistant] — watermark=2 + # Session: [user, assistant, user, assistant, user(current)] + # Gap: [user(2), assistant(3)] — positions 2 and 3 + transcript_content = _make_valid_transcript("user", "assistant") + dl = TranscriptDownload( + content=transcript_content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user", "assistant", "user") + result = extract_context_messages(dl, messages) + # 2 transcript messages + 2 gap messages = 4 + assert len(result) == 4 + assert result[0].role == "user" # transcript user + assert result[1].role == "assistant" # transcript assistant + assert result[2].role == "user" # gap user + assert result[3].role == "assistant" # gap assistant + + def test_compact_summary_entries_preserved(self): + """``isCompactSummary=True`` entries survive ``_transcript_to_messages``.""" + import json as stdlib_json + + from .transcript import STOP_REASON_END_TURN + + # Build a transcript where one entry is a compaction summary. + # isCompactSummary=True entries have type in STRIPPABLE_TYPES but are kept. + compact_entry = stdlib_json.dumps( + { + "type": "summary", + "uuid": "uid-compact", + "parentUuid": "", + "isCompactSummary": True, + "message": { + "role": "user", + "content": "COMPACT_SUMMARY_CONTENT", + }, + } + ) + assistant_entry = stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-compact", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "response after compact"}], + }, + } + ) + content = compact_entry + "\n" + assistant_entry + "\n" + dl = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + # Both the compact summary and the assistant response are present + assert len(result) == 2 + roles = [m.role for m in result] + assert "user" in roles # compact summary has role=user + assert "assistant" in roles + # The compact summary content is preserved + compact_msgs = [m for m in result if m.role == "user"] + assert any("COMPACT_SUMMARY_CONTENT" in (m.content or "") for m in compact_msgs) diff --git a/autogpt_platform/backend/backend/util/architecture_test.py b/autogpt_platform/backend/backend/util/architecture_test.py new file mode 100644 index 0000000000..b3cf457911 --- /dev/null +++ b/autogpt_platform/backend/backend/util/architecture_test.py @@ -0,0 +1,134 @@ +""" +Architectural tests for the backend package. + +Each rule here exists to prevent a *class* of bug, not to police style. +When adding a rule, document the incident or failure mode that motivated +it so future maintainers know whether the rule still earns its keep. +""" + +import ast +import pathlib + +BACKEND_ROOT = pathlib.Path(__file__).resolve().parents[1] + + +# --------------------------------------------------------------------------- +# Rule: no process-wide @cached(...) around event-loop-bound async clients +# --------------------------------------------------------------------------- +# +# Motivation: `backend.util.cache.cached` stores its result in a process-wide +# dict for ttl_seconds. Async clients (AsyncOpenAI, httpx.AsyncClient, +# AsyncRabbitMQ, supabase AClient, ...) wrap connection pools whose internal +# asyncio primitives lazily bind to the first event loop that uses them. The +# executor runs two long-lived loops on separate threads; once the cache is +# populated from loop A, any subsequent call from loop B raises +# `RuntimeError: ... bound to a different event loop`, surfaced as an opaque +# `APIConnectionError: Connection error.` and poisons the cache for a full +# TTL window. +# +# Use `per_loop_cached` (keyed on id(running loop)) or construct per-call. + +LOOP_BOUND_TYPES = frozenset( + { + "AsyncOpenAI", + "LangfuseAsyncOpenAI", + "AsyncClient", # httpx, openai internal + "AsyncRabbitMQ", + "AClient", # supabase async + "AsyncRedisExecutionEventBus", + } +) + +# Pre-existing offenders tracked for future cleanup. Exclude from this test +# so the rule can still catch NEW violations without blocking unrelated PRs. +_KNOWN_OFFENDERS = frozenset( + { + "util/clients.py get_async_supabase", + "util/clients.py get_openai_client", + } +) + + +def _decorator_name(node: ast.expr) -> str | None: + if isinstance(node, ast.Call): + return _decorator_name(node.func) + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None + + +def _annotation_names(annotation: ast.expr | None) -> set[str]: + if annotation is None: + return set() + if isinstance(annotation, ast.Constant) and isinstance(annotation.value, str): + try: + parsed = ast.parse(annotation.value, mode="eval").body + except SyntaxError: + return set() + return _annotation_names(parsed) + names: set[str] = set() + for child in ast.walk(annotation): + if isinstance(child, ast.Name): + names.add(child.id) + elif isinstance(child, ast.Attribute): + names.add(child.attr) + return names + + +def _iter_backend_py_files(): + for path in BACKEND_ROOT.rglob("*.py"): + if "__pycache__" in path.parts: + continue + yield path + + +def test_known_offenders_use_posix_separators(): + """_KNOWN_OFFENDERS must use forward slashes since the comparison key + is built from pathlib.Path.relative_to() which uses OS-native separators. + On Windows this would be backslashes, causing false positives. + + Ensure the key construction normalises to forward slashes. + """ + for entry in _KNOWN_OFFENDERS: + path_part = entry.split()[0] + assert "\\" not in path_part, ( + f"_KNOWN_OFFENDERS entry uses backslash: {entry!r}. " + "Use forward slashes — the test should normalise Path separators." + ) + + +def test_no_process_cached_loop_bound_clients(): + offenders: list[str] = [] + for py in _iter_backend_py_files(): + try: + tree = ast.parse(py.read_text(encoding="utf-8"), filename=str(py)) + except SyntaxError: + continue + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + decorators = {_decorator_name(d) for d in node.decorator_list} + if "cached" not in decorators: + continue + bound = _annotation_names(node.returns) & LOOP_BOUND_TYPES + if bound: + rel = py.relative_to(BACKEND_ROOT) + key = f"{rel.as_posix()} {node.name}" + if key in _KNOWN_OFFENDERS: + continue + offenders.append( + f"{rel}:{node.lineno} {node.name}() -> {sorted(bound)}" + ) + + assert not offenders, ( + "Process-wide @cached(...) must not wrap functions returning event-" + "loop-bound async clients. These objects lazily bind their connection " + "pool to the first event loop that uses them; caching them across " + "loops poisons the cache and surfaces as opaque connection errors.\n\n" + "Offenders:\n " + "\n ".join(offenders) + "\n\n" + "Fix: construct the client per-call, or introduce a per-loop factory " + "keyed on id(asyncio.get_running_loop()). See " + "backend/util/clients.py::get_openai_client for context." + ) diff --git a/autogpt_platform/backend/scripts/download_transcripts.py b/autogpt_platform/backend/scripts/download_transcripts.py index 26204c3243..a9b32e8494 100644 --- a/autogpt_platform/backend/scripts/download_transcripts.py +++ b/autogpt_platform/backend/scripts/download_transcripts.py @@ -88,17 +88,19 @@ async def cmd_download(session_ids: list[str]) -> None: print(f"[{sid[:12]}] Not found in GCS") continue + content_str = ( + dl.content.decode("utf-8") if isinstance(dl.content, bytes) else dl.content + ) out = _transcript_path(sid) with open(out, "w") as f: - f.write(dl.content) + f.write(content_str) - lines = len(dl.content.strip().split("\n")) + lines = len(content_str.strip().split("\n")) meta = { "session_id": sid, "user_id": user_id, "message_count": dl.message_count, - "uploaded_at": dl.uploaded_at, - "transcript_bytes": len(dl.content), + "transcript_bytes": len(content_str), "transcript_lines": lines, } with open(_meta_path(sid), "w") as f: @@ -106,7 +108,7 @@ async def cmd_download(session_ids: list[str]) -> None: print( f"[{sid[:12]}] Saved: {lines} entries, " - f"{len(dl.content)} bytes, msg_count={dl.message_count}" + f"{len(content_str)} bytes, msg_count={dl.message_count}" ) print("\nDone. Run 'load' command to import into local dev environment.") @@ -227,7 +229,7 @@ async def cmd_load(session_ids: list[str]) -> None: await upload_transcript( user_id=user_id, session_id=sid, - content=content, + content=content.encode("utf-8"), message_count=msg_count, ) print(f"[{sid[:12]}] Stored transcript in local workspace storage") diff --git a/autogpt_platform/backend/test/copilot/test_transcript_watermark.py b/autogpt_platform/backend/test/copilot/test_transcript_watermark.py new file mode 100644 index 0000000000..bd88726339 --- /dev/null +++ b/autogpt_platform/backend/test/copilot/test_transcript_watermark.py @@ -0,0 +1,140 @@ +"""Unit tests for the transcript watermark (message_count) fix. + +The bug: upload used message_count=len(session.messages) (DB count). When a +prior turn's GCS upload failed silently, the JSONL on GCS was stale (e.g. +covered only T1-T12) but the meta.json watermark matched the full DB count +(e.g. 46). The next turn's gap-fill check (transcript_msg_count < msg_count-1) +never triggered, so the model silently lost context for the skipped turns. + +The fix: watermark = previous_coverage + 2 (current user+asst pair) when +use_resume=True and transcript_msg_count > 0. This ensures the watermark +reflects the JSONL content, not the DB count. + +These tests exercise _build_query_message directly to verify that gap-fill +triggers with the corrected watermark but NOT with the inflated (buggy) one. +""" + +from unittest.mock import MagicMock + +import pytest + +from backend.copilot.sdk.service import _build_query_message + + +def _make_messages(n_pairs: int, *, current_user: str = "current") -> list[MagicMock]: + """Build a flat list of n_pairs*2 alternating user/asst messages, plus + one trailing user message for the *current* turn.""" + msgs: list[MagicMock] = [] + for i in range(n_pairs): + u = MagicMock() + u.role = "user" + u.content = f"user message {i}" + a = MagicMock() + a.role = "assistant" + a.content = f"assistant response {i}" + msgs.extend([u, a]) + # Current turn's user message + cur = MagicMock() + cur.role = "user" + cur.content = current_user + msgs.append(cur) + return msgs + + +def _make_session(messages: list[MagicMock]) -> MagicMock: + session = MagicMock() + session.messages = messages + return session + + +@pytest.mark.asyncio +async def test_gap_fill_triggers_for_stale_jsonl(): + """Scenario: T1-T12 in JSONL (watermark=24), DB has T1-T22+Test (46 msgs). + + With the FIX: 'Test' uploaded watermark=26 (T12's 24 + 2 for 'Test'). + Next turn (T24) downloads watermark=26, DB has 47. + Gap check: 26 < 47-1=46 → TRUE → gap fills T14-T23. + """ + # T23 turns in DB (46 messages) + T24 user = 47 + msgs = _make_messages(23, current_user="memory test - recall all") + assert len(msgs) == 47 + + session = _make_session(msgs) + + # Watermark as uploaded by the FIX: T12 covered 24, 'Test' +2 = 26 + result_msg, _ = await _build_query_message( + current_message="memory test - recall all", + session=session, + use_resume=True, + transcript_msg_count=26, + session_id="test-session-id", + ) + + assert "<conversation_history>" in result_msg, ( + "Expected gap-fill to inject <conversation_history> when " + "watermark=26 < msg_count-1=46" + ) + + +@pytest.mark.asyncio +async def test_no_gap_fill_when_watermark_is_current(): + """When the JSONL is fully current (watermark = DB-1), no gap injected.""" + # T23 turns in DB (46 messages) + T24 user = 47 + msgs = _make_messages(23, current_user="next message") + session = _make_session(msgs) + + result_msg, _ = await _build_query_message( + current_message="next message", + session=session, + use_resume=True, + transcript_msg_count=46, # current — no gap + session_id="test-session-id", + ) + + assert ( + "<conversation_history>" not in result_msg + ), "No gap-fill expected when watermark is current" + assert result_msg == "next message" + + +@pytest.mark.asyncio +async def test_inflated_watermark_suppresses_gap_fill(): + """Documents the original bug: inflated watermark suppresses gap-fill. + + 'Test' uploaded watermark=len(session.messages)=46 even though only 26 + messages are in the JSONL. Next turn: 46 < 47-1=46 → FALSE → no gap fill. + """ + msgs = _make_messages(23, current_user="memory test") + session = _make_session(msgs) + + # Buggy watermark: inflated to DB count + result_msg, _ = await _build_query_message( + current_message="memory test", + session=session, + use_resume=True, + transcript_msg_count=46, # inflated — suppresses gap fill + session_id="test-session-id", + ) + + assert ( + "<conversation_history>" not in result_msg + ), "With inflated watermark, gap-fill is suppressed — this documents the bug" + + +@pytest.mark.asyncio +async def test_fixed_watermark_fills_same_gap(): + """Same scenario but with the FIXED watermark triggers gap-fill.""" + msgs = _make_messages(23, current_user="memory test") + session = _make_session(msgs) + + result_msg, _ = await _build_query_message( + current_message="memory test", + session=session, + use_resume=True, + transcript_msg_count=26, # fixed watermark + session_id="test-session-id", + ) + + assert ( + "<conversation_history>" in result_msg + ), "With fixed watermark=26, gap-fill triggers and injects missing turns" diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx index 3a55fabf1d..186c8d96fe 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx @@ -110,7 +110,7 @@ export const Flow = () => { event.preventDefault(); }} maxZoom={2} - minZoom={0.1} + minZoom={0.05} onDragOver={onDragOver} onDrop={onDrop} nodesDraggable={!isLocked} diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 7ae74323ec..5ff9dd14df 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1346,7 +1346,15 @@ { "$ref": "#/components/schemas/MCPToolsDiscoveredResponse" }, - { "$ref": "#/components/schemas/MCPToolOutputResponse" } + { "$ref": "#/components/schemas/MCPToolOutputResponse" }, + { "$ref": "#/components/schemas/MemoryStoreResponse" }, + { "$ref": "#/components/schemas/MemorySearchResponse" }, + { + "$ref": "#/components/schemas/MemoryForgetCandidatesResponse" + }, + { + "$ref": "#/components/schemas/MemoryForgetConfirmResponse" + } ], "title": "Response Getv2[Dummy] Tool Response Type Export For Codegen" } @@ -11543,6 +11551,103 @@ "title": "MarketplaceListingCreator", "description": "Creator information for a marketplace listing." }, + "MemoryForgetCandidatesResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_forget_candidates" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "candidates": { + "items": { + "additionalProperties": { "type": "string" }, + "type": "object" + }, + "type": "array", + "title": "Candidates" + } + }, + "type": "object", + "required": ["message"], + "title": "MemoryForgetCandidatesResponse", + "description": "Response with candidate memories to forget." + }, + "MemoryForgetConfirmResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_forget_confirm" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "deleted_uuids": { + "items": { "type": "string" }, + "type": "array", + "title": "Deleted Uuids" + }, + "failed_uuids": { + "items": { "type": "string" }, + "type": "array", + "title": "Failed Uuids" + } + }, + "type": "object", + "required": ["message"], + "title": "MemoryForgetConfirmResponse", + "description": "Response after deleting specific memory edges." + }, + "MemorySearchResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_search" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "facts": { + "items": { "type": "string" }, + "type": "array", + "title": "Facts" + }, + "recent_episodes": { + "items": { "type": "string" }, + "type": "array", + "title": "Recent Episodes" + } + }, + "type": "object", + "required": ["message"], + "title": "MemorySearchResponse", + "description": "Response when memories are searched." + }, + "MemoryStoreResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "memory_store" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "memory_name": { "type": "string", "title": "Memory Name" } + }, + "type": "object", + "required": ["message", "memory_name"], + "title": "MemoryStoreResponse", + "description": "Response when a memory is stored." + }, "Message": { "properties": { "query": { "type": "string", "title": "Query" }, @@ -12939,7 +13044,9 @@ "feature_request_search", "feature_request_created", "memory_store", - "memory_search" + "memory_search", + "memory_forget_candidates", + "memory_forget_confirm" ], "title": "ResponseType", "description": "Types of tool responses."