mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(copilot): workspace file tools, context reconstruction, transcript upload protection (#12164)
## Summary - **Workspace file tools**: `write_workspace_file` now accepts plain text `content`, `source_path` (copy from ephemeral disk), and graceful fallback for invalid base64. `read_workspace_file` gains `save_to_path` to download workspace files to the ephemeral working directory. Both validate paths against session-specific ephemeral directory. - **Context reconstruction**: `_format_conversation_context` now includes tool call summaries and tool results (not just user/assistant text), fixing agent amnesia when transcript is unavailable or stale. - **Transcript upload protection**: Moved transcript upload from inside the inner `try` block to the `finally` block, ensuring it always runs even on streaming exceptions — prevents transcript loss that caused staleness on subsequent turns. - **Agent inactivity timeout**: Configurable timeout (default 300s) kills hung Claude agents that stop producing SDK messages. - **SDK system prompt**: Restructured with clear sections for shell commands, two storage systems, file transfer workflows, and long-running tools. - **Path validation hardening**: `_validate_ephemeral_path` uses `os.path.realpath` for both session dir and target path, fixing macOS `/tmp` → `/private/tmp` symlink mismatch. Empty-string params normalised to `None` to prevent dispatch assertion failures. ## Test plan - [x] `_format_conversation_context` — empty, user, assistant, tool calls, tool results, full conversation (query_builder_test.py) - [x] `_build_query_message` — resume up-to-date, stale transcript gap, zero msg count, no resume single/multi (query_builder_test.py) - [x] `_validate_ephemeral_path` — valid path, traversal, cross-session, symlink escape, nested (workspace_files_test.py) - [x] `_resolve_write_content` — no sources, multiple sources, plain text, base64, invalid base64, source_path, not found, outside ephemeral, empty strings (workspace_files_test.py) - [ ] Verify transcript upload occurs even after streaming error - [ ] Verify agent inactivity timeout kills hung agents (300s default) --------- Co-authored-by: Otto (AGPT) <otto@agpt.co>
This commit is contained in:
@@ -0,0 +1,221 @@
|
||||
"""Tests for _format_conversation_context and _build_query_message."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import (
|
||||
_build_query_message,
|
||||
_format_conversation_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_conversation_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_format_empty_list():
|
||||
assert _format_conversation_context([]) is None
|
||||
|
||||
|
||||
def test_format_none_content_messages():
|
||||
msgs = [ChatMessage(role="user", content=None)]
|
||||
assert _format_conversation_context(msgs) is None
|
||||
|
||||
|
||||
def test_format_user_message():
|
||||
msgs = [ChatMessage(role="user", content="hello")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: hello" in result
|
||||
assert result.startswith("<conversation_history>")
|
||||
assert result.endswith("</conversation_history>")
|
||||
|
||||
|
||||
def test_format_assistant_text():
|
||||
msgs = [ChatMessage(role="assistant", content="hi there")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "You responded: hi there" in result
|
||||
|
||||
|
||||
def test_format_assistant_tool_calls():
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[{"function": {"name": "search", "arguments": '{"q": "test"}'}}],
|
||||
)
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'You called tool: search({"q": "test"})' in result
|
||||
|
||||
|
||||
def test_format_tool_result():
|
||||
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'Tool result: {"result": "ok"}' in result
|
||||
|
||||
|
||||
def test_format_tool_result_none_content():
|
||||
msgs = [ChatMessage(role="tool", content=None)]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "Tool result: " in result
|
||||
|
||||
|
||||
def test_format_full_conversation():
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="find agents"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="I'll search for agents.",
|
||||
tool_calls=[
|
||||
{"function": {"name": "find_agents", "arguments": '{"q": "test"}'}}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", content='[{"id": "1", "name": "Agent1"}]'),
|
||||
ChatMessage(role="assistant", content="Found Agent1."),
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: find agents" in result
|
||||
assert "You responded: I'll search for agents." in result
|
||||
assert "You called tool: find_agents" in result
|
||||
assert "Tool result:" in result
|
||||
assert "You responded: Found Agent1." in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_query_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(messages: list[ChatMessage]) -> ChatSession:
|
||||
"""Build a minimal ChatSession with the given messages."""
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=messages,
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_up_to_date():
|
||||
"""With --resume and transcript covers all messages, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="what's new?"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"what's new?",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
# transcript_msg_count == msg_count - 1, so no gap
|
||||
assert result == "what's new?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_stale_transcript():
|
||||
"""With --resume and stale transcript, gap context is prepended."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="turn 1"),
|
||||
ChatMessage(role="assistant", content="reply 1"),
|
||||
ChatMessage(role="user", content="turn 2"),
|
||||
ChatMessage(role="assistant", content="reply 2"),
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "turn 2" in result
|
||||
assert "reply 2" in result
|
||||
assert "Now, the user says:\nturn 3" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_zero_msg_count():
|
||||
"""With --resume but transcript_msg_count=0, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="new msg"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"new msg",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "new msg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_single_message():
|
||||
"""Without --resume and only 1 message, return raw message."""
|
||||
session = _make_session([ChatMessage(role="user", content="first")])
|
||||
result = await _build_query_message(
|
||||
"first",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "first"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
"""Without --resume and multiple messages, compress and prepend."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="older question"),
|
||||
ChatMessage(role="assistant", content="older answer"),
|
||||
ChatMessage(role="user", content="new question"),
|
||||
]
|
||||
)
|
||||
|
||||
# Mock _compress_conversation_history to return the messages as-is
|
||||
async def _mock_compress(sess):
|
||||
return sess.messages[:-1]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_conversation_history",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "older question" in result
|
||||
assert "older answer" in result
|
||||
assert "Now, the user says:\nnew question" in result
|
||||
@@ -124,20 +124,20 @@ def _validate_user_isolation(
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
# The "path" param is a cloud storage key (e.g. "/ASEAN/report.md")
|
||||
# where a leading "/" is normal. Only check for ".." traversal.
|
||||
# Filesystem paths (source_path, save_to_path) are validated inside
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path:
|
||||
# Check for path traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
logger.warning(
|
||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
if path and ".." in path:
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
@@ -165,11 +165,12 @@ def test_workspace_path_traversal_blocked():
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_absolute_path_blocked():
|
||||
def test_workspace_absolute_path_allowed():
|
||||
"""Workspace 'path' is a cloud storage key — leading '/' is normal."""
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
||||
"workspace_read", {"path": "/ASEAN/report.md"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_workspace_normal_path_allowed():
|
||||
|
||||
@@ -69,6 +69,7 @@ class CapturedTranscript:
|
||||
|
||||
path: str = ""
|
||||
sdk_session_id: str = ""
|
||||
raw_content: str = ""
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
@@ -87,19 +88,44 @@ _SDK_TOOL_SUPPLEMENT = """
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a network-isolated sandbox.
|
||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
||||
same working directory. Files created by one are readable by the other.
|
||||
- **IMPORTANT — File persistence**: Your working directory is **ephemeral** —
|
||||
files are lost between turns. When you create or modify important files
|
||||
(code, configs, outputs), you MUST save them using `write_workspace_file`
|
||||
so they persist. Use `read_workspace_file` and `list_workspace_files` to
|
||||
access files saved in previous turns. If a "Files from previous turns"
|
||||
section is present above, those files are available via `read_workspace_file`.
|
||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **Ephemeral working directory** (`/tmp/copilot-<session>/`):
|
||||
- Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`
|
||||
- Files here are **lost between turns** — do NOT rely on them persisting
|
||||
- Use for temporary work: running scripts, processing data, etc.
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across turns and sessions**
|
||||
- Use `write_workspace_file` to save important files (code, outputs, configs)
|
||||
- Use `read_workspace_file` to retrieve previously saved files
|
||||
- Use `list_workspace_files` to see what files you've saved before
|
||||
- Call `list_workspace_files(include_all_sessions=True)` to see files from
|
||||
all sessions
|
||||
|
||||
### Moving files between ephemeral and persistent storage
|
||||
- **Ephemeral → Persistent**: Use `write_workspace_file` with either:
|
||||
- `content` param (plain text) — for text files
|
||||
- `source_path` param — to copy any file directly from the ephemeral dir
|
||||
- **Persistent → Ephemeral**: Use `read_workspace_file` with `save_to_path`
|
||||
param to download a workspace file to the ephemeral dir for processing
|
||||
|
||||
### File persistence workflow
|
||||
When you create or modify important files (code, configs, outputs), you MUST:
|
||||
1. Save them using `write_workspace_file` so they persist
|
||||
2. At the start of a new turn, call `list_workspace_files` to see what files
|
||||
are available from previous turns
|
||||
|
||||
### Long-running tools
|
||||
Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
@@ -380,11 +406,9 @@ async def _compress_conversation_history(
|
||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
"""Format conversation messages into a context prefix for the user message.
|
||||
|
||||
Returns a string like:
|
||||
<conversation_history>
|
||||
User: hello
|
||||
You responded: Hi! How can I help?
|
||||
</conversation_history>
|
||||
Includes user messages, assistant text, tool call summaries, and
|
||||
tool result summaries so the agent retains full context about what
|
||||
tools were invoked and their outcomes.
|
||||
|
||||
Returns None if there are no messages to format.
|
||||
"""
|
||||
@@ -393,18 +417,21 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
if not msg.content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
lines.append(f"User: {msg.content}")
|
||||
if msg.content:
|
||||
lines.append(f"User: {msg.content}")
|
||||
elif msg.role == "assistant":
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
if msg.content:
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_name = func.get("name", "unknown")
|
||||
tool_args = func.get("arguments", "")
|
||||
lines.append(f"You called tool: {tool_name}({tool_args})")
|
||||
elif msg.role == "tool":
|
||||
# Include tool error/denial outcomes so the agent doesn't
|
||||
# hallucinate that blocked or failed operations succeeded.
|
||||
content = msg.content
|
||||
if _is_tool_error_or_denial(content):
|
||||
lines.append(f"Tool result: {content}")
|
||||
content = msg.content or ""
|
||||
lines.append(f"Tool result: {content}")
|
||||
|
||||
if not lines:
|
||||
return None
|
||||
@@ -437,6 +464,44 @@ def _is_tool_error_or_denial(content: str | None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
async def _build_query_message(
|
||||
current_message: str,
|
||||
session: ChatSession,
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""Build the query message with appropriate context.
|
||||
|
||||
With --resume the CLI already has full context, so only the new message
|
||||
is needed. Without resume, compress history into a context prefix.
|
||||
Hybrid mode: if the transcript is stale, compress only the gap.
|
||||
"""
|
||||
msg_count = len(session.messages)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
if transcript_msg_count < msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
gap_context = _format_conversation_context(gap)
|
||||
if gap_context:
|
||||
logger.info(
|
||||
f"[SDK] Transcript stale: covers {transcript_msg_count} "
|
||||
f"of {msg_count} messages, compressing {len(gap)} missed"
|
||||
)
|
||||
return f"{gap_context}\n\nNow, the user says:\n{current_message}"
|
||||
elif not use_resume and msg_count > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({msg_count} messages) — no transcript for --resume"
|
||||
)
|
||||
compressed = await _compress_conversation_history(session)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
return f"{history_context}\n\nNow, the user says:\n{current_message}"
|
||||
|
||||
return current_message
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -501,10 +566,12 @@ async def stream_chat_completion_sdk(
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
stream_completed = False
|
||||
# Initialise sdk_cwd before the try so the finally can reference it
|
||||
# even if _make_sdk_cwd raises (in that case it stays as "").
|
||||
# Initialise variables before the try so the finally block can
|
||||
# always attempt transcript upload regardless of errors.
|
||||
sdk_cwd = ""
|
||||
use_resume = False
|
||||
resume_file: str | None = None
|
||||
captured_transcript = CapturedTranscript()
|
||||
|
||||
try:
|
||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
||||
@@ -534,12 +601,23 @@ async def stream_chat_completion_sdk(
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
# --- Transcript capture via Stop hook ---
|
||||
captured_transcript = CapturedTranscript()
|
||||
|
||||
# Read the file content immediately — the SDK may clean up
|
||||
# the file before our finally block runs.
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
logger.debug(f"[SDK] Stop hook: path={transcript_path!r}")
|
||||
content = read_transcript_file(transcript_path)
|
||||
if content:
|
||||
captured_transcript.raw_content = content
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: captured {len(content)}B from "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Stop hook: transcript file empty/missing at "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
@@ -549,13 +627,16 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
resume_file: str | None = None
|
||||
use_resume = False
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
if dl and validate_transcript(dl.content):
|
||||
is_valid = bool(dl and validate_transcript(dl.content))
|
||||
if dl and is_valid:
|
||||
logger.info(
|
||||
f"[SDK] Transcript available for session {session_id}: "
|
||||
f"{len(dl.content)}B, msg_count={dl.message_count}"
|
||||
)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
@@ -566,6 +647,15 @@ async def stream_chat_completion_sdk(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
elif dl:
|
||||
logger.warning(
|
||||
f"[SDK] Transcript downloaded but invalid for {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] No transcript available for {session_id} "
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
@@ -602,50 +692,13 @@ async def stream_chat_completion_sdk(
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Build query: with --resume the CLI already has full
|
||||
# context, so we only send the new message. Without
|
||||
# resume, compress history into a context prefix.
|
||||
#
|
||||
# Hybrid mode: if the transcript is stale (upload missed
|
||||
# some turns), compress only the gap and prepend it so
|
||||
# the agent has transcript context + missed turns.
|
||||
query_message = current_message
|
||||
current_msg_count = len(session.messages)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
# Transcript covers messages[0..M-1]. Current session
|
||||
# has N messages (last one is the new user msg).
|
||||
# Gap = messages[M .. N-2] (everything between upload
|
||||
# and the current turn).
|
||||
# When transcript_msg_count == 0 (no metadata), we trust
|
||||
# the transcript is up-to-date and skip gap detection to
|
||||
# avoid duplicating the full history.
|
||||
if transcript_msg_count < current_msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
gap_context = _format_conversation_context(gap)
|
||||
if gap_context:
|
||||
logger.info(
|
||||
f"[SDK] Transcript stale: covers {transcript_msg_count} "
|
||||
f"of {current_msg_count} messages, compressing "
|
||||
f"{len(gap)} missed messages"
|
||||
)
|
||||
query_message = (
|
||||
f"{gap_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
elif not use_resume and current_msg_count > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({current_msg_count} messages) — "
|
||||
f"no transcript available for --resume"
|
||||
)
|
||||
compressed = await _compress_conversation_history(session)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
query_message = (
|
||||
f"{history_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
query_message = await _build_query_message(
|
||||
current_message,
|
||||
session,
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs, "
|
||||
@@ -822,6 +875,38 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
yield StreamFinish()
|
||||
finally:
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# This MUST run in finally so the transcript is uploaded even when
|
||||
# the streaming loop raises an exception. The CLI uses
|
||||
# appendFileSync, so whatever was written before the error/SIGTERM
|
||||
# is safely on disk and still useful for the next turn.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
try:
|
||||
# Prefer content captured in the Stop hook (read before
|
||||
# cleanup removes the file). Fall back to the resume
|
||||
# file when the stop hook didn't fire (e.g. error before
|
||||
# completion) so we don't lose the prior transcript.
|
||||
raw_transcript = captured_transcript.raw_content or None
|
||||
if not raw_transcript and use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
|
||||
if raw_transcript:
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[SDK] No transcript to upload for {session_id}")
|
||||
except Exception as upload_err:
|
||||
logger.error(
|
||||
f"[SDK] Transcript upload failed in finally: {upload_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
@@ -18,6 +20,151 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
content_b64: str | None,
|
||||
source_path: str | None,
|
||||
session_id: str,
|
||||
) -> bytes | ErrorResponse:
|
||||
"""Resolve file content from exactly one of three input sources.
|
||||
|
||||
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
|
||||
failure (wrong number of sources, invalid path, file not found, etc.).
|
||||
"""
|
||||
# Normalise empty strings to None so counting and dispatch stay in sync.
|
||||
if content_text is not None and content_text == "":
|
||||
content_text = None
|
||||
if content_b64 is not None and content_b64 == "":
|
||||
content_b64 = None
|
||||
if source_path is not None and source_path == "":
|
||||
source_path = None
|
||||
|
||||
sources_provided = sum(
|
||||
x is not None for x in [content_text, content_b64, source_path]
|
||||
)
|
||||
if sources_provided == 0:
|
||||
return ErrorResponse(
|
||||
message="Please provide one of: content, content_base64, or source_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
if sources_provided > 1:
|
||||
return ErrorResponse(
|
||||
message="Provide only one of: content, content_base64, or source_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if source_path is not None:
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if content_b64 is not None:
|
||||
try:
|
||||
return base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Invalid base64 encoding in content_base64. "
|
||||
"Please encode the file content with standard base64, "
|
||||
"or use the 'content' parameter for plain text, "
|
||||
"or 'source_path' to copy from the working directory."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
assert content_text is not None
|
||||
return content_text.encode("utf-8")
|
||||
|
||||
|
||||
def _validate_ephemeral_path(
|
||||
path: str, *, param_name: str, session_id: str
|
||||
) -> ErrorResponse | str:
|
||||
"""Validate that *path* is inside the session's ephemeral directory.
|
||||
|
||||
Uses the session-specific directory (``make_session_path(session_id)``)
|
||||
rather than the bare prefix, so ``/tmp/copilot-evil/...`` is rejected.
|
||||
|
||||
Returns the resolved real path on success, or an ``ErrorResponse`` when the
|
||||
path escapes the session directory.
|
||||
"""
|
||||
session_dir = os.path.realpath(make_session_path(session_id)) + os.sep
|
||||
real = os.path.realpath(path)
|
||||
if not real.startswith(session_dir):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"{param_name} must be within the ephemeral working "
|
||||
f"directory ({make_session_path(session_id)})"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
return real
|
||||
|
||||
|
||||
_TEXT_MIME_PREFIXES = (
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
)
|
||||
|
||||
_IMAGE_MIME_TYPES = {"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
|
||||
|
||||
def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
path: str | None,
|
||||
session_id: str,
|
||||
) -> tuple[str, Any] | ErrorResponse:
|
||||
"""Resolve a file by file_id or path.
|
||||
|
||||
Returns ``(target_file_id, file_info)`` on success, or an
|
||||
``ErrorResponse`` if the file was not found.
|
||||
"""
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}", session_id=session_id
|
||||
)
|
||||
return file_id, file_info
|
||||
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}", session_id=session_id
|
||||
)
|
||||
return file_info.id, file_info
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
@@ -68,6 +215,8 @@ class WorkspaceWriteResponse(ToolResponseBase):
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
source: str | None = None # "content", "base64", or "copied from <path>"
|
||||
content_preview: str | None = None # First 200 chars for text files
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
@@ -136,11 +285,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
@@ -148,20 +295,13 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
path=path_prefix, include_all_sessions=include_all_sessions
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
@@ -172,19 +312,27 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
scope = "all sessions" if include_all_sessions else "current session"
|
||||
total_size = sum(f.size_bytes for f in file_infos)
|
||||
|
||||
# Build a human-readable summary so the agent can relay details.
|
||||
lines = [f"Found {len(files)} file(s) in workspace ({scope}):"]
|
||||
for f in file_infos:
|
||||
lines.append(f" - {f.path} ({f.size_bytes:,} bytes, {f.mime_type})")
|
||||
if total > len(files):
|
||||
lines.append(f" ... and {total - len(files)} more")
|
||||
lines.append(f"Total size: {total_size:,} bytes")
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||
message="\n".join(lines),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {str(e)}",
|
||||
message=f"Failed to list workspace files: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -193,10 +341,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
# Size threshold for returning full content vs metadata+URL
|
||||
# Files larger than this return metadata with download URL to prevent context bloat
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
# Preview size for text files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
@@ -212,6 +357,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Optionally use 'save_to_path' to copy the file to the ephemeral "
|
||||
"working directory for processing with bash_exec or SDK tools. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
@@ -232,6 +379,15 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the ephemeral "
|
||||
"working directory (e.g., '/tmp/copilot-.../data.csv') "
|
||||
"so it can be processed with bash_exec or SDK tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
@@ -247,18 +403,6 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""Check if the MIME type is a text-based type."""
|
||||
text_types = [
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
]
|
||||
return any(mime_type.startswith(t) for t in text_types)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
@@ -266,117 +410,112 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
save_to_path: Optional[str] = kwargs.get("save_to_path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
# Validate and resolve save_to_path (use sanitized real path).
|
||||
if save_to_path:
|
||||
validated_save = _validate_ephemeral_path(
|
||||
save_to_path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated_save, ErrorResponse):
|
||||
return validated_save
|
||||
save_to_path = validated_save
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
# If save_to_path, read + save; cache bytes for possible inline reuse.
|
||||
cached_content: bytes | None = None
|
||||
if save_to_path:
|
||||
cached_content = await manager.read_file_by_id(target_file_id)
|
||||
dir_path = os.path.dirname(save_to_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(save_to_path, "wb") as f:
|
||||
f.write(cached_content)
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
||||
|
||||
# Return inline content for small text/image files (unless force_download_url)
|
||||
is_image_file = file_info.mime_type in {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
}
|
||||
if (
|
||||
is_small_file
|
||||
and (is_text_file or is_image_file)
|
||||
and not force_download_url
|
||||
):
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text = _is_text_mime(file_info.mime_type)
|
||||
is_image = file_info.mime_type in _IMAGE_MIME_TYPES
|
||||
|
||||
# Inline content for small text/image files
|
||||
if is_small and (is_text or is_image) and not force_download_url:
|
||||
content = cached_content or await manager.read_file_by_id(
|
||||
target_file_id
|
||||
)
|
||||
msg = (
|
||||
f"Read {file_info.name} from workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes, {file_info.mime_type})"
|
||||
)
|
||||
if save_to_path:
|
||||
msg += f" — also saved to {save_to_path}"
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mime_type,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
content_base64=base64.b64encode(content).decode("utf-8"),
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Return metadata + workspace:// reference for large or binary files
|
||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||
download_url = f"workspace://{target_file_id}"
|
||||
|
||||
# Generate preview for text files
|
||||
# Metadata + download URL for large/binary files
|
||||
preview: str | None = None
|
||||
if is_text_file:
|
||||
if is_text:
|
||||
try:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||
"utf-8", errors="replace"
|
||||
raw = cached_content or await manager.read_file_by_id(
|
||||
target_file_id
|
||||
)
|
||||
if len(content) > self.PREVIEW_SIZE:
|
||||
preview_text += "..."
|
||||
preview = preview_text
|
||||
preview = raw[: self.PREVIEW_SIZE].decode("utf-8", errors="replace")
|
||||
if len(raw) > self.PREVIEW_SIZE:
|
||||
preview += "..."
|
||||
except Exception:
|
||||
pass # Preview is optional
|
||||
pass
|
||||
|
||||
msg = (
|
||||
f"File: {file_info.name} at workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes, {file_info.mime_type})"
|
||||
)
|
||||
if save_to_path:
|
||||
msg += f" — saved to {save_to_path}"
|
||||
else:
|
||||
msg += (
|
||||
" — use read_workspace_file with this file_id to retrieve content"
|
||||
)
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mime_type,
|
||||
size_bytes=file_info.size_bytes,
|
||||
download_url=download_url,
|
||||
download_url=f"workspace://{target_file_id}",
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.",
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {str(e)}",
|
||||
message=f"Failed to read workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -395,7 +534,9 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
"Provide content as plain text via 'content', OR base64-encoded via "
|
||||
"'content_base64', OR copy a file from the ephemeral working directory "
|
||||
"via 'source_path'. Exactly one of these three is required. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
@@ -410,9 +551,30 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Plain text content to write. Use this for text files "
|
||||
"(code, configs, documents, etc.). "
|
||||
"Mutually exclusive with content_base64 and source_path."
|
||||
),
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded file content",
|
||||
"description": (
|
||||
"Base64-encoded file content. Use this for binary files "
|
||||
"(images, PDFs, etc.). "
|
||||
"Mutually exclusive with content and source_path."
|
||||
),
|
||||
},
|
||||
"source_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path to a file in the ephemeral working directory to "
|
||||
"copy to workspace (e.g., '/tmp/copilot-.../output.csv'). "
|
||||
"Use this to persist files created by bash_exec or SDK Write. "
|
||||
"Mutually exclusive with content and content_base64."
|
||||
),
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
@@ -434,7 +596,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content_base64"],
|
||||
"required": ["filename"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -448,82 +610,92 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
message="Please provide a filename", session_id=session_id
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
)
|
||||
source_path_arg: str | None = kwargs.get("source_path")
|
||||
content_text: str | None = kwargs.get("content")
|
||||
content_b64: str | None = kwargs.get("content_base64")
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
session_id=session_id,
|
||||
)
|
||||
resolved = _resolve_write_content(
|
||||
content_text,
|
||||
content_b64,
|
||||
source_path_arg,
|
||||
session_id,
|
||||
)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
content: bytes = resolved
|
||||
|
||||
# Check size
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
max_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
file_record = await manager.write_file(
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
path=kwargs.get("path"),
|
||||
mime_type=kwargs.get("mime_type"),
|
||||
overwrite=kwargs.get("overwrite", False),
|
||||
)
|
||||
|
||||
# Build informative source label and message.
|
||||
if source_path_arg:
|
||||
source = f"copied from {source_path_arg}"
|
||||
msg = (
|
||||
f"Copied {source_path_arg} → workspace:{rec.path} "
|
||||
f"({rec.size_bytes:,} bytes)"
|
||||
)
|
||||
elif content_b64:
|
||||
source = "base64"
|
||||
msg = (
|
||||
f"Wrote {rec.name} to workspace ({rec.size_bytes:,} bytes, "
|
||||
f"decoded from base64)"
|
||||
)
|
||||
else:
|
||||
source = "content"
|
||||
msg = f"Wrote {rec.name} to workspace ({rec.size_bytes:,} bytes)"
|
||||
|
||||
# Include a short preview for text content.
|
||||
preview: str | None = None
|
||||
if _is_text_mime(rec.mime_type):
|
||||
try:
|
||||
preview = content[:200].decode("utf-8", errors="replace")
|
||||
if len(content) > 200:
|
||||
preview += "..."
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.size_bytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
file_id=rec.id,
|
||||
name=rec.name,
|
||||
path=rec.path,
|
||||
size_bytes=rec.size_bytes,
|
||||
source=source,
|
||||
content_preview=preview,
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {str(e)}",
|
||||
message=f"Failed to write workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -576,61 +748,42 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
success = await manager.delete_file(target_file_id)
|
||||
|
||||
if not success:
|
||||
if not await manager.delete_file(target_file_id):
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}",
|
||||
session_id=session_id,
|
||||
message=f"File not found: {target_file_id}", session_id=session_id
|
||||
)
|
||||
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message="File deleted successfully",
|
||||
message=(
|
||||
f"Deleted {file_info.name} from workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes)"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {str(e)}",
|
||||
message=f"Failed to delete workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
"""Tests for workspace file tool helpers and path validation."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools._test_data import make_session, setup_test_data
|
||||
from backend.copilot.tools.workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WorkspaceDeleteResponse,
|
||||
WorkspaceFileListResponse,
|
||||
WorkspaceWriteResponse,
|
||||
WriteWorkspaceFileTool,
|
||||
_resolve_write_content,
|
||||
_validate_ephemeral_path,
|
||||
)
|
||||
|
||||
# Re-export so pytest discovers the session-scoped fixture
|
||||
setup_test_data = setup_test_data
|
||||
|
||||
# We need to mock make_session_path to return a known temp dir for tests.
|
||||
# The real one uses WORKSPACE_PREFIX = "/tmp/copilot-"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ephemeral_dir(tmp_path, monkeypatch):
|
||||
"""Create a temp dir that acts as the ephemeral session directory."""
|
||||
session_dir = tmp_path / "copilot-test-session"
|
||||
session_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.workspace_files.make_session_path",
|
||||
lambda session_id: str(session_dir),
|
||||
)
|
||||
return session_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_ephemeral_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateEphemeralPath:
|
||||
def test_valid_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "file.txt"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert result == os.path.realpath(str(target))
|
||||
|
||||
def test_path_traversal_rejected(self, ephemeral_dir):
|
||||
evil_path = str(ephemeral_dir / ".." / "etc" / "passwd")
|
||||
result = _validate_ephemeral_path(evil_path, param_name="test", session_id="s1")
|
||||
# Should return ErrorResponse
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_different_session_rejected(self, ephemeral_dir, tmp_path):
|
||||
other_dir = tmp_path / "copilot-evil-session"
|
||||
other_dir.mkdir()
|
||||
target = other_dir / "steal.txt"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_symlink_escape_rejected(self, ephemeral_dir, tmp_path):
|
||||
"""Symlink inside session dir pointing outside should be rejected."""
|
||||
outside_file = tmp_path / "secret.txt"
|
||||
outside_file.write_text("secret")
|
||||
symlink = ephemeral_dir / "link.txt"
|
||||
symlink.symlink_to(outside_file)
|
||||
result = _validate_ephemeral_path(
|
||||
str(symlink), param_name="test", session_id="s1"
|
||||
)
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_nested_path_valid(self, ephemeral_dir):
|
||||
nested = ephemeral_dir / "subdir" / "deep"
|
||||
nested.mkdir(parents=True)
|
||||
target = nested / "data.csv"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_write_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveWriteContent:
|
||||
def test_no_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, None, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_multiple_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content("text", "b64data", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_plain_text_content(self):
|
||||
result = _resolve_write_content("hello world", None, None, "s1")
|
||||
assert result == b"hello world"
|
||||
|
||||
def test_base64_content(self):
|
||||
raw = b"binary data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
result = _resolve_write_content(None, b64, None, "s1")
|
||||
assert result == raw
|
||||
|
||||
def test_invalid_base64_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "base64" in result.message.lower()
|
||||
|
||||
def test_source_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "input.txt"
|
||||
target.write_bytes(b"file content")
|
||||
result = _resolve_write_content(None, None, str(target), "s1")
|
||||
assert result == b"file content"
|
||||
|
||||
def test_source_path_not_found(self, ephemeral_dir):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
missing = str(ephemeral_dir / "nope.txt")
|
||||
result = _resolve_write_content(None, None, missing, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("nope")
|
||||
result = _resolve_write_content(None, None, str(outside), "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_sources_treated_as_none(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
# All empty strings → same as no sources
|
||||
result = _resolve_write_content("", "", "", "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_source_path_with_text(self):
|
||||
# source_path="" should be normalised to None, so only content counts
|
||||
result = _resolve_write_content("hello", "", "", "s1")
|
||||
assert result == b"hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E: workspace file tool round-trip (write → list → read → delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_workspace_file_round_trip(setup_test_data):
|
||||
"""E2E: write a file, list it, read it back (with save_to_path), then delete it."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
session_id = session.session_id
|
||||
|
||||
# ---- Write ----
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="test_round_trip.txt",
|
||||
content="Hello from e2e test!",
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
file_id = write_resp.file_id
|
||||
|
||||
# ---- List ----
|
||||
list_tool = ListWorkspaceFilesTool()
|
||||
list_resp = await list_tool._execute(user_id=user.id, session=session)
|
||||
assert isinstance(list_resp, WorkspaceFileListResponse), list_resp.message
|
||||
assert any(f.file_id == file_id for f in list_resp.files)
|
||||
|
||||
# ---- Read (inline) ----
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
read_resp = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id
|
||||
)
|
||||
from backend.copilot.tools.workspace_files import WorkspaceFileContentResponse
|
||||
|
||||
assert isinstance(read_resp, WorkspaceFileContentResponse), read_resp.message
|
||||
decoded = base64.b64decode(read_resp.content_base64).decode()
|
||||
assert decoded == "Hello from e2e test!"
|
||||
|
||||
# ---- Read with save_to_path ----
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
|
||||
ephemeral_dir = make_session_path(session_id)
|
||||
os.makedirs(ephemeral_dir, exist_ok=True)
|
||||
save_path = os.path.join(ephemeral_dir, "saved_copy.txt")
|
||||
|
||||
read_resp2 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, save_to_path=save_path
|
||||
)
|
||||
assert not isinstance(read_resp2, type(None))
|
||||
assert os.path.exists(save_path)
|
||||
with open(save_path) as f:
|
||||
assert f.read() == "Hello from e2e test!"
|
||||
|
||||
# ---- Delete ----
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
del_resp = await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id
|
||||
)
|
||||
assert isinstance(del_resp, WorkspaceDeleteResponse), del_resp.message
|
||||
assert del_resp.success is True
|
||||
|
||||
# Verify file is gone
|
||||
list_resp2 = await list_tool._execute(user_id=user.id, session=session)
|
||||
assert isinstance(list_resp2, WorkspaceFileListResponse)
|
||||
assert not any(f.file_id == file_id for f in list_resp2.files)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_write_workspace_file_source_path(setup_test_data):
|
||||
"""E2E: write a file from ephemeral source_path to workspace."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
session_id = session.session_id
|
||||
|
||||
# Create a file in the ephemeral dir
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
|
||||
ephemeral_dir = make_session_path(session_id)
|
||||
os.makedirs(ephemeral_dir, exist_ok=True)
|
||||
source = os.path.join(ephemeral_dir, "generated_output.csv")
|
||||
with open(source, "w") as f:
|
||||
f.write("col1,col2\n1,2\n")
|
||||
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="output.csv",
|
||||
source_path=source,
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
|
||||
# Clean up
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=write_resp.file_id
|
||||
)
|
||||
@@ -10,12 +10,31 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useChat } from "@ai-sdk/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { DefaultChatTransport } from "ai";
|
||||
import type { UIMessage } from "ai";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useLongRunningToolPolling } from "./hooks/useLongRunningToolPolling";
|
||||
|
||||
const STREAM_START_TIMEOUT_MS = 12_000;
|
||||
|
||||
/** Mark any in-progress tool parts as completed/errored so spinners stop. */
|
||||
function resolveInProgressTools(
|
||||
messages: UIMessage[],
|
||||
outcome: "completed" | "cancelled",
|
||||
): UIMessage[] {
|
||||
return messages.map((msg) => ({
|
||||
...msg,
|
||||
parts: msg.parts.map((part) =>
|
||||
"state" in part &&
|
||||
(part.state === "input-streaming" || part.state === "input-available")
|
||||
? outcome === "cancelled"
|
||||
? { ...part, state: "output-error" as const, errorText: "Cancelled" }
|
||||
: { ...part, state: "output-available" as const, output: "" }
|
||||
: part,
|
||||
),
|
||||
}));
|
||||
}
|
||||
|
||||
export function useCopilotPage() {
|
||||
const { isUserLoading, isLoggedIn } = useSupabase();
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
@@ -114,23 +133,7 @@ export function useCopilotPage() {
|
||||
// the cancel API to actually stop the executor and wait for confirmation.
|
||||
async function stop() {
|
||||
sdkStop();
|
||||
|
||||
// Mark any in-progress tool parts as errored so spinners stop.
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) => ({
|
||||
...msg,
|
||||
parts: msg.parts.map((part) =>
|
||||
"state" in part &&
|
||||
(part.state === "input-streaming" || part.state === "input-available")
|
||||
? {
|
||||
...part,
|
||||
state: "output-error" as const,
|
||||
errorText: "Cancelled",
|
||||
}
|
||||
: part,
|
||||
),
|
||||
})),
|
||||
);
|
||||
setMessages((prev) => resolveInProgressTools(prev, "cancelled"));
|
||||
|
||||
if (!sessionId) return;
|
||||
try {
|
||||
@@ -199,6 +202,18 @@ export function useCopilotPage() {
|
||||
resumeStream();
|
||||
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
|
||||
|
||||
// When the stream finishes, resolve any tool parts still showing spinners.
|
||||
// This can happen if the backend didn't emit StreamToolOutputAvailable for
|
||||
// a tool call before sending StreamFinish (e.g. SDK built-in tools).
|
||||
const prevStatusRef = useRef(status);
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
prevStatusRef.current = status;
|
||||
if (prev === "streaming" && status === "ready") {
|
||||
setMessages((msgs) => resolveInProgressTools(msgs, "completed"));
|
||||
}
|
||||
}, [status, setMessages]);
|
||||
|
||||
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
|
||||
// is in progress. When the backend completes, the session data will contain
|
||||
// the final tool output — this hook detects the change and updates messages.
|
||||
|
||||
Reference in New Issue
Block a user