mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into HEAD
This commit is contained in:
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -149,7 +149,7 @@ jobs:
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v3
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
|
||||
@@ -111,13 +111,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
# Copy Node.js installation for Prisma and agent-browser.
|
||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
||||
# proper symlinks so npm/npx can find their modules.
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -22,6 +22,7 @@ from backend.data.human_review import (
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -321,10 +322,13 @@ async def process_review_action(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
|
||||
@@ -120,6 +120,10 @@ class UploadFileResponse(BaseModel):
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class StorageUsageResponse(BaseModel):
|
||||
used_bytes: int
|
||||
limit_bytes: int
|
||||
@@ -151,6 +155,31 @@ async def download_file(
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file_id: str,
|
||||
) -> DeleteFileResponse:
|
||||
"""
|
||||
Soft-delete a workspace file and attempt to remove it from storage.
|
||||
|
||||
Used when a user clears a file input in the builder.
|
||||
"""
|
||||
workspace = await get_workspace(user_id)
|
||||
if workspace is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id)
|
||||
deleted = await manager.delete_file(file_id)
|
||||
if not deleted:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return DeleteFileResponse(deleted=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
@@ -218,7 +247,10 @@ async def upload_file(
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
|
||||
@@ -305,3 +305,55 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = client.get("/files/some-file-id/download")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"deleted": True}
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/nonexistent-id")
|
||||
assert response.status_code == 404
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
|
||||
@@ -13,6 +13,7 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -198,6 +199,20 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Propagate user/session context to Langfuse so all LLM calls within
|
||||
# this request are grouped under a single trace with proper attribution.
|
||||
_trace_ctx: Any = None
|
||||
try:
|
||||
_trace_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-baseline",
|
||||
tags=["baseline"],
|
||||
)
|
||||
_trace_ctx.__enter__()
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
@@ -272,7 +287,7 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
@@ -282,7 +297,7 @@ async def stream_chat_completion_baseline(
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
@@ -385,6 +400,13 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
|
||||
@@ -62,6 +62,10 @@ class ChatConfig(BaseSettings):
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
langfuse_prompt_cache_ttl: int = Field(
|
||||
default=300,
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
@@ -87,6 +91,10 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
@@ -121,7 +129,7 @@ class ChatConfig(BaseSettings):
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
"""Get E2B API key from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
|
||||
return v
|
||||
|
||||
@@ -129,7 +137,7 @@ class ChatConfig(BaseSettings):
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
@@ -139,13 +147,16 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
|
||||
# The SDK CLI picks it up from the env directly. Including it
|
||||
# would pair it with the OpenRouter base_url, causing auth failures.
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("CHAT_BASE_URL")
|
||||
if not v:
|
||||
@@ -167,6 +178,15 @@ class ChatConfig(BaseSettings):
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -6,6 +6,8 @@ in a thread-local context, following the graph executor pattern.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
@@ -108,8 +110,41 @@ class CoPilotProcessor:
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
# Skip the SDK's per-request CLI version check — the bundled CLI is
|
||||
# already version-matched to the SDK package.
|
||||
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
|
||||
|
||||
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
|
||||
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
|
||||
self._prewarm_cli()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def _prewarm_cli(self) -> None:
|
||||
"""Run the bundled CLI binary once to warm OS page caches."""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if cli_path:
|
||||
result = subprocess.run(
|
||||
[cli_path, "-v"],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
|
||||
result.returncode, # type: ignore[reportCallIssue]
|
||||
cli_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
@@ -208,9 +243,10 @@ class CoPilotProcessor:
|
||||
error_msg = None
|
||||
|
||||
try:
|
||||
# Choose service based on LaunchDarkly flag
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
use_sdk = await is_feature_enabled(
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
@@ -228,6 +264,8 @@ class CoPilotProcessor:
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
|
||||
@@ -705,19 +705,10 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
try:
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
# Invalidate the cache so the next access reloads from DB with the
|
||||
# updated title. This avoids a read-modify-write on the full session
|
||||
# blob, which could overwrite concurrent message updates.
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,6 +10,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
@@ -26,6 +27,7 @@ async def stream_chat_completion_dummy(
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
|
||||
@@ -280,7 +280,9 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Parent directories are created automatically.",
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -22,6 +25,7 @@ from claude_agent_sdk import (
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
@@ -55,6 +59,7 @@ from ..service import (
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tools.workspace_files import get_manager
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
@@ -177,6 +182,13 @@ Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
|
||||
### Large tool outputs
|
||||
When a tool output exceeds the display limit, it is automatically saved to
|
||||
the persistent workspace. The truncated output includes a
|
||||
`<tool-output-truncated>` tag with the workspace path. Use
|
||||
`read_workspace_file(path="...", offset=N, length=50000)` to retrieve
|
||||
additional sections.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
@@ -288,15 +300,52 @@ def _resolve_sdk_model() -> str | None:
|
||||
Uses ``config.claude_agent_model`` if set, otherwise derives from
|
||||
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
|
||||
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
|
||||
|
||||
When ``use_claude_code_subscription`` is enabled and no explicit
|
||||
``claude_agent_model`` is set, returns ``None`` so the CLI uses the
|
||||
default model for the user's subscription plan.
|
||||
"""
|
||||
if config.claude_agent_model:
|
||||
return config.claude_agent_model
|
||||
if config.use_claude_code_subscription:
|
||||
return None
|
||||
model = config.model
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _validate_claude_code_subscription() -> None:
|
||||
"""Validate Claude CLI is installed and responds to ``--version``.
|
||||
|
||||
Cached so the blocking subprocess check runs at most once per process
|
||||
lifetime. A failure (CLI not installed) is a config error that requires
|
||||
a process restart anyway.
|
||||
"""
|
||||
claude_path = shutil.which("claude")
|
||||
if not claude_path:
|
||||
raise RuntimeError(
|
||||
"Claude Code CLI not found. Install it with: "
|
||||
"npm install -g @anthropic-ai/claude-code"
|
||||
)
|
||||
result = subprocess.run(
|
||||
[claude_path, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI check failed (exit {result.returncode}): "
|
||||
f"{result.stderr.strip()}"
|
||||
)
|
||||
logger.info(
|
||||
"Claude Code subscription mode: CLI version %s",
|
||||
result.stdout.strip(),
|
||||
)
|
||||
|
||||
|
||||
def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
@@ -317,7 +366,16 @@ def _build_sdk_env(
|
||||
falls back to its default credentials.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
if config.api_key and config.base_url:
|
||||
|
||||
if config.use_claude_code_subscription:
|
||||
# Claude Code subscription: let the CLI use its own logged-in auth.
|
||||
# Explicitly clear API key env vars so the subprocess doesn't pick
|
||||
# them up from the parent process and bypass subscription auth.
|
||||
_validate_claude_code_subscription()
|
||||
env["ANTHROPIC_API_KEY"] = ""
|
||||
env["ANTHROPIC_AUTH_TOKEN"] = ""
|
||||
env["ANTHROPIC_BASE_URL"] = ""
|
||||
elif config.api_key and config.base_url:
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path
|
||||
base = config.base_url.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
@@ -330,21 +388,24 @@ def _build_sdk_env(
|
||||
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
|
||||
env["ANTHROPIC_API_KEY"] = ""
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
# The ``x-session-id`` header is *required* for the Anthropic-native
|
||||
# ``/messages`` endpoint — without it broadcast silently drops the
|
||||
# trace even when org-level Langfuse integration is configured.
|
||||
def _safe(value: str) -> str:
|
||||
"""Strip CR/LF to prevent header injection, then truncate."""
|
||||
return value.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
# The ``x-session-id`` header is *required* for the Anthropic-native
|
||||
# ``/messages`` endpoint — without it broadcast silently drops the
|
||||
# trace even when org-level Langfuse integration is configured.
|
||||
def _safe(value: str) -> str:
|
||||
"""Strip CR/LF to prevent header injection, then truncate."""
|
||||
return value.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
headers: list[str] = []
|
||||
if session_id:
|
||||
headers.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
headers.append(f"x-user-id: {_safe(user_id)}")
|
||||
# Only inject headers when routing through OpenRouter/proxy — they're
|
||||
# meaningless (and leak internal IDs) when using subscription mode.
|
||||
if headers and env.get("ANTHROPIC_BASE_URL"):
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(headers)
|
||||
|
||||
headers: list[str] = []
|
||||
if session_id:
|
||||
headers.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
headers.append(f"x-user-id: {_safe(user_id)}")
|
||||
if headers:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(headers)
|
||||
return env
|
||||
|
||||
|
||||
@@ -568,15 +629,142 @@ async def _build_query_message(
|
||||
return current_message, False
|
||||
|
||||
|
||||
# Claude API vision-supported image types.
|
||||
_VISION_MIME_TYPES = frozenset({"image/png", "image/jpeg", "image/gif", "image/webp"})
|
||||
|
||||
# Max size for embedding images directly in the user message (20 MiB raw).
|
||||
_MAX_INLINE_IMAGE_BYTES = 20 * 1024 * 1024
|
||||
|
||||
# Matches characters unsafe for filenames.
|
||||
_UNSAFE_FILENAME = re.compile(r"[^\w.\-]")
|
||||
|
||||
|
||||
def _save_to_sdk_cwd(sdk_cwd: str, filename: str, content: bytes) -> str:
|
||||
"""Write file content to the SDK ephemeral directory.
|
||||
|
||||
Returns the absolute path. Adds a numeric suffix on name collisions.
|
||||
"""
|
||||
safe = _UNSAFE_FILENAME.sub("_", filename) or "file"
|
||||
candidate = os.path.join(sdk_cwd, safe)
|
||||
if os.path.exists(candidate):
|
||||
stem, ext = os.path.splitext(safe)
|
||||
idx = 1
|
||||
while os.path.exists(candidate):
|
||||
candidate = os.path.join(sdk_cwd, f"{stem}_{idx}{ext}")
|
||||
idx += 1
|
||||
with open(candidate, "wb") as f:
|
||||
f.write(content)
|
||||
return candidate
|
||||
|
||||
|
||||
class PreparedAttachments(BaseModel):
|
||||
"""Result of preparing file attachments for a query."""
|
||||
|
||||
hint: str = ""
|
||||
"""Text hint describing the files (appended to the user message)."""
|
||||
|
||||
image_blocks: list[dict[str, Any]] = []
|
||||
"""Claude API image content blocks to embed in the user message."""
|
||||
|
||||
|
||||
async def _prepare_file_attachments(
|
||||
file_ids: list[str],
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
) -> PreparedAttachments:
|
||||
"""Download workspace files and prepare them for Claude.
|
||||
|
||||
Images (PNG/JPEG/GIF/WebP) are embedded directly as vision content blocks
|
||||
in the user message so Claude can see them without tool calls.
|
||||
|
||||
Non-image files (PDFs, text, etc.) are saved to *sdk_cwd* so the CLI's
|
||||
built-in Read tool can access them.
|
||||
|
||||
Returns a :class:`PreparedAttachments` with a text hint and any image
|
||||
content blocks.
|
||||
"""
|
||||
empty = PreparedAttachments(hint="", image_blocks=[])
|
||||
if not file_ids or not user_id:
|
||||
return empty
|
||||
|
||||
try:
|
||||
manager = await get_manager(user_id, session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create workspace manager for file attachments",
|
||||
exc_info=True,
|
||||
)
|
||||
return empty
|
||||
|
||||
image_blocks: list[dict[str, Any]] = []
|
||||
file_descriptions: list[str] = []
|
||||
|
||||
for fid in file_ids:
|
||||
try:
|
||||
file_info = await manager.get_file_info(fid)
|
||||
if file_info is None:
|
||||
continue
|
||||
content = await manager.read_file_by_id(fid)
|
||||
mime = (file_info.mime_type or "").split(";")[0].strip().lower()
|
||||
|
||||
# Images: embed directly in the user message as vision blocks
|
||||
if mime in _VISION_MIME_TYPES and len(content) <= _MAX_INLINE_IMAGE_BYTES:
|
||||
b64 = base64.b64encode(content).decode("ascii")
|
||||
image_blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime,
|
||||
"data": b64,
|
||||
},
|
||||
}
|
||||
)
|
||||
file_descriptions.append(
|
||||
f"- {file_info.name} ({mime}, "
|
||||
f"{file_info.size_bytes:,} bytes) [embedded as image]"
|
||||
)
|
||||
else:
|
||||
# Non-image files: save to sdk_cwd for Read tool access
|
||||
local_path = _save_to_sdk_cwd(sdk_cwd, file_info.name, content)
|
||||
file_descriptions.append(
|
||||
f"- {file_info.name} ({mime}, "
|
||||
f"{file_info.size_bytes:,} bytes) saved to {local_path}"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to prepare file %s", fid[:12], exc_info=True)
|
||||
|
||||
if not file_descriptions:
|
||||
return empty
|
||||
|
||||
noun = "file" if len(file_descriptions) == 1 else "files"
|
||||
has_non_images = len(file_descriptions) > len(image_blocks)
|
||||
read_hint = " Use the Read tool to view non-image files." if has_non_images else ""
|
||||
hint = (
|
||||
f"[The user attached {len(file_descriptions)} {noun}.{read_hint}\n"
|
||||
+ "\n".join(file_descriptions)
|
||||
+ "]"
|
||||
)
|
||||
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream chat completion using Claude Agent SDK."""
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Args:
|
||||
file_ids: Optional workspace file IDs attached to the user's message.
|
||||
Images are embedded as vision content blocks; other files are
|
||||
saved to the SDK working directory for the Read tool.
|
||||
"""
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -683,54 +871,108 @@ async def stream_chat_completion_sdk(
|
||||
code="sdk_cwd_error",
|
||||
)
|
||||
return
|
||||
# Set up E2B sandbox for persistent cloud execution when configured.
|
||||
# When active, MCP file tools route directly to the sandbox filesystem
|
||||
# so bash_exec and file tools share the same /home/user directory.
|
||||
if config.use_e2b_sandbox and not config.e2b_api_key:
|
||||
logger.warning(
|
||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||
session_id[:12],
|
||||
)
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
try:
|
||||
e2b_sandbox = await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=config.e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
timeout=config.e2b_sandbox_timeout,
|
||||
)
|
||||
except Exception as e2b_err:
|
||||
logger.error(
|
||||
"[E2B] [%s] Setup failed: %s",
|
||||
# --- Run independent async I/O operations in parallel ---
|
||||
# E2B sandbox setup, system prompt build (Langfuse + DB), and transcript
|
||||
# download are independent network calls. Running them concurrently
|
||||
# saves ~200-500ms compared to sequential execution.
|
||||
|
||||
async def _setup_e2b():
|
||||
"""Set up E2B sandbox if configured, return sandbox or None."""
|
||||
if config.use_e2b_sandbox and not config.e2b_api_key:
|
||||
logger.warning(
|
||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||
session_id[:12],
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
e2b_sandbox = None
|
||||
return None
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
try:
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=config.e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
timeout=config.e2b_sandbox_timeout,
|
||||
)
|
||||
except Exception as e2b_err:
|
||||
logger.error(
|
||||
"[E2B] [%s] Setup failed: %s",
|
||||
session_id[:12],
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
if not (
|
||||
config.claude_agent_use_resume and user_id and len(session.messages) > 1
|
||||
):
|
||||
return None
|
||||
try:
|
||||
return await download_transcript(user_id, session_id)
|
||||
except Exception as transcript_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Transcript download failed, continuing without "
|
||||
"--resume: %s",
|
||||
session_id[:12],
|
||||
transcript_err,
|
||||
)
|
||||
return None
|
||||
|
||||
e2b_sandbox, (base_system_prompt, _), dl = await asyncio.gather(
|
||||
_setup_e2b(),
|
||||
_build_system_prompt(user_id, has_conversation_history=has_history),
|
||||
_fetch_transcript(),
|
||||
)
|
||||
|
||||
use_e2b = e2b_sandbox is not None
|
||||
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
system_prompt += (
|
||||
system_prompt = base_system_prompt + (
|
||||
_E2B_TOOL_SUPPLEMENT
|
||||
if use_e2b
|
||||
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
|
||||
)
|
||||
|
||||
# Process transcript download result
|
||||
transcript_msg_count = 0
|
||||
if dl:
|
||||
is_valid = validate_transcript(dl.content)
|
||||
if is_valid:
|
||||
logger.info(
|
||||
f"[SDK] Transcript available for session {session_id}: "
|
||||
f"{len(dl.content)}B, msg_count={dl.message_count}"
|
||||
)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Transcript downloaded but invalid for {session_id}"
|
||||
)
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"[SDK] No transcript available for {session_id} "
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
set_execution_context(user_id, session, sandbox=e2b_sandbox, sdk_cwd=sdk_cwd)
|
||||
|
||||
# Fail fast when no API credentials are available at all
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
||||
"(or CHAT_API_KEY) for OpenRouter routing, "
|
||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
"CHAT_API_KEY, or ANTHROPIC_API_KEY for API access, "
|
||||
"or CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true to use "
|
||||
"Claude Code CLI subscription (requires `claude login`)."
|
||||
)
|
||||
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
@@ -767,37 +1009,6 @@ async def stream_chat_completion_sdk(
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
is_valid = bool(dl and validate_transcript(dl.content))
|
||||
if dl and is_valid:
|
||||
logger.info(
|
||||
f"[SDK] Transcript available for session {session_id}: "
|
||||
f"{len(dl.content)}B, msg_count={dl.message_count}"
|
||||
)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
elif dl:
|
||||
logger.warning(
|
||||
f"[SDK] Transcript downloaded but invalid for {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] No transcript available for {session_id} "
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
@@ -854,19 +1065,48 @@ async def stream_chat_completion_sdk(
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
)
|
||||
# If files are attached, prepare them: images become vision
|
||||
# content blocks in the user message, other files go to sdk_cwd.
|
||||
attachments = await _prepare_file_attachments(
|
||||
file_ids or [], user_id or "", session_id, sdk_cwd
|
||||
)
|
||||
if attachments.hint:
|
||||
query_message = f"{query_message}\n\n{attachments.hint}"
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
|
||||
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, "
|
||||
"query_len=%d, attached_files=%d, image_blocks=%d",
|
||||
session_id[:12],
|
||||
use_resume,
|
||||
len(session.messages),
|
||||
len(query_message),
|
||||
len(file_ids) if file_ids else 0,
|
||||
len(attachments.image_blocks),
|
||||
)
|
||||
|
||||
compaction.reset_for_query()
|
||||
if was_compacted:
|
||||
for ev in compaction.emit_pre_query(session):
|
||||
yield ev
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
if attachments.image_blocks:
|
||||
# Build multimodal content: image blocks + text
|
||||
content_blocks: list[dict[str, Any]] = [
|
||||
*attachments.image_blocks,
|
||||
{"type": "text", "text": query_message},
|
||||
]
|
||||
user_msg = {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": content_blocks},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": session_id,
|
||||
}
|
||||
assert client._transport is not None # noqa: SLF001
|
||||
await client._transport.write( # noqa: SLF001
|
||||
json.dumps(user_msg) + "\n"
|
||||
)
|
||||
else:
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
147
autogpt_platform/backend/backend/copilot/sdk/service_test.py
Normal file
147
autogpt_platform/backend/backend/copilot/sdk/service_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for SDK service helpers."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeFileInfo:
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_list_returns_empty(self, tmp_path):
|
||||
result = await _prepare_file_attachments([], "u", "s", str(tmp_path))
|
||||
assert result.hint == ""
|
||||
assert result.image_blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_embedded_as_vision_block(self, tmp_path):
|
||||
"""JPEG images should become vision content blocks, not files on disk."""
|
||||
raw = b"\xff\xd8\xff\xe0fake-jpeg"
|
||||
info = _FakeFileInfo(
|
||||
id="abc",
|
||||
name="photo.jpg",
|
||||
path="/photo.jpg",
|
||||
mime_type="image/jpeg",
|
||||
size_bytes=len(raw),
|
||||
)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = raw
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["abc"], "user1", "sess1", str(tmp_path)
|
||||
)
|
||||
|
||||
assert "1 file" in result.hint
|
||||
assert "photo.jpg" in result.hint
|
||||
assert "embedded as image" in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
block = result.image_blocks[0]
|
||||
assert block["type"] == "image"
|
||||
assert block["source"]["media_type"] == "image/jpeg"
|
||||
assert block["source"]["data"] == base64.b64encode(raw).decode("ascii")
|
||||
# Image should NOT be written to disk (embedded instead)
|
||||
assert not os.path.exists(os.path.join(tmp_path, "photo.jpg"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_saved_to_disk(self, tmp_path):
|
||||
"""PDFs should be saved to disk for Read tool access, not embedded."""
|
||||
info = _FakeFileInfo("f1", "doc.pdf", "/doc.pdf", "application/pdf", 50)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"%PDF-1.4 fake"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["f1"], "u", "s", str(tmp_path))
|
||||
|
||||
assert result.image_blocks == []
|
||||
saved = tmp_path / "doc.pdf"
|
||||
assert saved.exists()
|
||||
assert saved.read_bytes() == b"%PDF-1.4 fake"
|
||||
assert str(saved) in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_images_and_files(self, tmp_path):
|
||||
"""Images become blocks, non-images go to disk."""
|
||||
infos = {
|
||||
"id1": _FakeFileInfo("id1", "a.png", "/a.png", "image/png", 4),
|
||||
"id2": _FakeFileInfo("id2", "b.pdf", "/b.pdf", "application/pdf", 4),
|
||||
"id3": _FakeFileInfo("id3", "c.txt", "/c.txt", "text/plain", 4),
|
||||
}
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.side_effect = lambda fid: infos[fid]
|
||||
mgr.read_file_by_id.return_value = b"data"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["id1", "id2", "id3"], "u", "s", str(tmp_path)
|
||||
)
|
||||
|
||||
assert "3 files" in result.hint
|
||||
assert "a.png" in result.hint
|
||||
assert "b.pdf" in result.hint
|
||||
assert "c.txt" in result.hint
|
||||
# Only the image should be a vision block
|
||||
assert len(result.image_blocks) == 1
|
||||
assert result.image_blocks[0]["source"]["media_type"] == "image/png"
|
||||
# Non-image files should be on disk
|
||||
assert (tmp_path / "b.pdf").exists()
|
||||
assert (tmp_path / "c.txt").exists()
|
||||
# Read tool hint should appear (has non-image files)
|
||||
assert "Read tool" in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_singular_noun(self, tmp_path):
|
||||
info = _FakeFileInfo("x", "only.txt", "/only.txt", "text/plain", 2)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"hi"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["x"], "u", "s", str(tmp_path))
|
||||
|
||||
assert "1 file." in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_file_skipped(self, tmp_path):
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = None
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["missing-id"], "u", "s", str(tmp_path)
|
||||
)
|
||||
|
||||
assert result.hint == ""
|
||||
assert result.image_blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_only_no_read_hint(self, tmp_path):
|
||||
"""When all files are images, no Read tool hint should appear."""
|
||||
info = _FakeFileInfo("i1", "cat.png", "/cat.png", "image/png", 4)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"data"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["i1"], "u", "s", str(tmp_path))
|
||||
|
||||
assert "Read tool" not in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
@@ -102,6 +102,9 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
|
||||
# Used by workspace tools to save binary files for the CLI's built-in Read.
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
@@ -140,6 +143,7 @@ def set_execution_context(
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
@@ -150,6 +154,11 @@ def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK ephemeral working directory for the current turn."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
@@ -263,61 +272,12 @@ async def _execute_tool_sync(
|
||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
||||
)
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||
image_block = _extract_image_block(text)
|
||||
if image_block:
|
||||
content_blocks.append(image_block)
|
||||
|
||||
return {
|
||||
"content": content_blocks,
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
# MIME types that Claude can process as image content blocks.
|
||||
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
)
|
||||
|
||||
|
||||
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||
"""Extract an MCP image content block from a tool result JSON string.
|
||||
|
||||
Detects workspace file responses with ``content_base64`` and an image
|
||||
MIME type, returning an MCP-format image block that allows Claude to
|
||||
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
mime_type = data.get("mime_type", "")
|
||||
base64_content = data.get("content_base64", "")
|
||||
|
||||
# Only inline small images — large ones would exceed Claude's limits.
|
||||
# 32 KB raw ≈ ~43 KB base64.
|
||||
_MAX_IMAGE_BASE64_BYTES = 43_000
|
||||
if (
|
||||
mime_type in _SUPPORTED_IMAGE_TYPES
|
||||
and base64_content
|
||||
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
|
||||
):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": base64_content,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
@@ -423,18 +383,21 @@ _READ_TOOL_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP result helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
"""Extract concatenated text from an MCP response's content blocks."""
|
||||
content = result.get("content", [])
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "".join(parts)
|
||||
return ""
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
return "".join(
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
)
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for tool_adapter helpers: _text_from_mcp_result, truncation stash."""
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -7,6 +7,7 @@ from backend.util.truncate import truncate
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
get_sdk_cwd,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
@@ -54,6 +55,30 @@ class TestTextFromMcpResult:
|
||||
assert _text_from_mcp_result(result) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_sdk_cwd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSdkCwd:
|
||||
def test_returns_empty_string_by_default(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
assert get_sdk_cwd() == ""
|
||||
|
||||
def test_returns_set_value(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/copilot-test-123",
|
||||
)
|
||||
assert get_sdk_cwd() == "/tmp/copilot-test-123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stash / pop round-trip (the mechanism _truncating relies on)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -331,10 +331,10 @@ async def upload_transcript(
|
||||
) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen. We always overwrite — with ``--resume``
|
||||
the CLI may compact old tool results, so neither byte size nor line count
|
||||
is a reliable proxy for "newer".
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
@@ -353,33 +353,16 @@ async def upload_transcript(
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
new_size = len(encoded)
|
||||
|
||||
# Check existing transcript size to avoid overwriting newer with older
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
content_skipped = False
|
||||
try:
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.info(
|
||||
f"[Transcript] Skipping content upload — existing ({len(existing)}B) "
|
||||
f">= new ({new_size}B) for session {session_id}"
|
||||
)
|
||||
content_skipped = True
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No existing transcript or retrieval error — proceed with upload
|
||||
await storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
if not content_skipped:
|
||||
await storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Always update metadata (even when content is skipped) so message_count
|
||||
# stays current. The gap-fill logic in _build_query_message relies on
|
||||
# message_count to avoid re-compressing the same messages every turn.
|
||||
# Update metadata so message_count stays current. The gap-fill logic
|
||||
# in _build_query_message relies on it to avoid re-compressing messages.
|
||||
try:
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
@@ -393,9 +376,8 @@ async def upload_transcript(
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}, "
|
||||
f"content_skipped={content_skipped}) "
|
||||
f"[Transcript] Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -11,8 +11,10 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from langfuse import get_client
|
||||
from langfuse.openai import (
|
||||
AsyncOpenAI as LangfuseAsyncOpenAI, # pyright: ignore[reportPrivateImportUsage]
|
||||
)
|
||||
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import format_understanding_for_prompt
|
||||
@@ -26,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
settings = Settings()
|
||||
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
|
||||
langfuse = get_client()
|
||||
@@ -173,7 +175,6 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
"""
|
||||
if _is_langfuse_configured():
|
||||
try:
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
# In non-production environments, fetch the latest prompt version
|
||||
# instead of the production-labeled version for easier testing
|
||||
@@ -186,7 +187,7 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
langfuse.get_prompt,
|
||||
config.langfuse_prompt_name,
|
||||
label=label,
|
||||
cache_ttl_seconds=0,
|
||||
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
|
||||
)
|
||||
return prompt.compile(users_information=context)
|
||||
except Exception as e:
|
||||
|
||||
@@ -733,7 +733,10 @@ async def mark_session_completed(
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamFinish())
|
||||
await publish_chunk(
|
||||
turn_id,
|
||||
StreamFinish(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish StreamFinish for session {session_id}: {e}. "
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Base classes and shared utilities for chat tools."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -7,11 +8,98 @@ from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.truncate import truncate
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Persist full tool output to workspace when it exceeds this threshold.
|
||||
# Must be below _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py so we
|
||||
# capture the data before model_post_init middle-out truncation discards it.
|
||||
_LARGE_OUTPUT_THRESHOLD = 80_000
|
||||
|
||||
# Character budget for the middle-out preview. The total preview + wrapper
|
||||
# must stay below BOTH:
|
||||
# - _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py (our own truncation)
|
||||
# - Claude SDK's ~100 KB tool-result spill-to-disk threshold
|
||||
# to avoid double truncation/spilling. 95K + ~300 wrapper = ~95.3K, under both.
|
||||
_PREVIEW_CHARS = 95_000
|
||||
|
||||
|
||||
# Fields whose values are binary/base64 data — truncating them produces
|
||||
# garbage, so we replace them with a human-readable size summary instead.
|
||||
_BINARY_FIELD_NAMES = {"content_base64"}
|
||||
|
||||
|
||||
def _summarize_binary_fields(raw_json: str) -> str:
|
||||
"""Replace known binary fields with a size summary so truncate() doesn't
|
||||
produce garbled base64 in the middle-out preview."""
|
||||
try:
|
||||
data = json.loads(raw_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return raw_json
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return raw_json
|
||||
|
||||
changed = False
|
||||
for key in _BINARY_FIELD_NAMES:
|
||||
if key in data and isinstance(data[key], str) and len(data[key]) > 1_000:
|
||||
byte_size = len(data[key]) * 3 // 4 # approximate decoded size
|
||||
data[key] = f"<binary, ~{byte_size:,} bytes>"
|
||||
changed = True
|
||||
|
||||
return json.dumps(data, ensure_ascii=False) if changed else raw_json
|
||||
|
||||
|
||||
async def _persist_and_summarize(
|
||||
raw_output: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
) -> str:
|
||||
"""Persist full output to workspace and return a middle-out preview with retrieval instructions.
|
||||
|
||||
On failure, returns the original ``raw_output`` unchanged so that the
|
||||
existing ``model_post_init`` middle-out truncation handles it as before.
|
||||
"""
|
||||
file_path = f"tool-outputs/{tool_call_id}.json"
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
await manager.write_file(
|
||||
content=raw_output.encode("utf-8"),
|
||||
filename=f"{tool_call_id}.json",
|
||||
path=file_path,
|
||||
mime_type="application/json",
|
||||
overwrite=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist large tool output for %s",
|
||||
tool_call_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return raw_output # fall back to normal truncation
|
||||
|
||||
total = len(raw_output)
|
||||
preview = truncate(_summarize_binary_fields(raw_output), _PREVIEW_CHARS)
|
||||
retrieval = (
|
||||
f"\nFull output ({total:,} chars) saved to workspace. "
|
||||
f"Use read_workspace_file("
|
||||
f'path="{file_path}", offset=<char_offset>, length=50000) '
|
||||
f"to read any section."
|
||||
)
|
||||
return (
|
||||
f'<tool-output-truncated total_chars={total} path="{file_path}">\n'
|
||||
f"{preview}\n"
|
||||
f"{retrieval}\n"
|
||||
f"</tool-output-truncated>"
|
||||
)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all chat tools."""
|
||||
@@ -67,7 +155,7 @@ class BaseTool:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous like "anon_123")
|
||||
user_id: User ID (None for anonymous users)
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
@@ -91,10 +179,21 @@ class BaseTool:
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
raw_output = result.model_dump_json()
|
||||
|
||||
if (
|
||||
len(raw_output) > _LARGE_OUTPUT_THRESHOLD
|
||||
and user_id
|
||||
and session.session_id
|
||||
):
|
||||
raw_output = await _persist_and_summarize(
|
||||
raw_output, user_id, session.session_id, tool_call_id
|
||||
)
|
||||
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=result.model_dump_json(),
|
||||
output=raw_output,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
|
||||
194
autogpt_platform/backend/backend/copilot/tools/base_test.py
Normal file
194
autogpt_platform/backend/backend/copilot/tools/base_test.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Tests for BaseTool large-output persistence in execute()."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.base import (
|
||||
_LARGE_OUTPUT_THRESHOLD,
|
||||
BaseTool,
|
||||
_persist_and_summarize,
|
||||
_summarize_binary_fields,
|
||||
)
|
||||
from backend.copilot.tools.models import ResponseType, ToolResponseBase
|
||||
|
||||
|
||||
class _HugeOutputTool(BaseTool):
|
||||
"""Fake tool that returns an arbitrarily large output."""
|
||||
|
||||
def __init__(self, output_size: int) -> None:
|
||||
self._output_size = output_size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "huge_output_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Returns a huge output"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
async def _execute(self, user_id, session, **kwargs) -> ToolResponseBase:
|
||||
return ToolResponseBase(
|
||||
type=ResponseType.ERROR,
|
||||
message="x" * self._output_size,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _persist_and_summarize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistAndSummarize:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_middle_out_preview_with_retrieval_instructions(self):
|
||||
raw = "A" * 200_000
|
||||
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.copilot.tools.base.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
),
|
||||
):
|
||||
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-123")
|
||||
|
||||
assert "<tool-output-truncated" in result
|
||||
assert "</tool-output-truncated>" in result
|
||||
assert "total_chars=200000" in result
|
||||
assert 'path="tool-outputs/tc-123.json"' in result
|
||||
assert "read_workspace_file" in result
|
||||
# Middle-out sentinel from truncate()
|
||||
assert "omitted" in result
|
||||
# Total result is much shorter than the raw output
|
||||
assert len(result) < len(raw)
|
||||
|
||||
# Verify write_file was called with full content
|
||||
mock_manager.write_file.assert_awaited_once()
|
||||
call_kwargs = mock_manager.write_file.call_args
|
||||
assert call_kwargs.kwargs["content"] == raw.encode("utf-8")
|
||||
assert call_kwargs.kwargs["path"] == "tool-outputs/tc-123.json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_workspace_error(self):
|
||||
"""If workspace write fails, return raw output for normal truncation."""
|
||||
raw = "B" * 200_000
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
with patch("backend.copilot.tools.base.workspace_db", return_value=mock_db):
|
||||
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-fail")
|
||||
|
||||
assert result == raw # unchanged — fallback to normal truncation
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseTool.execute — integration with persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBaseToolExecuteLargeOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_output_not_persisted(self):
|
||||
"""Outputs under the threshold go through without persistence."""
|
||||
tool = _HugeOutputTool(output_size=100)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.base._persist_and_summarize",
|
||||
new_callable=AsyncMock,
|
||||
) as persist_mock:
|
||||
result = await tool.execute("user-1", session, "tc-small")
|
||||
persist_mock.assert_not_awaited()
|
||||
assert "<tool-output-truncated" not in str(result.output)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_output_persisted(self):
|
||||
"""Outputs over the threshold trigger persistence + preview."""
|
||||
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
mock_manager = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.copilot.tools.base.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
),
|
||||
):
|
||||
result = await tool.execute("user-1", session, "tc-big")
|
||||
|
||||
assert "<tool-output-truncated" in str(result.output)
|
||||
assert "read_workspace_file" in str(result.output)
|
||||
mock_manager.write_file.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_persistence_without_user_id(self):
|
||||
"""Anonymous users skip persistence (no workspace)."""
|
||||
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
# user_id=None → should not attempt persistence
|
||||
with patch(
|
||||
"backend.copilot.tools.base._persist_and_summarize",
|
||||
new_callable=AsyncMock,
|
||||
) as persist_mock:
|
||||
result = await tool.execute(None, session, "tc-anon")
|
||||
persist_mock.assert_not_awaited()
|
||||
# Output is set but not wrapped in <tool-output-truncated> tags
|
||||
# (it will be middle-out truncated by model_post_init instead)
|
||||
assert "<tool-output-truncated" not in str(result.output)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _summarize_binary_fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSummarizeBinaryFields:
|
||||
def test_replaces_large_content_base64(self):
|
||||
import json
|
||||
|
||||
data = {"content_base64": "A" * 10_000, "name": "file.png"}
|
||||
result = json.loads(_summarize_binary_fields(json.dumps(data)))
|
||||
assert result["name"] == "file.png"
|
||||
assert "<binary" in result["content_base64"]
|
||||
assert "bytes>" in result["content_base64"]
|
||||
|
||||
def test_preserves_small_content_base64(self):
|
||||
import json
|
||||
|
||||
data = {"content_base64": "AQID", "name": "tiny.bin"}
|
||||
result_str = _summarize_binary_fields(json.dumps(data))
|
||||
result = json.loads(result_str)
|
||||
assert result["content_base64"] == "AQID" # unchanged
|
||||
|
||||
def test_non_json_passthrough(self):
|
||||
raw = "not json at all"
|
||||
assert _summarize_binary_fields(raw) == raw
|
||||
|
||||
def test_no_binary_fields_unchanged(self):
|
||||
import json
|
||||
|
||||
data = {"message": "hello", "type": "info"}
|
||||
raw = json.dumps(data)
|
||||
assert _summarize_binary_fields(raw) == raw
|
||||
@@ -432,7 +432,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB for text/image files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
@@ -448,8 +448,10 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Optionally use 'save_to_path' to copy the file to the ephemeral "
|
||||
"working directory for processing with bash_exec or SDK tools. "
|
||||
"Use 'save_to_path' to copy the file to the working directory "
|
||||
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
|
||||
"Use 'offset' and 'length' for paginated reads of large files "
|
||||
"(e.g., persisted tool outputs). "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
@@ -473,9 +475,10 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the ephemeral "
|
||||
"working directory (e.g., '/tmp/copilot-.../data.csv') "
|
||||
"so it can be processed with bash_exec or SDK tools. "
|
||||
"If provided, save the file to this path in the working "
|
||||
"directory (cloud sandbox when E2B is active, or "
|
||||
"ephemeral dir otherwise) so it can be processed with "
|
||||
"bash_exec or file tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
@@ -486,6 +489,20 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Character offset to start reading from (0-based). "
|
||||
"Use with 'length' for paginated reads of large files."
|
||||
),
|
||||
},
|
||||
"length": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of characters to return. "
|
||||
"Defaults to full file. Use with 'offset' for paginated reads."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
@@ -510,6 +527,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
save_to_path: Optional[str] = kwargs.get("save_to_path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
char_offset: int = max(0, kwargs.get("offset", 0))
|
||||
char_length: Optional[int] = kwargs.get("length")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
@@ -532,6 +551,34 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
return result
|
||||
save_to_path = result
|
||||
|
||||
# Ranged read: return a character slice directly.
|
||||
if char_offset > 0 or char_length is not None:
|
||||
raw = cached_content or await manager.read_file_by_id(target_file_id)
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
total_chars = len(text)
|
||||
end = (
|
||||
char_offset + char_length
|
||||
if char_length is not None
|
||||
else total_chars
|
||||
)
|
||||
slice_text = text[char_offset:end]
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type="text/plain",
|
||||
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
),
|
||||
message=(
|
||||
f"Read chars {char_offset}–"
|
||||
f"{char_offset + len(slice_text)} "
|
||||
f"of {total_chars:,} total "
|
||||
f"from {file_info.name}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text = _is_text_mime(file_info.mime_type)
|
||||
is_image = file_info.mime_type in _IMAGE_MIME_TYPES
|
||||
|
||||
@@ -236,6 +236,65 @@ async def test_workspace_file_round_trip(setup_test_data):
|
||||
assert not any(f.file_id == file_id for f in list_resp2.files)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ranged reads (offset / length)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_with_offset_and_length(setup_test_data):
|
||||
"""Read a slice of a text file using offset and length."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
# Write a known-content file
|
||||
content = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" * 100 # 2600 chars
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="ranged_test.txt",
|
||||
content=content,
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
file_id = write_resp.file_id
|
||||
|
||||
from backend.copilot.tools.workspace_files import WorkspaceFileContentResponse
|
||||
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
|
||||
# Read with offset=100, length=50
|
||||
resp = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=100, length=50
|
||||
)
|
||||
assert isinstance(resp, WorkspaceFileContentResponse), resp.message
|
||||
decoded = base64.b64decode(resp.content_base64).decode()
|
||||
assert decoded == content[100:150]
|
||||
assert "100" in resp.message
|
||||
assert "2,600" in resp.message # total chars (comma-formatted)
|
||||
|
||||
# Read with offset only (no length) — returns from offset to end
|
||||
resp2 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=2500
|
||||
)
|
||||
assert isinstance(resp2, WorkspaceFileContentResponse)
|
||||
decoded2 = base64.b64decode(resp2.content_base64).decode()
|
||||
assert decoded2 == content[2500:]
|
||||
assert len(decoded2) == 100
|
||||
|
||||
# Read with offset beyond file length — returns empty string
|
||||
resp3 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=9999, length=10
|
||||
)
|
||||
assert isinstance(resp3, WorkspaceFileContentResponse)
|
||||
decoded3 = base64.b64decode(resp3.content_base64).decode()
|
||||
assert decoded3 == ""
|
||||
|
||||
# Cleanup
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
await delete_tool._execute(user_id=user.id, session=session, file_id=file_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_write_workspace_file_source_path(setup_test_data):
|
||||
"""E2E: write a file from ephemeral source_path to workspace."""
|
||||
|
||||
@@ -32,6 +32,7 @@ from backend.data.execution import (
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput, GraphInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -891,6 +892,7 @@ async def add_graph_execution(
|
||||
if execution_context is None:
|
||||
user = await udb.get_user_by_id(user_id)
|
||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
@@ -907,6 +909,8 @@ async def add_graph_execution(
|
||||
),
|
||||
# Execution hierarchy
|
||||
root_execution_id=graph_exec.id,
|
||||
# Workspace (enables workspace:// file resolution in blocks)
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -368,6 +368,12 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
@@ -643,6 +649,12 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
@@ -681,6 +693,10 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
# Verify workspace_id is set in the execution context
|
||||
assert "execution_context" in captured_kwargs
|
||||
assert captured_kwargs["execution_context"].workspace_id == "test-workspace-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
|
||||
|
||||
14
autogpt_platform/backend/poetry.lock
generated
14
autogpt_platform/backend/poetry.lock
generated
@@ -899,17 +899,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "claude-agent-sdk"
|
||||
version = "0.1.39"
|
||||
version = "0.1.45"
|
||||
description = "Python SDK for Claude Code"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ed6a79781f545b761b9fe467bc5ae213a103c9d3f0fe7a9dad3c01790ed58fa"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0c03b5a3772eaec42e29ea39240c7d24b760358082f2e36336db9e71dde3dda4"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d2665c9e87b6ffece590bcdd6eb9def47cde4809b0d2f66e0a61a719189be7c9"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-win_amd64.whl", hash = "sha256:d03324daf7076be79d2dd05944559aabf4cc11c98d3a574b992a442a7c7a26d6"},
|
||||
{file = "claude_agent_sdk-0.1.39.tar.gz", hash = "sha256:dcf0ebd5a638c9a7d9f3af7640932a9212b2705b7056e4f08bd3968a865b4268"},
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-macosx_11_0_arm64.whl", hash = "sha256:26a5cc60c3a394f5b814f6b2f67650819cbcd38c405bbdc11582b3e097b3a770"},
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:decc741b53e0b2c10a64fd84c15acca1102077d9f99941c54905172cd95160c9"},
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:7d48dcf4178c704e4ccbf3f1f4ebf20b3de3f03d0592086c1f3abd16b8ca441e"},
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-win_amd64.whl", hash = "sha256:d1cf34995109c513d8daabcae7208edc260b553b53462a9ac06a7c40e240a288"},
|
||||
{file = "claude_agent_sdk-0.1.45.tar.gz", hash = "sha256:97c1e981431b5af1e08c34731906ab8d4a58fe0774a04df0ea9587dcabc85151"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -8840,4 +8840,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "e7863413fda5e0a8b236e39a4c37390b52ae8c2f572c77df732abbd4280312b6"
|
||||
content-hash = "7189c9725ca42dfe6672632fe801c61248d87d3dd1259747b0ed9579b19fe088"
|
||||
|
||||
@@ -16,7 +16,7 @@ anthropic = "^0.79.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
claude-agent-sdk = "^0.1.39" # see copilot/sdk/sdk_compat_test.py for capability checks
|
||||
claude-agent-sdk = "0.1.45" # see copilot/sdk/sdk_compat_test.py for capability checks
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
|
||||
@@ -34,6 +34,7 @@ export const ContentRenderer: React.FC<{
|
||||
if (
|
||||
renderer?.name === "ImageRenderer" ||
|
||||
renderer?.name === "VideoRenderer" ||
|
||||
renderer?.name === "WorkspaceFileRenderer" ||
|
||||
!shortContent
|
||||
) {
|
||||
return (
|
||||
|
||||
@@ -64,6 +64,7 @@ export const ChatContainer = ({
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
headerSlot={headerSlot}
|
||||
sessionID={sessionId}
|
||||
/>
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
|
||||
@@ -3,19 +3,16 @@ import {
|
||||
ConversationContent,
|
||||
ConversationScrollButton,
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import {
|
||||
Message,
|
||||
MessageActions,
|
||||
MessageContent,
|
||||
} from "@/components/ai-elements/message";
|
||||
import { Message, MessageContent } from "@/components/ai-elements/message";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
|
||||
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
|
||||
import { parseSpecialMarkers } from "./helpers";
|
||||
import { AssistantMessageActions } from "./components/AssistantMessageActions";
|
||||
import { MessageAttachments } from "./components/MessageAttachments";
|
||||
import { MessagePartRenderer } from "./components/MessagePartRenderer";
|
||||
import { ThinkingIndicator } from "./components/ThinkingIndicator";
|
||||
import { CopyButton } from "./components/CopyButton";
|
||||
import { TTSButton } from "./components/TTSButton";
|
||||
import { parseSpecialMarkers } from "./helpers";
|
||||
|
||||
interface Props {
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
@@ -23,6 +20,24 @@ interface Props {
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
headerSlot?: React.ReactNode;
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
/** Collect all messages belonging to a turn: the user message + every
|
||||
* assistant message up to (but not including) the next user message. */
|
||||
function getTurnMessages(
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
lastAssistantIndex: number,
|
||||
): UIMessage<unknown, UIDataTypes, UITools>[] {
|
||||
const userIndex = messages.findLastIndex(
|
||||
(m, i) => i < lastAssistantIndex && m.role === "user",
|
||||
);
|
||||
const nextUserIndex = messages.findIndex(
|
||||
(m, i) => i > lastAssistantIndex && m.role === "user",
|
||||
);
|
||||
const start = userIndex >= 0 ? userIndex : lastAssistantIndex;
|
||||
const end = nextUserIndex >= 0 ? nextUserIndex : messages.length;
|
||||
return messages.slice(start, end);
|
||||
}
|
||||
|
||||
export function ChatMessagesContainer({
|
||||
@@ -31,12 +46,10 @@ export function ChatMessagesContainer({
|
||||
error,
|
||||
isLoading,
|
||||
headerSlot,
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
// Determine if something is visibly "in-flight" in the last assistant message:
|
||||
// - Text is actively streaming (last part is non-empty text)
|
||||
// - A tool call is pending (state is input-streaming or input-available)
|
||||
const hasInflight = (() => {
|
||||
if (lastMessage?.role !== "assistant") return false;
|
||||
const parts = lastMessage.parts;
|
||||
@@ -44,13 +57,11 @@ export function ChatMessagesContainer({
|
||||
|
||||
const lastPart = parts[parts.length - 1];
|
||||
|
||||
// Text is actively being written
|
||||
if (lastPart.type === "text" && lastPart.text.trim().length > 0)
|
||||
return true;
|
||||
|
||||
// A tool call is still pending (no output yet)
|
||||
if (
|
||||
lastPart.type.startsWith("tool-") &&
|
||||
lastPart.type.startsWith(TOOL_PART_PREFIX) &&
|
||||
"state" in lastPart &&
|
||||
(lastPart.state === "input-streaming" ||
|
||||
lastPart.state === "input-available")
|
||||
@@ -80,14 +91,29 @@ export function ChatMessagesContainer({
|
||||
messageIndex === messages.length - 1 &&
|
||||
message.role === "assistant";
|
||||
|
||||
const isCurrentlyStreaming =
|
||||
isLastAssistant &&
|
||||
(status === "streaming" || status === "submitted");
|
||||
|
||||
const isAssistant = message.role === "assistant";
|
||||
|
||||
// Past assistant messages are always done; the last one
|
||||
// is done only when the stream has finished.
|
||||
const isAssistantDone =
|
||||
const nextMessage = messages[messageIndex + 1];
|
||||
const isLastInTurn =
|
||||
isAssistant &&
|
||||
(!isLastAssistant ||
|
||||
(status !== "streaming" && status !== "submitted"));
|
||||
messageIndex <= messages.length - 1 &&
|
||||
(!nextMessage || nextMessage.role === "user");
|
||||
const textParts = message.parts.filter(
|
||||
(p): p is Extract<typeof p, { type: "text" }> => p.type === "text",
|
||||
);
|
||||
const lastTextPart = textParts[textParts.length - 1];
|
||||
const hasErrorMarker =
|
||||
lastTextPart !== undefined &&
|
||||
parseSpecialMarkers(lastTextPart.text).markerType === "error";
|
||||
const showActions =
|
||||
isLastInTurn &&
|
||||
!isCurrentlyStreaming &&
|
||||
textParts.length > 0 &&
|
||||
!hasErrorMarker;
|
||||
|
||||
const fileParts = message.parts.filter(
|
||||
(p): p is FileUIPart => p.type === "file",
|
||||
@@ -110,6 +136,11 @@ export function ChatMessagesContainer({
|
||||
partIndex={i}
|
||||
/>
|
||||
))}
|
||||
{isLastInTurn && !isCurrentlyStreaming && (
|
||||
<TurnStatsBar
|
||||
turnMessages={getTurnMessages(messages, messageIndex)}
|
||||
/>
|
||||
)}
|
||||
{isLastAssistant && showThinking && (
|
||||
<ThinkingIndicator active={showThinking} />
|
||||
)}
|
||||
@@ -120,30 +151,12 @@ export function ChatMessagesContainer({
|
||||
isUser={message.role === "user"}
|
||||
/>
|
||||
)}
|
||||
{isAssistantDone &&
|
||||
(() => {
|
||||
const textParts = message.parts.filter(
|
||||
(p): p is Extract<typeof p, { type: "text" }> =>
|
||||
p.type === "text",
|
||||
);
|
||||
|
||||
// Hide actions when the message ended with an error or cancellation
|
||||
const lastTextPart = textParts[textParts.length - 1];
|
||||
if (lastTextPart) {
|
||||
const { markerType } = parseSpecialMarkers(
|
||||
lastTextPart.text,
|
||||
);
|
||||
if (markerType === "error") return null;
|
||||
}
|
||||
|
||||
const textContent = textParts.map((p) => p.text).join("\n");
|
||||
return (
|
||||
<MessageActions>
|
||||
<CopyButton text={textContent} />
|
||||
<TTSButton text={textContent} />
|
||||
</MessageActions>
|
||||
);
|
||||
})()}
|
||||
{showActions && (
|
||||
<AssistantMessageActions
|
||||
message={message}
|
||||
sessionID={sessionID ?? null}
|
||||
/>
|
||||
)}
|
||||
</Message>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
MessageAction,
|
||||
MessageActions,
|
||||
} from "@/components/ai-elements/message";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CopySimple, ThumbsDown, ThumbsUp } from "@phosphor-icons/react";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useMessageFeedback } from "../useMessageFeedback";
|
||||
import { FeedbackModal } from "./FeedbackModal";
|
||||
import { TTSButton } from "./TTSButton";
|
||||
|
||||
interface Props {
|
||||
message: UIMessage<unknown, UIDataTypes, UITools>;
|
||||
sessionID: string | null;
|
||||
}
|
||||
|
||||
function extractTextFromParts(
|
||||
parts: UIMessage<unknown, UIDataTypes, UITools>["parts"],
|
||||
): string {
|
||||
return parts
|
||||
.filter((p) => p.type === "text")
|
||||
.map((p) => (p as { type: "text"; text: string }).text)
|
||||
.join("\n")
|
||||
.trim();
|
||||
}
|
||||
|
||||
export function AssistantMessageActions({ message, sessionID }: Props) {
|
||||
const {
|
||||
feedback,
|
||||
showFeedbackModal,
|
||||
handleCopy,
|
||||
handleUpvote,
|
||||
handleDownvoteClick,
|
||||
handleDownvoteSubmit,
|
||||
handleDownvoteCancel,
|
||||
} = useMessageFeedback({ sessionID, messageID: message.id });
|
||||
|
||||
const text = extractTextFromParts(message.parts);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MessageActions className="mt-1">
|
||||
<MessageAction
|
||||
tooltip="Copy"
|
||||
onClick={() => handleCopy(text)}
|
||||
variant="ghost"
|
||||
size="icon-sm"
|
||||
>
|
||||
<CopySimple size={16} weight="regular" />
|
||||
</MessageAction>
|
||||
|
||||
<MessageAction
|
||||
tooltip="Good response"
|
||||
onClick={handleUpvote}
|
||||
variant="ghost"
|
||||
size="icon-sm"
|
||||
disabled={feedback === "downvote"}
|
||||
className={cn(
|
||||
feedback === "upvote" && "text-green-300 hover:text-green-300",
|
||||
feedback === "downvote" && "!opacity-20",
|
||||
)}
|
||||
>
|
||||
<ThumbsUp
|
||||
size={16}
|
||||
weight={feedback === "upvote" ? "fill" : "regular"}
|
||||
/>
|
||||
</MessageAction>
|
||||
|
||||
<MessageAction
|
||||
tooltip="Bad response"
|
||||
onClick={handleDownvoteClick}
|
||||
variant="ghost"
|
||||
size="icon-sm"
|
||||
disabled={feedback === "upvote"}
|
||||
className={cn(
|
||||
feedback === "downvote" && "text-red-300 hover:text-red-300",
|
||||
feedback === "upvote" && "!opacity-20",
|
||||
)}
|
||||
>
|
||||
<ThumbsDown
|
||||
size={16}
|
||||
weight={feedback === "downvote" ? "fill" : "regular"}
|
||||
/>
|
||||
</MessageAction>
|
||||
|
||||
<TTSButton text={text} />
|
||||
</MessageActions>
|
||||
|
||||
{showFeedbackModal && (
|
||||
<FeedbackModal
|
||||
isOpen={showFeedbackModal}
|
||||
onSubmit={handleDownvoteSubmit}
|
||||
onCancel={handleDownvoteCancel}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
"use client";
|
||||
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
onSubmit: (comment: string) => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
const [comment, setComment] = useState("");
|
||||
|
||||
function handleSubmit() {
|
||||
onSubmit(comment);
|
||||
setComment("");
|
||||
}
|
||||
|
||||
function handleClose() {
|
||||
onCancel();
|
||||
setComment("");
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="What could have been better?"
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: (open) => {
|
||||
if (!open) handleClose();
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="mx-auto w-[95%] space-y-4">
|
||||
<p className="text-sm text-slate-600">
|
||||
Your feedback helps us improve. Share details below.
|
||||
</p>
|
||||
<Textarea
|
||||
placeholder="Tell us what went wrong or could be improved..."
|
||||
value={comment}
|
||||
onChange={(e) => setComment(e.target.value)}
|
||||
rows={4}
|
||||
maxLength={2000}
|
||||
className="resize-none"
|
||||
/>
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="text-xs text-slate-400">{comment.length}/2000</p>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" size="sm" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button size="sm" onClick={handleSubmit}>
|
||||
Submit feedback
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,11 @@ import {
|
||||
DownloadSimple as DownloadIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { FileUIPart } from "ai";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
ContentCard,
|
||||
ContentCardHeader,
|
||||
@@ -15,13 +20,60 @@ interface Props {
|
||||
isUser?: boolean;
|
||||
}
|
||||
|
||||
function renderFileContent(file: FileUIPart): React.ReactNode | null {
|
||||
if (!file.url) return null;
|
||||
const metadata: OutputMetadata = {
|
||||
mimeType: file.mediaType,
|
||||
filename: file.filename,
|
||||
type: file.mediaType?.startsWith("image/")
|
||||
? "image"
|
||||
: file.mediaType?.startsWith("video/")
|
||||
? "video"
|
||||
: undefined,
|
||||
};
|
||||
const renderer = globalRegistry.getRenderer(file.url, metadata);
|
||||
if (!renderer) return null;
|
||||
return (
|
||||
<OutputItem value={file.url} metadata={metadata} renderer={renderer} />
|
||||
);
|
||||
}
|
||||
|
||||
export function MessageAttachments({ files, isUser }: Props) {
|
||||
if (files.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex flex-col gap-2">
|
||||
{files.map((file, i) =>
|
||||
isUser ? (
|
||||
{files.map((file, i) => {
|
||||
const rendered = renderFileContent(file);
|
||||
return rendered ? (
|
||||
<div
|
||||
key={`${file.filename}-${i}`}
|
||||
className={`inline-block rounded-lg border p-1.5 ${
|
||||
isUser
|
||||
? "border-purple-300 bg-purple-50"
|
||||
: "border-neutral-200 bg-neutral-50"
|
||||
}`}
|
||||
>
|
||||
{rendered}
|
||||
<div
|
||||
className={`mt-1 flex items-center gap-1 px-0.5 text-xs ${
|
||||
isUser ? "text-zinc-600" : "text-neutral-500"
|
||||
}`}
|
||||
>
|
||||
<span className="truncate">{file.filename || "file"}</span>
|
||||
{file.url && (
|
||||
<a
|
||||
href={file.url}
|
||||
download
|
||||
aria-label="Download file"
|
||||
className="ml-auto shrink-0 opacity-50 hover:opacity-100"
|
||||
>
|
||||
<DownloadIcon className="h-3.5 w-3.5" />
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
) : isUser ? (
|
||||
<div
|
||||
key={`${file.filename}-${i}`}
|
||||
className="min-w-0 rounded-lg border border-purple-300 bg-purple-100 p-3"
|
||||
@@ -77,8 +129,8 @@ export function MessageAttachments({ files, isUser }: Props) {
|
||||
</div>
|
||||
</ContentCardHeader>
|
||||
</ContentCard>
|
||||
),
|
||||
)}
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import { environment } from "@/services/environment";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Args {
|
||||
sessionID: string | null;
|
||||
messageID: string;
|
||||
}
|
||||
|
||||
async function submitFeedbackToBackend(args: {
|
||||
sessionID: string;
|
||||
messageID: string;
|
||||
scoreName: string;
|
||||
scoreValue: number;
|
||||
comment?: string;
|
||||
}) {
|
||||
try {
|
||||
const { token } = await getWebSocketToken();
|
||||
if (!token) return;
|
||||
|
||||
await fetch(
|
||||
`${environment.getAGPTServerBaseUrl()}/api/chat/sessions/${args.sessionID}/feedback`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message_id: args.messageID,
|
||||
score_name: args.scoreName,
|
||||
score_value: args.scoreValue,
|
||||
comment: args.comment,
|
||||
}),
|
||||
},
|
||||
);
|
||||
} catch {
|
||||
// Feedback submission is best-effort; silently ignore failures
|
||||
}
|
||||
}
|
||||
|
||||
export function useMessageFeedback({ sessionID, messageID }: Args) {
|
||||
const [feedback, setFeedback] = useState<"upvote" | "downvote" | null>(null);
|
||||
const [showFeedbackModal, setShowFeedbackModal] = useState(false);
|
||||
|
||||
async function handleCopy(text: string) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
toast({ title: "Copied!", variant: "success", duration: 2000 });
|
||||
} catch {
|
||||
toast({
|
||||
title: "Failed to copy",
|
||||
variant: "destructive",
|
||||
duration: 2000,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (sessionID) {
|
||||
submitFeedbackToBackend({
|
||||
sessionID,
|
||||
messageID,
|
||||
scoreName: "copy",
|
||||
scoreValue: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleUpvote() {
|
||||
if (feedback) return;
|
||||
setFeedback("upvote");
|
||||
toast({
|
||||
title: "Thank you for your feedback!",
|
||||
variant: "success",
|
||||
duration: 3000,
|
||||
});
|
||||
if (sessionID) {
|
||||
submitFeedbackToBackend({
|
||||
sessionID,
|
||||
messageID,
|
||||
scoreName: "user-feedback",
|
||||
scoreValue: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleDownvoteClick() {
|
||||
if (feedback) return;
|
||||
setFeedback("downvote");
|
||||
setShowFeedbackModal(true);
|
||||
}
|
||||
|
||||
function handleDownvoteSubmit(comment: string) {
|
||||
setShowFeedbackModal(false);
|
||||
if (sessionID) {
|
||||
submitFeedbackToBackend({
|
||||
sessionID,
|
||||
messageID,
|
||||
scoreName: "user-feedback",
|
||||
scoreValue: 0,
|
||||
comment: comment || undefined,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleDownvoteCancel() {
|
||||
setShowFeedbackModal(false);
|
||||
setFeedback(null);
|
||||
}
|
||||
|
||||
return {
|
||||
feedback,
|
||||
showFeedbackModal,
|
||||
handleCopy,
|
||||
handleUpvote,
|
||||
handleDownvoteClick,
|
||||
handleDownvoteSubmit,
|
||||
handleDownvoteCancel,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
import type { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { getWorkDoneCounters } from "./useWorkDoneCounters";
|
||||
|
||||
interface Props {
|
||||
turnMessages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
}
|
||||
|
||||
export function TurnStatsBar({ turnMessages }: Props) {
|
||||
const { counters } = getWorkDoneCounters(turnMessages);
|
||||
|
||||
if (counters.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex items-center gap-1.5">
|
||||
{counters.map(function renderCounter(counter, index) {
|
||||
return (
|
||||
<span key={counter.category} className="flex items-center gap-1">
|
||||
{index > 0 && (
|
||||
<span className="text-xs text-neutral-300">·</span>
|
||||
)}
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{counter.count} {counter.label}
|
||||
</span>
|
||||
</span>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
export const TOOL_PART_PREFIX = "tool-";
|
||||
@@ -0,0 +1,75 @@
|
||||
import type { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "./constants";
|
||||
|
||||
const TOOL_TO_CATEGORY: Record<string, string> = {
|
||||
find_agent: "search",
|
||||
find_library_agent: "search",
|
||||
run_agent: "agent run",
|
||||
run_block: "block run",
|
||||
create_agent: "agent created",
|
||||
edit_agent: "agent edited",
|
||||
schedule_agent: "agent scheduled",
|
||||
};
|
||||
|
||||
const MAX_COUNTERS = 3;
|
||||
|
||||
function pluralize(label: string, count: number): string {
|
||||
if (count === 1) return label;
|
||||
|
||||
// "agent created" -> "agents created", "agent edited" -> "agents edited"
|
||||
const nounVerbMatch = label.match(
|
||||
/^(\w+)\s+(created|edited|scheduled|run)$/i,
|
||||
);
|
||||
if (nounVerbMatch) {
|
||||
return pluralizeWord(nounVerbMatch[1]) + " " + nounVerbMatch[2];
|
||||
}
|
||||
|
||||
return pluralizeWord(label);
|
||||
}
|
||||
|
||||
function pluralizeWord(word: string): string {
|
||||
if (word.endsWith("ch") || word.endsWith("sh") || word.endsWith("x"))
|
||||
return word + "es";
|
||||
return word + "s";
|
||||
}
|
||||
|
||||
export interface WorkDoneCounter {
|
||||
label: string;
|
||||
count: number;
|
||||
category: string;
|
||||
}
|
||||
|
||||
export function getWorkDoneCounters(
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
) {
|
||||
const categoryCounts = new Map<string, number>();
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.role !== "assistant") continue;
|
||||
|
||||
for (const part of message.parts) {
|
||||
if (!part.type.startsWith(TOOL_PART_PREFIX)) continue;
|
||||
|
||||
const toolName = part.type.replace(TOOL_PART_PREFIX, "");
|
||||
const category = TOOL_TO_CATEGORY[toolName];
|
||||
if (!category) continue;
|
||||
|
||||
categoryCounts.set(category, (categoryCounts.get(category) ?? 0) + 1);
|
||||
}
|
||||
}
|
||||
|
||||
const counters: WorkDoneCounter[] = Array.from(categoryCounts.entries())
|
||||
.map(function toCounter([category, count]) {
|
||||
return {
|
||||
label: pluralize(category, count),
|
||||
count,
|
||||
category,
|
||||
};
|
||||
})
|
||||
.sort(function byCountDesc(a, b) {
|
||||
return b.count - a.count;
|
||||
})
|
||||
.slice(0, MAX_COUNTERS);
|
||||
|
||||
return { counters };
|
||||
}
|
||||
@@ -26,11 +26,28 @@ import {
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
}
|
||||
|
||||
function RenderMedia({
|
||||
value,
|
||||
metadata,
|
||||
}: {
|
||||
value: string;
|
||||
metadata: OutputMetadata;
|
||||
}) {
|
||||
const renderer = globalRegistry.getRenderer(value, metadata);
|
||||
if (!renderer) return null;
|
||||
return <OutputItem value={value} metadata={metadata} renderer={renderer} />;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool name helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
@@ -612,14 +629,21 @@ function getFileAccordionData(
|
||||
|
||||
// Handle base64 content from workspace files
|
||||
let displayContent = content;
|
||||
const mimeType = getStringField(output, "mime_type");
|
||||
const isImage = mimeType?.startsWith("image/");
|
||||
if (output.content_base64 && typeof output.content_base64 === "string") {
|
||||
try {
|
||||
const bytes = Uint8Array.from(atob(output.content_base64), (c) =>
|
||||
c.charCodeAt(0),
|
||||
);
|
||||
displayContent = new TextDecoder().decode(bytes);
|
||||
} catch {
|
||||
displayContent = "[Binary content]";
|
||||
if (isImage) {
|
||||
// Render image inline — handled below in the JSX
|
||||
displayContent = null;
|
||||
} else {
|
||||
try {
|
||||
const bytes = Uint8Array.from(atob(output.content_base64), (c) =>
|
||||
c.charCodeAt(0),
|
||||
);
|
||||
displayContent = new TextDecoder().decode(bytes);
|
||||
} catch {
|
||||
displayContent = "[Binary content]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -697,6 +721,17 @@ function getFileAccordionData(
|
||||
</>
|
||||
) : writtenContent ? (
|
||||
<ContentCodeBlock>{writtenContent}</ContentCodeBlock>
|
||||
) : isImage &&
|
||||
output.content_base64 &&
|
||||
typeof output.content_base64 === "string" ? (
|
||||
<RenderMedia
|
||||
value={`data:${mimeType};base64,${output.content_base64}`}
|
||||
metadata={{
|
||||
type: "image",
|
||||
mimeType: mimeType ?? undefined,
|
||||
filename: filePath ?? undefined,
|
||||
}}
|
||||
/>
|
||||
) : displayContent ? (
|
||||
<ContentCodeBlock>{displayContent}</ContentCodeBlock>
|
||||
) : null}
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { isWorkspaceURI, parseWorkspaceURI } from "@/lib/workspace-uri";
|
||||
import {
|
||||
ContentBadge,
|
||||
ContentCard,
|
||||
@@ -23,30 +24,23 @@ interface Props {
|
||||
|
||||
const COLLAPSED_LIMIT = 3;
|
||||
|
||||
function isWorkspaceRef(value: unknown): value is string {
|
||||
return typeof value === "string" && value.startsWith("workspace://");
|
||||
}
|
||||
|
||||
function resolveForRenderer(value: unknown): {
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
} {
|
||||
if (!isWorkspaceRef(value)) return { value };
|
||||
if (!isWorkspaceURI(value)) return { value };
|
||||
|
||||
const withoutPrefix = value.replace("workspace://", "");
|
||||
const fileId = withoutPrefix.split("#")[0];
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||
const parsed = parseWorkspaceURI(value);
|
||||
if (!parsed) return { value };
|
||||
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
|
||||
const url = `/api/proxy${apiPath}`;
|
||||
|
||||
const hashIndex = value.indexOf("#");
|
||||
const mimeHint =
|
||||
hashIndex !== -1 ? value.slice(hashIndex + 1) || undefined : undefined;
|
||||
|
||||
const metadata: OutputMetadata = {};
|
||||
if (mimeHint) {
|
||||
metadata.mimeType = mimeHint;
|
||||
if (mimeHint.startsWith("image/")) metadata.type = "image";
|
||||
else if (mimeHint.startsWith("video/")) metadata.type = "video";
|
||||
if (parsed.mimeType) {
|
||||
metadata.mimeType = parsed.mimeType;
|
||||
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
|
||||
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
|
||||
}
|
||||
|
||||
return { value: url, metadata };
|
||||
@@ -71,7 +65,7 @@ function RenderOutputValue({ value }: { value: unknown }) {
|
||||
|
||||
// Fallback for audio workspace refs
|
||||
if (
|
||||
isWorkspaceRef(value) &&
|
||||
isWorkspaceURI(value) &&
|
||||
resolved.metadata?.mimeType?.startsWith("audio/")
|
||||
) {
|
||||
return (
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { isWorkspaceURI, parseWorkspaceURI } from "@/lib/workspace-uri";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import {
|
||||
@@ -46,30 +47,23 @@ interface Props {
|
||||
part: ViewAgentOutputToolPart;
|
||||
}
|
||||
|
||||
function isWorkspaceRef(value: unknown): value is string {
|
||||
return typeof value === "string" && value.startsWith("workspace://");
|
||||
}
|
||||
|
||||
function resolveForRenderer(value: unknown): {
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
} {
|
||||
if (!isWorkspaceRef(value)) return { value };
|
||||
if (!isWorkspaceURI(value)) return { value };
|
||||
|
||||
const withoutPrefix = value.replace("workspace://", "");
|
||||
const fileId = withoutPrefix.split("#")[0];
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||
const parsed = parseWorkspaceURI(value);
|
||||
if (!parsed) return { value };
|
||||
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
|
||||
const url = `/api/proxy${apiPath}`;
|
||||
|
||||
const hashIndex = value.indexOf("#");
|
||||
const mimeHint =
|
||||
hashIndex !== -1 ? value.slice(hashIndex + 1) || undefined : undefined;
|
||||
|
||||
const metadata: OutputMetadata = {};
|
||||
if (mimeHint) {
|
||||
metadata.mimeType = mimeHint;
|
||||
if (mimeHint.startsWith("image/")) metadata.type = "image";
|
||||
else if (mimeHint.startsWith("video/")) metadata.type = "video";
|
||||
if (parsed.mimeType) {
|
||||
metadata.mimeType = parsed.mimeType;
|
||||
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
|
||||
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
|
||||
}
|
||||
|
||||
return { value: url, metadata };
|
||||
@@ -94,7 +88,7 @@ function RenderOutputValue({ value }: { value: unknown }) {
|
||||
|
||||
// Fallback for audio workspace refs
|
||||
if (
|
||||
isWorkspaceRef(value) &&
|
||||
isWorkspaceURI(value) &&
|
||||
resolved.metadata?.mimeType?.startsWith("audio/")
|
||||
) {
|
||||
return (
|
||||
|
||||
@@ -6596,6 +6596,44 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/{file_id}": {
|
||||
"delete": {
|
||||
"tags": ["workspace"],
|
||||
"summary": "Delete a workspace file",
|
||||
"description": "Soft-delete a workspace file and attempt to remove it from storage.\n\nUsed when a user clears a file input in the builder.",
|
||||
"operationId": "deleteWorkspaceDelete a workspace file",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "file_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "File Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/DeleteFileResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/{file_id}/download": {
|
||||
"get": {
|
||||
"tags": ["workspace"],
|
||||
@@ -8248,6 +8286,12 @@
|
||||
"enum": ["TOP_UP", "USAGE", "GRANT", "REFUND", "CARD_CHECK"],
|
||||
"title": "CreditTransactionType"
|
||||
},
|
||||
"DeleteFileResponse": {
|
||||
"properties": { "deleted": { "type": "boolean", "title": "Deleted" } },
|
||||
"type": "object",
|
||||
"required": ["deleted"],
|
||||
"title": "DeleteFileResponse"
|
||||
},
|
||||
"DeleteGraphResponse": {
|
||||
"properties": {
|
||||
"version_counts": { "type": "integer", "title": "Version Counts" }
|
||||
|
||||
@@ -71,9 +71,13 @@ async function handleWorkspaceDownload(
|
||||
responseHeaders["Content-Disposition"] = contentDisposition;
|
||||
}
|
||||
|
||||
// Return the binary content
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
return new NextResponse(arrayBuffer, {
|
||||
const contentLength = response.headers.get("Content-Length");
|
||||
if (contentLength) {
|
||||
responseHeaders["Content-Length"] = contentLength;
|
||||
}
|
||||
|
||||
// Stream the response body directly instead of buffering in memory
|
||||
return new NextResponse(response.body, {
|
||||
status: 200,
|
||||
headers: responseHeaders,
|
||||
});
|
||||
@@ -255,7 +259,6 @@ async function handler(
|
||||
responseBody = await handleJsonRequest(req, method, backendUrl);
|
||||
} else if (contentType?.includes("multipart/form-data")) {
|
||||
responseBody = await handleFormDataRequest(req, backendUrl);
|
||||
responseHeaders["Content-Type"] = "text/plain";
|
||||
} else if (contentType?.includes("application/x-www-form-urlencoded")) {
|
||||
responseBody = await handleUrlEncodedRequest(req, method, backendUrl);
|
||||
} else {
|
||||
|
||||
@@ -5,9 +5,8 @@ import { ButtonGroup, ButtonGroupText } from "@/components/ui/button-group";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cjk } from "@streamdown/cjk";
|
||||
import { code } from "@/lib/streamdown-code-plugin";
|
||||
@@ -89,14 +88,10 @@ export const MessageAction = ({
|
||||
|
||||
if (tooltip) {
|
||||
return (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>{button}</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p>{tooltip}</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>{button}</TooltipTrigger>
|
||||
<TooltipContent>{tooltip}</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import { FileTextIcon, TrashIcon, UploadIcon } from "@phosphor-icons/react";
|
||||
import { Cross2Icon } from "@radix-ui/react-icons";
|
||||
import { FileTextIcon, TrashIcon, UploadIcon, X } from "@phosphor-icons/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { Button } from "../Button/Button";
|
||||
import { formatFileSize, getFileLabel } from "./helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Progress } from "../Progress/Progress";
|
||||
import { parseWorkspaceURI } from "@/lib/workspace-uri";
|
||||
import { Text } from "../Text/Text";
|
||||
|
||||
type UploadFileResult = {
|
||||
@@ -20,6 +19,7 @@ interface BaseProps {
|
||||
value?: string;
|
||||
placeholder?: string;
|
||||
onChange: (value: string) => void;
|
||||
onDeleteFile?: (fileURI: string) => void;
|
||||
className?: string;
|
||||
maxFileSize?: number;
|
||||
accept?: string | string[];
|
||||
@@ -30,7 +30,7 @@ interface BaseProps {
|
||||
interface UploadModeProps extends BaseProps {
|
||||
mode?: "upload";
|
||||
onUploadFile: (file: File) => Promise<UploadFileResult>;
|
||||
uploadProgress: number;
|
||||
uploadProgress?: number;
|
||||
}
|
||||
|
||||
interface Base64ModeProps extends BaseProps {
|
||||
@@ -45,6 +45,7 @@ export function FileInput(props: Props) {
|
||||
const {
|
||||
value,
|
||||
onChange,
|
||||
onDeleteFile,
|
||||
className,
|
||||
maxFileSize,
|
||||
accept,
|
||||
@@ -56,8 +57,6 @@ export function FileInput(props: Props) {
|
||||
|
||||
const onUploadFile =
|
||||
mode === "upload" ? (props as UploadModeProps).onUploadFile : undefined;
|
||||
const uploadProgress =
|
||||
mode === "upload" ? (props as UploadModeProps).uploadProgress : 0;
|
||||
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [uploadError, setUploadError] = useState<string | null>(null);
|
||||
@@ -69,8 +68,7 @@ export function FileInput(props: Props) {
|
||||
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const storageNote =
|
||||
"Files are stored securely and will be automatically deleted at most 24 hours after upload.";
|
||||
const storageNote = "Files are stored securely in your workspace.";
|
||||
|
||||
function acceptToString(a?: string | string[]) {
|
||||
if (!a) return "*/*";
|
||||
@@ -104,7 +102,7 @@ export function FileInput(props: Props) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const getFileLabelFromValue = (val: unknown): string => {
|
||||
function getFileLabelFromValue(val: unknown): string {
|
||||
// Handle object format from external API: { name, type, size, data }
|
||||
if (val && typeof val === "object") {
|
||||
const obj = val as Record<string, unknown>;
|
||||
@@ -124,11 +122,23 @@ export function FileInput(props: Props) {
|
||||
return "File";
|
||||
}
|
||||
|
||||
// Handle string values (data URIs or file paths)
|
||||
// Handle string values (workspace URIs, data URIs, or file paths)
|
||||
if (typeof val !== "string") {
|
||||
return "File";
|
||||
}
|
||||
|
||||
const wsURI = parseWorkspaceURI(val);
|
||||
if (wsURI) {
|
||||
if (wsURI.mimeType) {
|
||||
const parts = wsURI.mimeType.split("/");
|
||||
if (parts.length > 1) {
|
||||
return `${parts[1].toUpperCase()} file`;
|
||||
}
|
||||
return "File";
|
||||
}
|
||||
return "Uploaded file";
|
||||
}
|
||||
|
||||
if (val.startsWith("data:")) {
|
||||
const matches = val.match(/^data:([^;]+);/);
|
||||
if (matches?.[1]) {
|
||||
@@ -146,9 +156,9 @@ export function FileInput(props: Props) {
|
||||
}
|
||||
}
|
||||
return "File";
|
||||
};
|
||||
}
|
||||
|
||||
const processFileBase64 = (file: File) => {
|
||||
function processFileBase64(file: File) {
|
||||
setIsUploading(true);
|
||||
setUploadError(null);
|
||||
|
||||
@@ -168,9 +178,9 @@ export function FileInput(props: Props) {
|
||||
setIsUploading(false);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
};
|
||||
}
|
||||
|
||||
const uploadFile = async (file: File) => {
|
||||
async function uploadFile(file: File) {
|
||||
if (mode === "base64") {
|
||||
processFileBase64(file);
|
||||
return;
|
||||
@@ -184,6 +194,8 @@ export function FileInput(props: Props) {
|
||||
setIsUploading(true);
|
||||
setUploadError(null);
|
||||
|
||||
const oldURI = value;
|
||||
|
||||
try {
|
||||
const result = await onUploadFile(file);
|
||||
|
||||
@@ -194,15 +206,20 @@ export function FileInput(props: Props) {
|
||||
});
|
||||
|
||||
onChange(result.file_uri);
|
||||
|
||||
// Delete the old file only after the new upload succeeds
|
||||
if (oldURI && onDeleteFile) {
|
||||
onDeleteFile(oldURI);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Upload failed:", error);
|
||||
setUploadError(error instanceof Error ? error.message : "Upload failed");
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
function handleFileChange(event: React.ChangeEvent<HTMLInputElement>) {
|
||||
const file = event.target.files?.[0];
|
||||
if (!file) return;
|
||||
// Validate max size
|
||||
@@ -218,21 +235,24 @@ export function FileInput(props: Props) {
|
||||
return;
|
||||
}
|
||||
uploadFile(file);
|
||||
};
|
||||
}
|
||||
|
||||
const handleFileDrop = (event: React.DragEvent<HTMLDivElement>) => {
|
||||
function handleFileDrop(event: React.DragEvent<HTMLDivElement>) {
|
||||
event.preventDefault();
|
||||
const file = event.dataTransfer.files[0];
|
||||
if (file) uploadFile(file);
|
||||
};
|
||||
}
|
||||
|
||||
const handleClear = () => {
|
||||
function handleClear() {
|
||||
if (value && onDeleteFile) {
|
||||
onDeleteFile(value);
|
||||
}
|
||||
if (inputRef.current) {
|
||||
inputRef.current.value = "";
|
||||
}
|
||||
onChange("");
|
||||
setFileInfo(null);
|
||||
};
|
||||
}
|
||||
|
||||
const displayName = placeholder || "File";
|
||||
|
||||
@@ -241,27 +261,14 @@ export function FileInput(props: Props) {
|
||||
<div className={cn("flex flex-col gap-1.5", className)}>
|
||||
<div className="nodrag flex flex-col gap-1.5">
|
||||
{isUploading ? (
|
||||
<div className="flex flex-col gap-1.5 rounded-md border border-blue-200 bg-blue-50 p-2 dark:border-blue-800 dark:bg-blue-950">
|
||||
<div className="flex items-center gap-2">
|
||||
<UploadIcon className="h-4 w-4 animate-pulse text-blue-600 dark:text-blue-400" />
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-blue-700 dark:text-blue-300"
|
||||
>
|
||||
{mode === "base64" ? "Processing..." : "Uploading..."}
|
||||
</Text>
|
||||
{mode === "upload" && (
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="ml-auto text-blue-600 dark:text-blue-400"
|
||||
>
|
||||
{Math.round(uploadProgress)}%
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
{mode === "upload" && (
|
||||
<Progress value={uploadProgress} className="h-1 w-full" />
|
||||
)}
|
||||
<div className="flex items-center gap-2 rounded-md border border-blue-200 bg-blue-50 p-2 dark:border-blue-800 dark:bg-blue-950">
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-blue-300 border-t-blue-600 dark:border-blue-700 dark:border-t-blue-400" />
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-blue-700 dark:text-blue-300"
|
||||
>
|
||||
{mode === "base64" ? "Processing..." : "Uploading..."}
|
||||
</Text>
|
||||
</div>
|
||||
) : value ? (
|
||||
<div className="flex items-center gap-2">
|
||||
@@ -292,7 +299,7 @@ export function FileInput(props: Props) {
|
||||
onClick={handleClear}
|
||||
type="button"
|
||||
>
|
||||
<Cross2Icon className="h-3.5 w-3.5" />
|
||||
<X size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
@@ -333,26 +340,13 @@ export function FileInput(props: Props) {
|
||||
{isUploading ? (
|
||||
<div className="space-y-2">
|
||||
<div className="flex min-h-14 items-center gap-4">
|
||||
<div className="agpt-border-input flex min-h-14 w-full flex-col justify-center rounded-xl bg-zinc-50 p-4 text-sm">
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<UploadIcon className="h-5 w-5 text-blue-600" />
|
||||
<span className="text-gray-700">
|
||||
{mode === "base64" ? "Processing..." : "Uploading..."}
|
||||
</span>
|
||||
{mode === "upload" && (
|
||||
<span className="text-gray-500">
|
||||
{Math.round(uploadProgress)}%
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{mode === "upload" && (
|
||||
<Progress value={uploadProgress} className="w-full" />
|
||||
)}
|
||||
<div className="agpt-border-input flex min-h-14 w-full items-center gap-3 rounded-xl bg-zinc-50 p-4 text-sm">
|
||||
<div className="h-5 w-5 animate-spin rounded-full border-2 border-blue-300 border-t-blue-600" />
|
||||
<span className="text-gray-700">
|
||||
{mode === "base64" ? "Processing..." : "Uploading..."}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
{showStorageNote && mode === "upload" && (
|
||||
<p className="text-xs text-gray-500">{storageNote}</p>
|
||||
)}
|
||||
</div>
|
||||
) : value ? (
|
||||
<div className="space-y-2">
|
||||
|
||||
@@ -5,8 +5,10 @@ import { imageRenderer } from "./renderers/ImageRenderer";
|
||||
import { videoRenderer } from "./renderers/VideoRenderer";
|
||||
import { jsonRenderer } from "./renderers/JSONRenderer";
|
||||
import { markdownRenderer } from "./renderers/MarkdownRenderer";
|
||||
import { workspaceFileRenderer } from "./renderers/WorkspaceFileRenderer";
|
||||
|
||||
// Register all renderers in priority order
|
||||
globalRegistry.register(workspaceFileRenderer);
|
||||
globalRegistry.register(videoRenderer);
|
||||
globalRegistry.register(imageRenderer);
|
||||
globalRegistry.register(codeRenderer);
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
parseWorkspaceURI,
|
||||
parseWorkspaceFileID,
|
||||
isWorkspaceURI,
|
||||
buildWorkspaceURI,
|
||||
} from "@/lib/workspace-uri";
|
||||
|
||||
describe("parseWorkspaceURI", () => {
|
||||
it("parses a full workspace URI with mime type", () => {
|
||||
const result = parseWorkspaceURI("workspace://file-abc-123#image/png");
|
||||
expect(result).toEqual({ fileID: "file-abc-123", mimeType: "image/png" });
|
||||
});
|
||||
|
||||
it("parses a workspace URI without mime type", () => {
|
||||
const result = parseWorkspaceURI("workspace://file-abc-123");
|
||||
expect(result).toEqual({ fileID: "file-abc-123", mimeType: null });
|
||||
});
|
||||
|
||||
it("returns null for non-workspace URIs", () => {
|
||||
expect(parseWorkspaceURI("https://example.com")).toBeNull();
|
||||
expect(parseWorkspaceURI("data:image/png;base64,abc")).toBeNull();
|
||||
expect(parseWorkspaceURI("")).toBeNull();
|
||||
expect(parseWorkspaceURI("file:///tmp/test.txt")).toBeNull();
|
||||
});
|
||||
|
||||
it("handles empty fragment after hash as null mime type", () => {
|
||||
const result = parseWorkspaceURI("workspace://file-abc-123#");
|
||||
expect(result).toEqual({ fileID: "file-abc-123", mimeType: null });
|
||||
});
|
||||
|
||||
it("handles mime types with subtype", () => {
|
||||
const result = parseWorkspaceURI(
|
||||
"workspace://file-id#application/octet-stream",
|
||||
);
|
||||
expect(result).toEqual({
|
||||
fileID: "file-id",
|
||||
mimeType: "application/octet-stream",
|
||||
});
|
||||
});
|
||||
|
||||
it("handles UUID-style file IDs", () => {
|
||||
const uuid = "550e8400-e29b-41d4-a716-446655440000";
|
||||
const result = parseWorkspaceURI(`workspace://${uuid}#text/plain`);
|
||||
expect(result).toEqual({ fileID: uuid, mimeType: "text/plain" });
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseWorkspaceFileID", () => {
|
||||
it("extracts file ID from a full workspace URI", () => {
|
||||
expect(parseWorkspaceFileID("workspace://file-abc-123#image/png")).toBe(
|
||||
"file-abc-123",
|
||||
);
|
||||
});
|
||||
|
||||
it("extracts file ID when no mime type fragment", () => {
|
||||
expect(parseWorkspaceFileID("workspace://file-abc-123")).toBe(
|
||||
"file-abc-123",
|
||||
);
|
||||
});
|
||||
|
||||
it("returns null for non-workspace URIs", () => {
|
||||
expect(parseWorkspaceFileID("https://example.com")).toBeNull();
|
||||
expect(parseWorkspaceFileID("data:image/png;base64,abc")).toBeNull();
|
||||
expect(parseWorkspaceFileID("")).toBeNull();
|
||||
});
|
||||
|
||||
it("is consistent with parseWorkspaceURI for file ID extraction", () => {
|
||||
const uris = [
|
||||
"workspace://abc#image/png",
|
||||
"workspace://abc",
|
||||
"workspace://abc#",
|
||||
"workspace://550e8400-e29b-41d4-a716-446655440000#text/plain",
|
||||
];
|
||||
|
||||
for (const uri of uris) {
|
||||
const fullParse = parseWorkspaceURI(uri);
|
||||
const idOnly = parseWorkspaceFileID(uri);
|
||||
expect(idOnly).toBe(fullParse?.fileID ?? null);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("isWorkspaceURI", () => {
|
||||
it("returns true for workspace URIs", () => {
|
||||
expect(isWorkspaceURI("workspace://abc")).toBe(true);
|
||||
expect(isWorkspaceURI("workspace://abc#image/png")).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false for non-workspace values", () => {
|
||||
expect(isWorkspaceURI("https://example.com")).toBe(false);
|
||||
expect(isWorkspaceURI("")).toBe(false);
|
||||
expect(isWorkspaceURI(null)).toBe(false);
|
||||
expect(isWorkspaceURI(undefined)).toBe(false);
|
||||
expect(isWorkspaceURI(123)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildWorkspaceURI", () => {
|
||||
it("builds URI with mime type", () => {
|
||||
expect(buildWorkspaceURI("file-123", "image/png")).toBe(
|
||||
"workspace://file-123#image/png",
|
||||
);
|
||||
});
|
||||
|
||||
it("builds URI without mime type", () => {
|
||||
expect(buildWorkspaceURI("file-123")).toBe("workspace://file-123");
|
||||
});
|
||||
|
||||
it("roundtrips with parseWorkspaceURI", () => {
|
||||
const uri = buildWorkspaceURI("file-abc", "text/plain");
|
||||
const parsed = parseWorkspaceURI(uri);
|
||||
expect(parsed).toEqual({ fileID: "file-abc", mimeType: "text/plain" });
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,226 @@
|
||||
import { DownloadSimple, FileText } from "@phosphor-icons/react";
|
||||
import { type ReactNode, useState } from "react";
|
||||
import {
|
||||
OutputRenderer,
|
||||
OutputMetadata,
|
||||
DownloadContent,
|
||||
CopyContent,
|
||||
} from "../types";
|
||||
import { parseWorkspaceURI, isWorkspaceURI } from "@/lib/workspace-uri";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
|
||||
const imageMimeTypes = [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/bmp",
|
||||
"image/svg+xml",
|
||||
"image/webp",
|
||||
"image/x-icon",
|
||||
];
|
||||
|
||||
const videoMimeTypes = [
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"video/ogg",
|
||||
"video/quicktime",
|
||||
"video/x-msvideo",
|
||||
"video/x-matroska",
|
||||
];
|
||||
|
||||
const audioMimeTypes = [
|
||||
"audio/mpeg",
|
||||
"audio/ogg",
|
||||
"audio/wav",
|
||||
"audio/webm",
|
||||
"audio/aac",
|
||||
"audio/flac",
|
||||
];
|
||||
|
||||
function buildDownloadURL(fileID: string): string {
|
||||
return `/api/proxy/api/workspace/files/${fileID}/download`;
|
||||
}
|
||||
|
||||
function canRenderWorkspaceFile(value: unknown): boolean {
|
||||
return isWorkspaceURI(value);
|
||||
}
|
||||
|
||||
function getFileTypeLabel(mimeType: string | null): string {
|
||||
if (!mimeType) return "File";
|
||||
const sub = mimeType.split("/")[1];
|
||||
if (!sub) return "File";
|
||||
return `${sub.toUpperCase()} file`;
|
||||
}
|
||||
|
||||
function WorkspaceImage({ src, alt }: { src: string; alt: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="group relative">
|
||||
{!loaded && (
|
||||
<Skeleton className="absolute inset-0 h-full min-h-40 w-full rounded-md" />
|
||||
)}
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={src}
|
||||
alt={alt}
|
||||
className={`h-auto max-w-full rounded-md border border-gray-200 ${loaded ? "opacity-100" : "min-h-40 opacity-0"}`}
|
||||
loading="lazy"
|
||||
onLoad={() => setLoaded(true)}
|
||||
onError={() => setLoaded(true)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function WorkspaceVideo({ src, mimeType }: { src: string; mimeType: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="group relative">
|
||||
{!loaded && (
|
||||
<Skeleton className="absolute inset-0 h-full min-h-40 w-full rounded-md" />
|
||||
)}
|
||||
<video
|
||||
controls
|
||||
className={`h-auto max-w-full rounded-md border border-gray-200 ${loaded ? "opacity-100" : "min-h-40 opacity-0"}`}
|
||||
preload="metadata"
|
||||
onLoadedMetadata={() => setLoaded(true)}
|
||||
onError={() => setLoaded(true)}
|
||||
>
|
||||
<source src={src} type={mimeType} />
|
||||
Your browser does not support the video tag.
|
||||
</video>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function WorkspaceAudio({ src, mimeType }: { src: string; mimeType: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="group relative">
|
||||
{!loaded && (
|
||||
<Skeleton className="absolute inset-0 h-full min-h-12 w-full rounded-md" />
|
||||
)}
|
||||
<audio
|
||||
controls
|
||||
preload="metadata"
|
||||
className={`w-full ${loaded ? "opacity-100" : "min-h-12 opacity-0"}`}
|
||||
onLoadedMetadata={() => setLoaded(true)}
|
||||
onError={() => setLoaded(true)}
|
||||
>
|
||||
<source src={src} type={mimeType} />
|
||||
Your browser does not support the audio tag.
|
||||
</audio>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function renderWorkspaceFile(
|
||||
value: unknown,
|
||||
metadata?: OutputMetadata,
|
||||
): ReactNode {
|
||||
const uri = parseWorkspaceURI(String(value));
|
||||
if (!uri) return null;
|
||||
|
||||
const downloadURL = buildDownloadURL(uri.fileID);
|
||||
const mimeType = uri.mimeType || metadata?.mimeType || null;
|
||||
|
||||
if (mimeType && imageMimeTypes.includes(mimeType)) {
|
||||
return (
|
||||
<WorkspaceImage src={downloadURL} alt={metadata?.filename || "Image"} />
|
||||
);
|
||||
}
|
||||
|
||||
if (mimeType && videoMimeTypes.includes(mimeType)) {
|
||||
return <WorkspaceVideo src={downloadURL} mimeType={mimeType} />;
|
||||
}
|
||||
|
||||
if (mimeType && audioMimeTypes.includes(mimeType)) {
|
||||
return <WorkspaceAudio src={downloadURL} mimeType={mimeType} />;
|
||||
}
|
||||
|
||||
// Generic file card with icon and download link
|
||||
const label = getFileTypeLabel(mimeType);
|
||||
return (
|
||||
<div className="flex items-center gap-3 rounded-lg border border-gray-200 bg-gray-50 p-3 dark:border-gray-700 dark:bg-gray-800">
|
||||
<FileText size={28} className="flex-shrink-0 text-gray-500" />
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<span className="truncate text-sm font-medium text-gray-900 dark:text-gray-100">
|
||||
{metadata?.filename || label}
|
||||
</span>
|
||||
{mimeType && (
|
||||
<span className="text-xs text-gray-500 dark:text-gray-400">
|
||||
{mimeType}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<a
|
||||
href={downloadURL}
|
||||
download
|
||||
className="flex-shrink-0 rounded-md p-1.5 text-gray-500 transition-colors hover:bg-gray-200 hover:text-gray-700 dark:hover:bg-gray-700 dark:hover:text-gray-300"
|
||||
>
|
||||
<DownloadSimple size={18} />
|
||||
</a>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function getCopyContentWorkspaceFile(
|
||||
value: unknown,
|
||||
metadata?: OutputMetadata,
|
||||
): CopyContent | null {
|
||||
const uri = parseWorkspaceURI(String(value));
|
||||
if (!uri) return null;
|
||||
|
||||
const downloadURL = buildDownloadURL(uri.fileID);
|
||||
const mimeType =
|
||||
uri.mimeType || metadata?.mimeType || "application/octet-stream";
|
||||
|
||||
return {
|
||||
mimeType,
|
||||
data: async () => {
|
||||
const response = await fetch(downloadURL);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch file: ${response.status}`);
|
||||
}
|
||||
return await response.blob();
|
||||
},
|
||||
alternativeMimeTypes: ["text/plain"],
|
||||
fallbackText: String(value),
|
||||
};
|
||||
}
|
||||
|
||||
function getDownloadContentWorkspaceFile(
|
||||
value: unknown,
|
||||
metadata?: OutputMetadata,
|
||||
): DownloadContent | null {
|
||||
const uri = parseWorkspaceURI(String(value));
|
||||
if (!uri) return null;
|
||||
|
||||
const mimeType =
|
||||
uri.mimeType || metadata?.mimeType || "application/octet-stream";
|
||||
const ext = mimeType.split("/")[1] || "bin";
|
||||
const filename = metadata?.filename || `file.${ext}`;
|
||||
|
||||
return {
|
||||
data: buildDownloadURL(uri.fileID),
|
||||
filename,
|
||||
mimeType,
|
||||
};
|
||||
}
|
||||
|
||||
function isConcatenableWorkspaceFile(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
export const workspaceFileRenderer: OutputRenderer = {
|
||||
name: "WorkspaceFileRenderer",
|
||||
priority: 50, // Higher than video (45) and image (40) so it matches first
|
||||
canRender: canRenderWorkspaceFile,
|
||||
render: renderWorkspaceFile,
|
||||
getCopyContent: getCopyContentWorkspaceFile,
|
||||
getDownloadContent: getDownloadContentWorkspaceFile,
|
||||
isConcatenable: isConcatenableWorkspaceFile,
|
||||
};
|
||||
@@ -1,27 +1,30 @@
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import { FileInput } from "@/components/atoms/FileInput/FileInput";
|
||||
import { useWorkspaceUpload } from "./useWorkspaceUpload";
|
||||
|
||||
export const FileWidget = (props: WidgetProps) => {
|
||||
export function FileWidget(props: WidgetProps) {
|
||||
const { onChange, disabled, readonly, value, schema, formContext } = props;
|
||||
|
||||
const { size } = formContext || {};
|
||||
|
||||
const displayName = schema?.title || "File";
|
||||
const { handleUploadFile, handleDeleteFile } = useWorkspaceUpload();
|
||||
|
||||
const handleChange = (fileUri: string) => {
|
||||
onChange(fileUri);
|
||||
};
|
||||
function handleChange(fileURI: string) {
|
||||
onChange(fileURI);
|
||||
}
|
||||
|
||||
return (
|
||||
<FileInput
|
||||
variant={size === "large" ? "default" : "compact"}
|
||||
mode="base64"
|
||||
mode="upload"
|
||||
value={value}
|
||||
placeholder={displayName}
|
||||
onChange={handleChange}
|
||||
onDeleteFile={handleDeleteFile}
|
||||
onUploadFile={handleUploadFile}
|
||||
showStorageNote={false}
|
||||
className={
|
||||
disabled || readonly ? "pointer-events-none opacity-50" : undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import {
|
||||
usePostWorkspaceUploadFileToWorkspace,
|
||||
useDeleteWorkspaceDeleteAWorkspaceFile,
|
||||
} from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { parseWorkspaceFileID, buildWorkspaceURI } from "@/lib/workspace-uri";
|
||||
|
||||
export function useWorkspaceUpload() {
|
||||
const { toast } = useToast();
|
||||
|
||||
const { mutateAsync: uploadMutation } =
|
||||
usePostWorkspaceUploadFileToWorkspace();
|
||||
|
||||
const { mutate: deleteMutation } = useDeleteWorkspaceDeleteAWorkspaceFile({
|
||||
mutation: {
|
||||
onError: () => {
|
||||
toast({
|
||||
title: "Failed to delete file",
|
||||
description: "The file could not be removed from storage.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
async function handleUploadFile(file: File) {
|
||||
const response = await uploadMutation({ data: { file } });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Upload failed");
|
||||
}
|
||||
const d = response.data;
|
||||
return {
|
||||
file_name: d.name,
|
||||
size: d.size_bytes,
|
||||
content_type: d.mime_type,
|
||||
file_uri: buildWorkspaceURI(d.file_id, d.mime_type),
|
||||
};
|
||||
}
|
||||
|
||||
function handleDeleteFile(fileURI: string) {
|
||||
const fileID = parseWorkspaceFileID(fileURI);
|
||||
if (!fileID) return;
|
||||
deleteMutation({ fileId: fileID });
|
||||
}
|
||||
|
||||
return { handleUploadFile, handleDeleteFile };
|
||||
}
|
||||
@@ -1,32 +1,60 @@
|
||||
import { BlockUIType } from "@/app/(platform)/build/components/types";
|
||||
import { GoogleDrivePickerInput } from "@/components/contextual/GoogleDrivePicker/GoogleDrivePickerInput";
|
||||
import { GoogleDrivePickerConfig } from "@/lib/autogpt-server-api";
|
||||
import { FieldProps, getUiOptions } from "@rjsf/utils";
|
||||
import { FieldProps, getTemplate, getUiOptions, titleId } from "@rjsf/utils";
|
||||
import { cleanUpHandleId, getHandleId, updateUiOption } from "../../helpers";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
|
||||
export const GoogleDrivePickerField = (props: FieldProps) => {
|
||||
const { schema, uiSchema, onChange, fieldPathId, formData, registry } = props;
|
||||
const uiOptions = getUiOptions(uiSchema);
|
||||
const config: GoogleDrivePickerConfig = schema.google_drive_picker_config;
|
||||
|
||||
const { nodeId } = registry.formContext;
|
||||
const uiType = registry.formContext?.uiType;
|
||||
|
||||
const TitleFieldTemplate = getTemplate("TitleFieldTemplate", registry);
|
||||
|
||||
const handleId = getHandleId({ uiOptions, id: fieldPathId.$id, schema });
|
||||
const updatedUiSchema = updateUiOption(uiSchema, {
|
||||
handleId,
|
||||
showHandles: !!nodeId,
|
||||
});
|
||||
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
const isConnected = isInputConnected(nodeId, cleanUpHandleId(handleId));
|
||||
|
||||
if (uiType === BlockUIType.INPUT) {
|
||||
return (
|
||||
<div className="rounded-3xl border border-gray-200 p-2 pl-4 text-xs text-gray-500 hover:cursor-not-allowed">
|
||||
Select files when you run the graph
|
||||
<div className="flex flex-col gap-2">
|
||||
{!isConnected && (
|
||||
<div className="rounded-3xl border border-gray-200 p-2 pl-4 text-xs text-gray-500 hover:cursor-not-allowed">
|
||||
Select files when you run the graph
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<GoogleDrivePickerInput
|
||||
config={config}
|
||||
value={formData}
|
||||
onChange={(value) => onChange(value, fieldPathId.path)}
|
||||
className={uiOptions.className}
|
||||
showRemoveButton={true}
|
||||
<div className="flex flex-col gap-2">
|
||||
<TitleFieldTemplate
|
||||
id={titleId(fieldPathId.$id)}
|
||||
title={schema.title || ""}
|
||||
required={false}
|
||||
schema={schema}
|
||||
uiSchema={updatedUiSchema}
|
||||
registry={registry}
|
||||
/>
|
||||
{!isConnected && (
|
||||
<GoogleDrivePickerInput
|
||||
config={config}
|
||||
value={formData}
|
||||
onChange={(value) => onChange(value, fieldPathId.path)}
|
||||
className={uiOptions.className}
|
||||
showRemoveButton={true}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -44,9 +44,19 @@ export function generateUiSchemaForCustomFields(
|
||||
const customFieldId = findCustomFieldId(propSchema);
|
||||
|
||||
if (customFieldId) {
|
||||
const hasAnyOfOrOneOf =
|
||||
propSchema.anyOf || propSchema.oneOf ? true : false;
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
"ui:field": customFieldId,
|
||||
...(hasAnyOfOrOneOf && {
|
||||
"ui:options": {
|
||||
...((uiSchema[key] as Record<string, unknown>)?.[
|
||||
"ui:options"
|
||||
] as object),
|
||||
fieldReplacesAnyOrOneOf: true,
|
||||
},
|
||||
}),
|
||||
};
|
||||
// Skip further processing for custom fields
|
||||
continue;
|
||||
|
||||
54
autogpt_platform/frontend/src/lib/workspace-uri.ts
Normal file
54
autogpt_platform/frontend/src/lib/workspace-uri.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Shared utilities for parsing and constructing workspace:// URIs.
|
||||
*
|
||||
* Format: workspace://{fileID}#{mimeType}
|
||||
* - fileID: unique identifier for the file
|
||||
* - mimeType: optional MIME type hint (e.g. "image/png")
|
||||
*/
|
||||
|
||||
export interface WorkspaceURI {
|
||||
fileID: string;
|
||||
mimeType: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a workspace:// URI into its components.
|
||||
* Returns null if the string is not a workspace URI.
|
||||
*/
|
||||
export function parseWorkspaceURI(value: string): WorkspaceURI | null {
|
||||
if (!value.startsWith("workspace://")) return null;
|
||||
const rest = value.slice("workspace://".length);
|
||||
const hashIndex = rest.indexOf("#");
|
||||
if (hashIndex === -1) {
|
||||
return { fileID: rest, mimeType: null };
|
||||
}
|
||||
return {
|
||||
fileID: rest.slice(0, hashIndex),
|
||||
mimeType: rest.slice(hashIndex + 1) || null,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract just the file ID from a workspace:// URI.
|
||||
* Returns null if the string is not a workspace URI.
|
||||
*/
|
||||
export function parseWorkspaceFileID(uri: string): string | null {
|
||||
const parsed = parseWorkspaceURI(uri);
|
||||
return parsed?.fileID ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a value is a workspace:// URI string.
|
||||
*/
|
||||
export function isWorkspaceURI(value: unknown): value is string {
|
||||
return typeof value === "string" && value.startsWith("workspace://");
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a workspace:// URI from a file ID and optional MIME type.
|
||||
*/
|
||||
export function buildWorkspaceURI(fileID: string, mimeType?: string): string {
|
||||
return mimeType
|
||||
? `workspace://${fileID}#${mimeType}`
|
||||
: `workspace://${fileID}`;
|
||||
}
|
||||
Reference in New Issue
Block a user