mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
5 Commits
fix/copilo
...
abhi/add-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54bf45656a | ||
|
|
2f32217c7c | ||
|
|
7b64fbc931 | ||
|
|
1a0234c946 | ||
|
|
1e14634d3d |
@@ -11,10 +11,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
@@ -181,8 +178,7 @@ class FileOperation(StrEnum):
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
content: str
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
@@ -279,11 +275,11 @@ class GithubMultiFileCommitBlock(Block):
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
async def _create_blob(content: str) -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": encoding},
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
@@ -305,19 +301,10 @@ class GithubMultiFileCommitBlock(Block):
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
# Create all blobs concurrently
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
@@ -371,36 +358,15 @@ class GithubMultiFileCommitBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
resolved_files,
|
||||
input_data.files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
@@ -98,11 +97,7 @@ async def test_multi_file_commit_error_path():
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=420, # 7 min safety net — allows headroom for compaction retries
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
|
||||
@@ -11,8 +11,6 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
@@ -84,17 +82,6 @@ def resolve_sandbox_path(path: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||
|
||||
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||
``tools/__init__`` import chain.
|
||||
"""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
|
||||
@@ -52,43 +52,11 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
**Type coercion**: The platform automatically coerces expanded string values
|
||||
to match the block's expected input types. For example, if a block expects
|
||||
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
|
||||
an @@agptfile: expansion), the string will be parsed into the correct type.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
|
||||
@@ -43,7 +43,6 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -233,26 +232,3 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE comment format to keep connection alive."""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Encode as an SSE comment so the AI SDK stream parser ignores it.
|
||||
|
||||
The frontend AI SDK validates every ``data:`` line against a strict
|
||||
Zod union of known chunk types. ``"status"`` is not in that union,
|
||||
so sending it as ``data:`` would cause a schema-validation error that
|
||||
breaks the entire stream. Using an SSE comment (``:``) keeps the
|
||||
connection alive and is silently discarded by ``EventSource`` parsers.
|
||||
"""
|
||||
return f": status {self.message}\n\n"
|
||||
|
||||
@@ -3,45 +3,12 @@
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
|
||||
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||
circular import cycle::
|
||||
|
||||
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||
|
||||
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||
function scope without a larger refactor (moving tool-name registration
|
||||
to a separate lightweight module). The lazy-import pattern here is the
|
||||
least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
entry = _LAZY_IMPORTS.get(name)
|
||||
if entry is not None:
|
||||
module_path, attr = entry
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path, package=__name__)
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -12,7 +12,6 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -120,12 +119,14 @@ def filter_compaction_messages(
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
real_calls: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
else:
|
||||
real_calls.append(tc)
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
@@ -221,7 +222,6 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""Shared test fixtures for copilot SDK tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
"""Build a minimal valid JSONL transcript from (role, content) pairs.
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": msg,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -41,20 +41,12 @@ from typing import Any
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
MIME_TO_FORMAT,
|
||||
PARSE_EXCEPTIONS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
@@ -82,8 +74,6 @@ _FILE_REF_RE = re.compile(
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||
_MAX_BARE_REF_BYTES = 10_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,11 +83,6 @@ class FileRef:
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
@@ -119,6 +104,17 @@ def parse_file_ref(text: str) -> FileRef | None:
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
@@ -134,47 +130,27 @@ async def read_file_bytes(
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
data = await (
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except (PermissionError, OSError) as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||
# AttributeError/TypeError: workspace manager returned an
|
||||
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
# NOTE: Workspace API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||
with open(resolved, "rb") as fh:
|
||||
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||
f"limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except OSError as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
@@ -186,33 +162,9 @@ async def read_file_bytes(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||
# inherit from standard exceptions. Import lazily to avoid a
|
||||
# hard dependency on e2b at module level.
|
||||
try:
|
||||
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||
|
||||
if isinstance(exc, SandboxException):
|
||||
raise ValueError(
|
||||
f"Failed to read from sandbox: {plain}: {exc}"
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||
# so they surface as real bugs rather than being silently masked.
|
||||
raise
|
||||
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
@@ -226,13 +178,15 @@ async def resolve_file_ref(
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
@@ -278,9 +232,6 @@ async def expand_file_refs_in_string(
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
@@ -301,31 +252,13 @@ async def expand_file_refs_in_string(
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
**Bare references** (the entire argument value is a single
|
||||
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||
parsed according to the file's extension or MIME type. See
|
||||
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||
|
||||
When *input_schema* is provided and the target property has
|
||||
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||
is returned as a plain string so blocks receive the original text.
|
||||
|
||||
If the format is unrecognised or parsing fails, the content is returned as
|
||||
a plain string (the fallback).
|
||||
|
||||
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||
produce a plain string — structured parsing only applies to bare refs.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
@@ -334,382 +267,15 @@ async def expand_file_refs_in_args(
|
||||
if not args:
|
||||
return args
|
||||
|
||||
properties = (input_schema or {}).get("properties", {})
|
||||
|
||||
async def _expand(
|
||||
value: Any,
|
||||
*,
|
||||
prop_schema: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Recursively expand a single argument value.
|
||||
|
||||
Strings are checked for ``@@agptfile:`` references and expanded
|
||||
(bare refs get structured parsing; embedded refs get inline
|
||||
substitution). Dicts and lists are traversed recursively,
|
||||
threading the corresponding sub-schema from *prop_schema* so
|
||||
that nested fields also receive correct type-aware expansion.
|
||||
Non-string scalars pass through unchanged.
|
||||
"""
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
ref = parse_file_ref(value)
|
||||
if ref is not None:
|
||||
# MediaFileType fields: return the raw URI immediately —
|
||||
# no file reading, no format inference, no content parsing.
|
||||
if _is_media_file_field(prop_schema):
|
||||
return ref.uri
|
||||
|
||||
fmt = infer_format_from_uri(ref.uri)
|
||||
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||
# When the MIME fragment is also missing, fall back to the
|
||||
# workspace file manager's metadata for format detection.
|
||||
if fmt is None and ref.uri.startswith("workspace://"):
|
||||
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||
|
||||
# Not a bare ref — do normal inline expansion.
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
# When the schema says this is an object but doesn't define
|
||||
# inner properties, skip expansion — the caller (e.g.
|
||||
# RunBlockTool) will expand with the actual nested schema.
|
||||
if (
|
||||
prop_schema is not None
|
||||
and prop_schema.get("type") == "object"
|
||||
and "properties" not in prop_schema
|
||||
):
|
||||
return value
|
||||
nested_props = (prop_schema or {}).get("properties", {})
|
||||
return {
|
||||
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||
for k, v in value.items()
|
||||
}
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
items_schema = (prop_schema or {}).get("items")
|
||||
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers (used by the public functions above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
def _to_str(content: str | bytes) -> str:
|
||||
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _check_content_size(content: str | bytes) -> None:
|
||||
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||
|
||||
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||
|
||||
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||
mean the byte size can be up to 4x the character count.
|
||||
"""
|
||||
if isinstance(content, bytes):
|
||||
size = len(content)
|
||||
else:
|
||||
char_len = len(content)
|
||||
# Fast lower bound: UTF-8 byte count >= char count.
|
||||
# If char count already exceeds the limit, reject immediately
|
||||
# without allocating an encoded copy.
|
||||
if char_len > _MAX_BARE_REF_BYTES:
|
||||
size = char_len # real byte size is even larger
|
||||
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||
# If worst-case is still under the limit, skip encoding entirely.
|
||||
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||
return
|
||||
else:
|
||||
# Edge case: char count is under limit but multibyte chars
|
||||
# might push byte count over. Encode to get exact size.
|
||||
size = len(content.encode("utf-8"))
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large for structured parsing "
|
||||
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
|
||||
|
||||
async def _infer_format_from_workspace(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str | None:
|
||||
"""Look up workspace file metadata to infer the format.
|
||||
|
||||
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||
When the MIME fragment is also absent, we query the workspace file
|
||||
manager for the file's stored MIME type and original filename.
|
||||
"""
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
ws = parse_workspace_uri(uri)
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
info = await (
|
||||
manager.get_file_info(ws.file_ref)
|
||||
if not ws.is_path
|
||||
else manager.get_file_info_by_path(ws.file_ref)
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
# Try MIME type first, then filename extension.
|
||||
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||
except (
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
OSError,
|
||||
PermissionError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
):
|
||||
# Expected failures: bad URI, missing file, permission denied, or
|
||||
# workspace manager returning unexpected types. Propagate anything
|
||||
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||
if prop_schema is None:
|
||||
return False
|
||||
return (
|
||||
prop_schema.get("type") == "string"
|
||||
and prop_schema.get("format") == MediaFileType.string_format
|
||||
)
|
||||
|
||||
|
||||
async def _expand_bare_ref(
|
||||
ref: FileRef,
|
||||
fmt: str | None,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
prop_schema: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||
|
||||
This is the structured-parsing path: the file is read, optionally parsed
|
||||
according to *fmt*, and adapted to the target *prop_schema*.
|
||||
|
||||
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||
|
||||
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||
"""
|
||||
try:
|
||||
if fmt is not None and fmt in BINARY_FORMATS:
|
||||
# Binary formats need raw bytes, not UTF-8 text.
|
||||
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||
# — ignore them and parse full bytes. Warn so the caller/model
|
||||
# knows the range was silently dropped.
|
||||
if ref.start_line is not None or ref.end_line is not None:
|
||||
logger.warning(
|
||||
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||
"binary formats are always parsed in full.",
|
||||
ref.start_line,
|
||||
ref.end_line,
|
||||
fmt,
|
||||
ref.uri,
|
||||
)
|
||||
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||
else:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
# _check_content_size raises ValueError, which we unify here just like
|
||||
# resolution errors above.
|
||||
try:
|
||||
_check_content_size(content)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
# return raw file content — don't parse into a structured
|
||||
# type that would need json.dumps() serialisation.
|
||||
expect_string = (prop_schema or {}).get("type") == "string"
|
||||
if expect_string:
|
||||
if isinstance(content, bytes):
|
||||
raise FileRefExpansionError(
|
||||
f"Cannot use {fmt} file as text input: "
|
||||
f"binary formats (parquet, xlsx) must be passed "
|
||||
f"to a block that accepts structured data (list/object), "
|
||||
f"not a string-typed parameter."
|
||||
)
|
||||
return content
|
||||
|
||||
if fmt is not None:
|
||||
# Use strict mode for binary formats so we surface the
|
||||
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||
# file) instead of silently returning garbled bytes.
|
||||
strict = fmt in BINARY_FORMATS
|
||||
try:
|
||||
parsed = parse_file_content(content, fmt, strict=strict)
|
||||
except PARSE_EXCEPTIONS as exc:
|
||||
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||
# Normalize bytes fallback to str so tools never
|
||||
# receive raw bytes when parsing fails.
|
||||
if isinstance(parsed, bytes):
|
||||
parsed = _to_str(parsed)
|
||||
return _adapt_to_schema(parsed, prop_schema)
|
||||
|
||||
# Unknown format — return as plain string, but apply
|
||||
# the same per-ref character limit used by inline refs
|
||||
# to prevent injecting unexpectedly large content.
|
||||
text = _to_str(content)
|
||||
if len(text) > _MAX_EXPAND_CHARS:
|
||||
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
return text
|
||||
|
||||
|
||||
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
"""Adapt a parsed file value to better fit the target schema type.
|
||||
|
||||
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||
that doesn't match the block's expected type, this function converts it to
|
||||
a more useful representation instead of relying on pydantic's generic
|
||||
coercion (which can produce awkward results like flattened dicts → lists).
|
||||
|
||||
Returns *parsed* unchanged when no adaptation is needed.
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return parsed
|
||||
|
||||
target_type = prop_schema.get("type")
|
||||
|
||||
# Dict → array: delegate to helper.
|
||||
if isinstance(parsed, dict) and target_type == "array":
|
||||
return _adapt_dict_to_array(parsed, prop_schema)
|
||||
|
||||
# List → object: delegate to helper (raises for non-tabular lists).
|
||||
if isinstance(parsed, list) and target_type == "object":
|
||||
return _adapt_list_to_object(parsed)
|
||||
|
||||
# Tabular list → Any (no type): convert to list of dicts.
|
||||
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||
return _tabular_to_list_of_dicts(parsed)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||
"""Adapt a parsed dict to an array-typed field.
|
||||
|
||||
Extracts list-valued entries when the target item type is ``array``,
|
||||
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||
or wraps in ``[parsed]`` as a fallback.
|
||||
"""
|
||||
items_type = (prop_schema.get("items") or {}).get("type")
|
||||
if items_type == "array":
|
||||
# Target is List[List[Any]] — extract list-typed values from the
|
||||
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _adapt_list_to_object(parsed: list) -> Any:
|
||||
"""Adapt a parsed list to an object-typed field.
|
||||
|
||||
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||
"""
|
||||
if _is_tabular(parsed):
|
||||
return _tabular_to_column_dict(parsed)
|
||||
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||
# get a clear error rather than pydantic silently wrapping the list.
|
||||
raise FileRefExpansionError(
|
||||
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||
)
|
||||
|
||||
|
||||
def _is_tabular(parsed: Any) -> bool:
|
||||
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||
|
||||
Uses isinstance checks because this is a structural type guard on
|
||||
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||
help here — we need to verify exact list-of-lists shape.
|
||||
"""
|
||||
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||
return False
|
||||
header = parsed[0]
|
||||
if not isinstance(header, list) or not header:
|
||||
return False
|
||||
if not all(isinstance(h, str) for h in header):
|
||||
return False
|
||||
return all(isinstance(row, list) for row in parsed[1:])
|
||||
|
||||
|
||||
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values.
|
||||
Extra values beyond the header length are silently dropped.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return [
|
||||
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||
for row in parsed[1:]
|
||||
]
|
||||
|
||||
|
||||
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values,
|
||||
ensuring all columns have equal length.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return {
|
||||
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||
for i, col in enumerate(header)
|
||||
}
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
|
||||
@@ -175,199 +175,6 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — bare ref structured parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_json_returns_parsed_dict():
|
||||
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value", "count": 42}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_csv_returns_parsed_table():
|
||||
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||
with open(csv_file, "w") as f:
|
||||
f.write("Name,Score\nAlice,90\nBob,85")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"input": f"@@agptfile:{csv_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["input"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_unknown_extension_returns_string():
|
||||
"""Bare ref to a file with unknown extension returns plain string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||
with open(txt_file, "w") as f:
|
||||
f.write("plain text content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{txt_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "plain text content"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "not valid json {{{"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value"}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert isinstance(result["data"], str)
|
||||
assert result["data"].startswith("prefix ")
|
||||
assert result["data"].endswith(" suffix")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||
"""Bare ref to a .yaml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||
with open(yaml_file, "w") as f:
|
||||
f.write("name: test\ncount: 42\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{yaml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||
|
||||
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||
parsed and the range is silently dropped with a log warning.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pytest.skip("pandas not installed")
|
||||
try:
|
||||
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
except ImportError:
|
||||
pytest.skip("pyarrow not installed")
|
||||
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||
import io as _io
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = _io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
with open(parquet_file, "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
# Line range [1-2] should be silently ignored for binary formats.
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
# Full file is returned despite the line range.
|
||||
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_toml_returns_parsed_dict():
|
||||
"""Bare ref to a .toml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||
with open(toml_file, "w") as f:
|
||||
f.write('name = "test"\ncount = 42\n')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{toml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -412,7 +219,7 @@ async def test_read_file_handler_workspace_uri():
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
@@ -469,7 +276,7 @@ async def test_read_file_bytes_workspace_virtual_path():
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,552 +0,0 @@
|
||||
"""Tests for retry logic and transcript compaction helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
|
||||
assert _flatten_tool_result_content(blocks) == "ok"
|
||||
|
||||
def test_text_block(self):
|
||||
blocks = [{"type": "text", "text": "plain text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "plain text"
|
||||
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_nested_dict_without_text(self):
|
||||
"""Dict blocks without text key use json.dumps fallback."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "[__image__]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
|
||||
"""Build a JSONL line for testing."""
|
||||
uid = str(uuid4())
|
||||
msg: dict = {"role": role, "content": content}
|
||||
msg.update(kwargs)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": None,
|
||||
"message": msg,
|
||||
}
|
||||
return json.dumps(entry, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_roundtrip(self):
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "Hello"}
|
||||
assert messages[1] == {"role": "assistant", "content": "Hi"}
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
"""Progress and metadata entries are excluded."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"role": "assistant", "content": "..."},
|
||||
}
|
||||
),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
|
||||
def test_tool_result_content(self):
|
||||
"""User entries with tool_result content blocks are flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": "tool output",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
def test_assistant_has_proper_structure(self):
|
||||
messages = [{"role": "assistant", "content": "Hello"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "assistant"
|
||||
msg = entry["message"]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["type"] == "text"
|
||||
|
||||
def test_user_has_plain_content(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "user"
|
||||
assert entry["message"]["content"] == "Hi"
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{"role": "user", "content": "C"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
"""Output should pass validate_transcript if it has assistant entries."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
def test_roundtrip_to_messages(self):
|
||||
"""Messages → transcript → messages preserves structure."""
|
||||
original = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
restored = _transcript_to_messages(transcript)
|
||||
assert len(restored) == len(original)
|
||||
for orig, rest in zip(original, restored):
|
||||
assert orig["role"] == rest["role"]
|
||||
assert orig["content"] == rest["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self):
|
||||
"""compact_transcript returns None when transcript has < 2 messages."""
|
||||
transcript = _build_transcript([("user", "Hello")])
|
||||
with patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_compacted(self):
|
||||
"""When compress_context says no compaction needed, returns None.
|
||||
The compressor couldn't reduce it, so retrying with the same
|
||||
content would fail identically."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi there"),
|
||||
]
|
||||
)
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": False,
|
||||
"messages": [],
|
||||
"original_token_count": 100,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_compacted_transcript(self):
|
||||
"""When compaction succeeds, returns a valid compacted transcript."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
("user", "More"),
|
||||
("assistant", "Details"),
|
||||
]
|
||||
)
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 500,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self):
|
||||
"""When _run_compression raises, returns None."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
]
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Unit tests for _is_prompt_too_long pattern matching."""
|
||||
|
||||
def test_prompt_is_too_long(self):
|
||||
err = RuntimeError("prompt is too long for model context")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_request_too_large(self):
|
||||
err = Exception("request too large: 250000 tokens")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_maximum_context_length(self):
|
||||
err = ValueError("maximum context length exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
err = Exception("context_length_exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_tokens_exceed(self):
|
||||
err = Exception("input tokens exceed the max_tokens limit")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_is_too_long(self):
|
||||
err = Exception("input is too long for the model")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_content_length_exceeds(self):
|
||||
err = Exception("content length exceeds maximum")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_unrelated_error_returns_false(self):
|
||||
err = RuntimeError("network timeout")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_auth_error_returns_false(self):
|
||||
err = Exception("authentication failed: invalid API key")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_chained_exception_detected(self):
|
||||
"""Prompt-too-long error wrapped in another exception is detected."""
|
||||
inner = RuntimeError("prompt is too long")
|
||||
outer = Exception("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
err = Exception("PROMPT IS TOO LONG")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_old_max_tokens_exceeded_not_matched(self):
|
||||
"""The old broad 'max_tokens_exceeded' pattern was removed.
|
||||
Only 'input tokens exceed' should match now."""
|
||||
err = Exception("max_tokens_exceeded")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
@@ -226,7 +226,7 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -52,7 +52,7 @@ def _validate_workspace_path(
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -71,7 +71,7 @@ def _validate_tool_access(
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning("Blocked tool access attempt: %s", tool_name)
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
@@ -89,9 +89,7 @@ def _validate_tool_access(
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
"Blocked dangerous pattern in tool input: %s in %s",
|
||||
pattern,
|
||||
tool_name,
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
)
|
||||
return _deny(
|
||||
"[SECURITY] Input contains a blocked pattern. "
|
||||
@@ -113,9 +111,7 @@ def _validate_user_isolation(
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(
|
||||
"Blocked path traversal attempt: %s by user %s", path, user_id
|
||||
)
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
@@ -174,7 +170,7 @@ def create_security_hooks(
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info("[SDK] Blocked background Task, user=%s", user_id)
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
@@ -185,9 +181,7 @@ def create_security_hooks(
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
"[SDK] Task limit reached (%d), user=%s",
|
||||
max_subtasks,
|
||||
user_id,
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
@@ -218,7 +212,7 @@ def create_security_hooks(
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
@@ -288,11 +282,8 @@ def create_security_hooks(
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
|
||||
tool_name,
|
||||
str(error).replace("\n", "").replace("\r", ""),
|
||||
user_id,
|
||||
tool_use_id,
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
@@ -310,19 +301,16 @@ def create_security_hooks(
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input before logging to prevent log injection
|
||||
trigger = (
|
||||
str(input_data.get("trigger", "auto"))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,283 +0,0 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
and the ``ReducedContext`` named tuple.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
def test_direct_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("prompt is too long")) is True
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("network timeout")) is False
|
||||
|
||||
def test_request_too_large(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("request too large for model")) is True
|
||||
|
||||
def test_context_length_exceeded(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
|
||||
|
||||
def test_max_tokens_exceeded_not_matched(self) -> None:
|
||||
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
|
||||
|
||||
def test_max_tokens_config_error_no_match(self) -> None:
|
||||
"""'max_tokens must be at least 1' should NOT match."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
|
||||
|
||||
def test_chained_cause(self) -> None:
|
||||
inner = Exception("prompt is too long")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_chained_context(self) -> None:
|
||||
inner = Exception("request too large")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_deep_chain(self) -> None:
|
||||
bottom = Exception("maximum context length")
|
||||
middle = RuntimeError("middle")
|
||||
middle.__cause__ = bottom
|
||||
top = ValueError("top")
|
||||
top.__cause__ = middle
|
||||
assert _is_prompt_too_long(top) is True
|
||||
|
||||
def test_chain_no_match(self) -> None:
|
||||
inner = Exception("rate limit exceeded")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
|
||||
def test_cycle_detection(self) -> None:
|
||||
"""Exception chain with a cycle should not infinite-loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # cycle
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_all_patterns(self) -> None:
|
||||
patterns = [
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
"maximum context length",
|
||||
"context_length_exceeded",
|
||||
"input tokens exceed",
|
||||
"input is too long",
|
||||
"content length exceeds",
|
||||
]
|
||||
for pattern in patterns:
|
||||
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reduce_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReduceContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_retry_compaction_success(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value="/tmp/resume.jsonl",
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert isinstance(ctx, ReducedContext)
|
||||
assert ctx.use_resume is True
|
||||
assert ctx.resume_file == "/tmp/resume.jsonl"
|
||||
assert ctx.transcript_lost is False
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_fails_drops_transcript(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.resume_file is None
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_tried_compaction_skips(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_drops(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_returns_same_content_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript, # same content
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_tempfile_fails_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterSdkMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_messages(self) -> None:
|
||||
messages = ["msg1", "msg2", "msg3"]
|
||||
client = AsyncMock()
|
||||
|
||||
async def _fake_receive() -> AsyncGenerator[str]:
|
||||
for m in messages:
|
||||
yield m
|
||||
|
||||
client.receive_response = _fake_receive
|
||||
result = [msg async for msg in _iter_sdk_messages(client)]
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_on_timeout(self) -> None:
|
||||
"""Yields None when asyncio.wait times out."""
|
||||
client = AsyncMock()
|
||||
received: list = []
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
await asyncio.sleep(100) # never completes
|
||||
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
|
||||
count = 0
|
||||
async for msg in _iter_sdk_messages(client):
|
||||
received.append(msg)
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
assert all(m is None for m in received)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_propagates(self) -> None:
|
||||
client = AsyncMock()
|
||||
|
||||
async def _error_receive() -> AsyncGenerator[str]:
|
||||
raise RuntimeError("SDK crash")
|
||||
yield # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _error_receive
|
||||
|
||||
with pytest.raises(RuntimeError, match="SDK crash"):
|
||||
async for _ in _iter_sdk_messages(client):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_break(self) -> None:
|
||||
"""Pending task is cancelled when generator is closed."""
|
||||
client = AsyncMock()
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
yield "first"
|
||||
await asyncio.sleep(100)
|
||||
yield "second"
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
gen = _iter_sdk_messages(client)
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
@@ -20,7 +20,7 @@ class _FakeFileInfo:
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
|
||||
@@ -234,9 +234,7 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
@@ -349,7 +347,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
"""
|
||||
|
||||
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
@@ -363,9 +361,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
args = await expand_file_refs_in_args(args, user_id, session)
|
||||
except FileRefExpansionError as exc:
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
@@ -393,12 +389,11 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
schema = _build_input_schema(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
schema,
|
||||
)(_truncating(handler, tool_name, input_schema=schema))
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
|
||||
@@ -10,9 +10,6 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -20,12 +17,8 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -106,14 +99,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
# seen_parents is local per-entry (not shared across iterations) so
|
||||
# it can only detect cycles within a single ancestry walk, not across
|
||||
# entries. This is intentional: each entry's parent chain is
|
||||
# independent, and reusing a global set would incorrectly short-circuit
|
||||
# valid re-use of the same UUID as a parent in different subtrees.
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
while parent in stripped_uuids:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -341,7 +327,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -351,17 +337,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -422,6 +408,8 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = parts
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
@@ -460,15 +448,17 @@ async def upload_transcript(
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
@@ -504,14 +494,11 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
)
|
||||
|
||||
|
||||
@@ -525,6 +512,8 @@ async def download_transcript(
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -532,10 +521,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -547,14 +536,10 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except FileNotFoundError:
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -568,6 +553,8 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -584,280 +571,3 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
else:
|
||||
# Preserve non-text blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
|
||||
Like ``_flatten_assistant_content``, structured blocks (images, nested
|
||||
tool results) are reduced to text representations for compression.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content") or ""
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
sub_type = sub.get("type")
|
||||
if sub_type in ("image", "document"):
|
||||
# Avoid serializing base64 binary data into
|
||||
# the compaction input — use a placeholder.
|
||||
str_parts.append(f"[__{sub_type}__]")
|
||||
elif sub_type == "text" or sub.get("text") is not None:
|
||||
str_parts.append(str(sub.get("text", "")))
|
||||
else:
|
||||
str_parts.append(json.dumps(sub))
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, dict):
|
||||
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
btype = block.get("type", "unknown")
|
||||
str_parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to plain message dicts for compression.
|
||||
|
||||
Parses each line of the JSONL *content*, skips strippable metadata entries
|
||||
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
|
||||
flattened ``content`` from the ``message`` field of each remaining entry.
|
||||
|
||||
Structured content blocks (``tool_use``, ``tool_result``, images) are
|
||||
flattened to plain text via ``_flatten_assistant_content`` and
|
||||
``_flatten_tool_result_content`` so that ``compress_context`` can
|
||||
perform token counting and LLM summarization on uniform strings.
|
||||
|
||||
Returns:
|
||||
A list of ``{"role": str, "content": str}`` dicts suitable for
|
||||
``compress_context``.
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format.
|
||||
|
||||
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
|
||||
dicts returned by ``compress_context``. Each message becomes one JSONL
|
||||
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
|
||||
``--resume`` flag can reconstruct a valid conversation tree.
|
||||
|
||||
Assistant messages are wrapped in the full ``message`` envelope
|
||||
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
|
||||
that the CLI expects. User messages use the simpler ``{role, content}``
|
||||
form.
|
||||
|
||||
Returns:
|
||||
A newline-terminated JSONL string, or an empty string if *messages*
|
||||
is empty.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str = "" # root entry uses empty string, not null
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
_COMPACTION_TIMEOUT_SECONDS = 60
|
||||
_TRUNCATION_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
Uses the shared OpenAI client from ``get_openai_client()``.
|
||||
If no client is configured or the LLM call fails, falls back to
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
transcripts.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
model: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -68,7 +68,7 @@ class TranscriptBuilder:
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
isCompactSummary=data.get("isCompactSummary") or None,
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -382,7 +382,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -402,7 +402,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -420,7 +420,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -897,134 +897,3 @@ class TestCompactionFlowIntegration:
|
||||
output2 = builder2.to_jsonl()
|
||||
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
|
||||
assert lines2[-1]["parentUuid"] == "a2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression (direct tests for the 3 code paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompression:
|
||||
"""Direct tests for ``_run_compression`` covering all 3 code paths.
|
||||
|
||||
Paths:
|
||||
(a) No OpenAI client configured → truncation fallback immediately.
|
||||
(b) LLM success → returns LLM-compressed result.
|
||||
(c) LLM call raises → truncation fallback.
|
||||
"""
|
||||
|
||||
def _make_compress_result(self, was_compacted: bool, msgs=None):
|
||||
"""Build a minimal CompressResult-like object."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
was_compacted=was_compacted,
|
||||
messages=msgs or [{"role": "user", "content": "summary"}],
|
||||
original_token_count=500,
|
||||
token_count=100 if was_compacted else 500,
|
||||
messages_summarized=2 if was_compacted else 0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with client=None (truncation mode)
|
||||
call_kwargs = mock_compress.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("client") is None
|
||||
or (call_kwargs.args and call_kwargs.args[2] is None)
|
||||
or mock_compress.call_args[1].get("client") is None
|
||||
)
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with the real client
|
||||
assert mock_compress.called
|
||||
assert result is llm_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(**kwargs):
|
||||
call_count[0] += 1
|
||||
if kwargs.get("client") is not None:
|
||||
raise RuntimeError("LLM timeout")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (raises), once for truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
@@ -32,7 +32,6 @@ import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
@@ -44,6 +43,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -194,7 +194,7 @@ async def _save_browser_state(
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
@@ -218,7 +218,7 @@ async def _restore_browser_state(
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_name)
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestHasLocalSession:
|
||||
# _save_browser_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
|
||||
|
||||
|
||||
def _make_mock_manager():
|
||||
|
||||
@@ -41,7 +41,8 @@ import contextlib
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox, SandboxLifecycle
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.sandbox.sandbox_api import SandboxLifecycle
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.copilot.constants import (
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
@@ -198,29 +197,6 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Expand @@agptfile: refs in input_data with the block's input
|
||||
# schema. The generic _truncating wrapper skips opaque object
|
||||
# properties (input_data has no declared inner properties in the
|
||||
# tool schema), so file ref tokens are still intact here.
|
||||
# Using the block's schema lets us return raw text for string-typed
|
||||
# fields and parsed structures for list/dict-typed fields.
|
||||
if input_data:
|
||||
try:
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data,
|
||||
user_id,
|
||||
session,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
@@ -10,11 +10,11 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_workspace_manager,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -218,6 +218,12 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
@@ -380,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -530,7 +536,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
@@ -766,7 +772,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -893,7 +899,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -275,12 +275,13 @@ async def store_media_file(
|
||||
# Process file
|
||||
elif file.startswith("data:"):
|
||||
# Data URI
|
||||
parsed_uri = parse_data_uri(file)
|
||||
if parsed_uri is None:
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"Invalid data URI format. Expected data:<mime>;base64,<data>"
|
||||
)
|
||||
mime_type, b64_content = parsed_uri
|
||||
mime_type = match.group(1).strip().lower()
|
||||
b64_content = match.group(2).strip()
|
||||
|
||||
# Generate filename and decode
|
||||
extension = _extension_from_mime(mime_type)
|
||||
@@ -414,70 +415,13 @@ def get_dir_size(path: Path) -> int:
|
||||
return total
|
||||
|
||||
|
||||
async def resolve_media_content(
|
||||
content: MediaFileType,
|
||||
execution_context: "ExecutionContext",
|
||||
*,
|
||||
return_format: MediaReturnFormat,
|
||||
) -> MediaFileType:
|
||||
"""Resolve a ``MediaFileType`` value if it is a media reference, pass through otherwise.
|
||||
|
||||
Convenience wrapper around :func:`is_media_file_ref` + :func:`store_media_file`.
|
||||
Plain text content (source code, filenames) is returned unchanged. Media
|
||||
references (``data:``, ``workspace://``, ``http(s)://``) are resolved via
|
||||
:func:`store_media_file` using *return_format*.
|
||||
|
||||
Use this when a block field is typed as ``MediaFileType`` but may contain
|
||||
either literal text or a media reference.
|
||||
"""
|
||||
if not content or not is_media_file_ref(content):
|
||||
return content
|
||||
return await store_media_file(
|
||||
content, execution_context, return_format=return_format
|
||||
)
|
||||
|
||||
|
||||
def is_media_file_ref(value: str) -> bool:
|
||||
"""Return True if *value* looks like a ``MediaFileType`` reference.
|
||||
|
||||
Detects data URIs, workspace:// references, and HTTP(S) URLs — the
|
||||
formats accepted by :func:`store_media_file`. Plain text content
|
||||
(e.g. source code, filenames) returns False.
|
||||
|
||||
Known limitation: HTTP(S) URL detection is heuristic. Any string that
|
||||
starts with ``http://`` or ``https://`` is treated as a media URL, even
|
||||
if it appears as a URL inside source-code comments or documentation.
|
||||
Blocks that produce source code or Markdown as output may therefore
|
||||
trigger false positives. Callers that need higher precision should
|
||||
inspect the string further (e.g. verify the URL is reachable or has a
|
||||
media-friendly extension).
|
||||
|
||||
Note: this does *not* match local file paths, which are ambiguous
|
||||
(could be filenames or actual paths). Blocks that need to resolve
|
||||
local paths should check for them separately.
|
||||
"""
|
||||
return value.startswith(("data:", "workspace://", "http://", "https://"))
|
||||
|
||||
|
||||
def parse_data_uri(value: str) -> tuple[str, str] | None:
|
||||
"""Parse a ``data:<mime>;base64,<payload>`` URI.
|
||||
|
||||
Returns ``(mime_type, base64_payload)`` if *value* is a valid data URI,
|
||||
or ``None`` if it is not.
|
||||
"""
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", value, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1).strip().lower(), match.group(2).strip()
|
||||
|
||||
|
||||
def get_mime_type(file: str) -> str:
|
||||
"""
|
||||
Get the MIME type of a file, whether it's a data URI, URL, or local path.
|
||||
"""
|
||||
if file.startswith("data:"):
|
||||
parsed_uri = parse_data_uri(file)
|
||||
return parsed_uri[0] if parsed_uri else "application/octet-stream"
|
||||
match = re.match(r"^data:([^;]+);base64,", file)
|
||||
return match.group(1) if match else "application/octet-stream"
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
parsed_url = urlparse(file)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
"""Parse file content into structured Python objects based on file format.
|
||||
|
||||
Used by the ``@@agptfile:`` expansion system to eagerly parse well-known file
|
||||
formats into native Python types *before* schema-driven coercion runs. This
|
||||
lets blocks with ``Any``-typed inputs receive structured data rather than raw
|
||||
strings, while blocks expecting strings get the value coerced back via
|
||||
``convert()``.
|
||||
|
||||
Supported formats:
|
||||
|
||||
- **JSON** (``.json``) — arrays and objects are promoted; scalars stay as strings
|
||||
- **JSON Lines** (``.jsonl``, ``.ndjson``) — each non-empty line parsed as JSON;
|
||||
when all lines are dicts with the same keys (tabular data), output is
|
||||
``list[list[Any]]`` with a header row, consistent with CSV/Parquet/Excel;
|
||||
otherwise returns a plain ``list`` of parsed values
|
||||
- **CSV** (``.csv``) — ``csv.reader`` → ``list[list[str]]``
|
||||
- **TSV** (``.tsv``) — tab-delimited → ``list[list[str]]``
|
||||
- **YAML** (``.yaml``, ``.yml``) — parsed via PyYAML; containers only
|
||||
- **TOML** (``.toml``) — parsed via stdlib ``tomllib``
|
||||
- **Parquet** (``.parquet``) — via pandas/pyarrow → ``list[list[Any]]`` with header row
|
||||
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
|
||||
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
|
||||
|
||||
The **fallback contract** is enforced by :func:`parse_file_content`, not by
|
||||
individual parser functions. If any parser raises, ``parse_file_content``
|
||||
catches the exception and returns the original content unchanged (string for
|
||||
text formats, bytes for binary formats). Callers should never see an
|
||||
exception from the public API when ``strict=False``.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tomllib
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
|
||||
# posixpath.splitext handles forward-slash URI paths correctly on all platforms,
|
||||
# unlike os.path.splitext which uses platform-native separators.
|
||||
from posixpath import splitext
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extension / MIME → format label mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXT_TO_FORMAT: dict[str, str] = {
|
||||
".json": "json",
|
||||
".jsonl": "jsonl",
|
||||
".ndjson": "jsonl",
|
||||
".csv": "csv",
|
||||
".tsv": "tsv",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".toml": "toml",
|
||||
".parquet": "parquet",
|
||||
".xlsx": "xlsx",
|
||||
}
|
||||
|
||||
MIME_TO_FORMAT: dict[str, str] = {
|
||||
"application/json": "json",
|
||||
"application/x-ndjson": "jsonl",
|
||||
"application/jsonl": "jsonl",
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "tsv",
|
||||
"application/x-yaml": "yaml",
|
||||
"application/yaml": "yaml",
|
||||
"text/yaml": "yaml",
|
||||
"application/toml": "toml",
|
||||
"application/vnd.apache.parquet": "parquet",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
||||
}
|
||||
|
||||
# Formats that require raw bytes rather than decoded text.
|
||||
BINARY_FORMATS: frozenset[str] = frozenset({"parquet", "xlsx"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_format_from_uri(uri: str) -> str | None:
|
||||
"""Return a format label based on URI extension or MIME fragment.
|
||||
|
||||
Returns ``None`` when the format cannot be determined — the caller should
|
||||
fall back to returning the content as a plain string.
|
||||
"""
|
||||
# 1. Check MIME fragment (workspace://abc123#application/json)
|
||||
if "#" in uri:
|
||||
_, fragment = uri.rsplit("#", 1)
|
||||
fmt = MIME_TO_FORMAT.get(fragment.lower())
|
||||
if fmt:
|
||||
return fmt
|
||||
|
||||
# 2. Check file extension from the path portion.
|
||||
# Strip the fragment first so ".json#mime" doesn't confuse splitext.
|
||||
path = uri.split("#")[0].split("?")[0]
|
||||
_, ext = splitext(path)
|
||||
fmt = _EXT_TO_FORMAT.get(ext.lower())
|
||||
if fmt is not None:
|
||||
return fmt
|
||||
|
||||
# Legacy .xls is not supported — map it so callers can produce a
|
||||
# user-friendly error instead of returning garbled binary.
|
||||
if ext.lower() == ".xls":
|
||||
return "xls"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False) -> Any:
|
||||
"""Parse *content* according to *fmt* and return a native Python value.
|
||||
|
||||
When *strict* is ``False`` (default), returns the original *content*
|
||||
unchanged if *fmt* is not recognised or parsing fails for any reason.
|
||||
This mode **never raises**.
|
||||
|
||||
When *strict* is ``True``, parsing errors are propagated to the caller.
|
||||
Unrecognised formats or type mismatches (e.g. text for a binary format)
|
||||
still return *content* unchanged without raising.
|
||||
"""
|
||||
if fmt == "xls":
|
||||
return (
|
||||
"[Unsupported format] Legacy .xls files are not supported. "
|
||||
"Please re-save the file as .xlsx (Excel 2007+) and upload again."
|
||||
)
|
||||
|
||||
try:
|
||||
if fmt in BINARY_FORMATS:
|
||||
parser = _BINARY_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
# Caller gave us text for a binary format — can't parse.
|
||||
return content
|
||||
return parser(content)
|
||||
|
||||
parser = _TEXT_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
return parser(content)
|
||||
|
||||
except PARSE_EXCEPTIONS:
|
||||
if strict:
|
||||
raise
|
||||
logger.debug("Structured parsing failed for format=%s, falling back", fmt)
|
||||
return content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_openpyxl_exception() -> type[Exception]:
|
||||
"""Return openpyxl's InvalidFileException, raising ImportError if absent."""
|
||||
from openpyxl.utils.exceptions import InvalidFileException # noqa: PLC0415
|
||||
|
||||
return InvalidFileException
|
||||
|
||||
|
||||
def _load_arrow_exception() -> type[Exception]:
|
||||
"""Return pyarrow's ArrowException, raising ImportError if absent."""
|
||||
from pyarrow import ArrowException # noqa: PLC0415
|
||||
|
||||
return ArrowException
|
||||
|
||||
|
||||
def _optional_exc(loader: "Callable[[], type[Exception]]") -> "type[Exception] | None":
|
||||
"""Return the exception class from *loader*, or ``None`` if the dep is absent."""
|
||||
try:
|
||||
return loader()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
# Exception types that can be raised during file content parsing.
|
||||
# Shared between ``parse_file_content`` (which catches them in non-strict mode)
|
||||
# and ``file_ref._expand_bare_ref`` (which re-raises them as FileRefExpansionError).
|
||||
#
|
||||
# Optional-dependency exception types are loaded via a helper that raises
|
||||
# ``ImportError`` at *parse time* rather than silently becoming ``None`` here.
|
||||
# This ensures mypy sees clean types and missing deps surface as real errors.
|
||||
PARSE_EXCEPTIONS: tuple[type[BaseException], ...] = tuple(
|
||||
exc
|
||||
for exc in (
|
||||
json.JSONDecodeError,
|
||||
csv.Error,
|
||||
yaml.YAMLError,
|
||||
tomllib.TOMLDecodeError,
|
||||
ValueError,
|
||||
UnicodeDecodeError,
|
||||
ImportError,
|
||||
OSError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
zipfile.BadZipFile,
|
||||
_optional_exc(_load_openpyxl_exception),
|
||||
# ArrowException covers ArrowIOError and ArrowCapacityError which
|
||||
# do not inherit from standard exceptions; ArrowInvalid/ArrowTypeError
|
||||
# already map to ValueError/TypeError but this catches the rest.
|
||||
_optional_exc(_load_arrow_exception),
|
||||
)
|
||||
if exc is not None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text-based parsers (content: str → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_container(parser: Callable[[str], Any], content: str) -> list | dict | str:
|
||||
"""Parse *content* and return the result only if it is a container (list/dict).
|
||||
|
||||
Scalar values (strings, numbers, booleans, None) are discarded and the
|
||||
original *content* string is returned instead. This prevents e.g. a JSON
|
||||
file containing just ``"42"`` from silently becoming an int.
|
||||
"""
|
||||
parsed = parser(content)
|
||||
if isinstance(parsed, (list, dict)):
|
||||
return parsed
|
||||
return content
|
||||
|
||||
|
||||
def _parse_json(content: str) -> list | dict | str:
|
||||
return _parse_container(json.loads, content)
|
||||
|
||||
|
||||
def _parse_jsonl(content: str) -> Any:
|
||||
lines = [json.loads(line) for line in content.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return content
|
||||
|
||||
# When every line is a dict with the same keys, convert to table format
|
||||
# (header row + data rows) — consistent with CSV/TSV/Parquet/Excel output.
|
||||
# Require ≥2 dicts so a single-line JSONL stays as [dict] (not a table).
|
||||
if len(lines) >= 2 and all(isinstance(obj, dict) for obj in lines):
|
||||
keys = list(lines[0].keys())
|
||||
# Cache as tuple to avoid O(n×k) list allocations in the all() call.
|
||||
keys_tuple = tuple(keys)
|
||||
if keys and all(tuple(obj.keys()) == keys_tuple for obj in lines[1:]):
|
||||
return [keys] + [[obj[k] for k in keys] for obj in lines]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _parse_csv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter=",")
|
||||
|
||||
|
||||
def _parse_tsv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter="\t")
|
||||
|
||||
|
||||
def _parse_delimited(content: str, *, delimiter: str) -> Any:
|
||||
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
|
||||
# csv.reader never yields [] — blank lines yield [""]. Filter out
|
||||
# rows where every cell is empty (i.e. truly blank lines).
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
if not rows:
|
||||
return content
|
||||
# If the declared delimiter produces only single-column rows, try
|
||||
# sniffing the actual delimiter — catches misidentified files (e.g.
|
||||
# a tab-delimited file with a .csv extension).
|
||||
if len(rows[0]) == 1:
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(content[:8192])
|
||||
if dialect.delimiter != delimiter:
|
||||
reader = csv.reader(io.StringIO(content), dialect)
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
except csv.Error:
|
||||
pass
|
||||
if rows and len(rows[0]) >= 2:
|
||||
return rows
|
||||
return content
|
||||
|
||||
|
||||
def _row_has_content(row: list[str]) -> bool:
|
||||
"""Return True when *row* contains at least one non-empty cell.
|
||||
|
||||
``csv.reader`` never yields ``[]`` — truly blank lines yield ``[""]``.
|
||||
This predicate filters those out consistently across the initial read
|
||||
and the sniffer-fallback re-read.
|
||||
"""
|
||||
return any(cell for cell in row)
|
||||
|
||||
|
||||
def _parse_yaml(content: str) -> list | dict | str:
|
||||
# NOTE: YAML anchor/alias expansion can amplify input beyond the 10MB cap.
|
||||
# safe_load prevents code execution; for production hardening consider
|
||||
# a YAML parser with expansion limits (e.g. ruamel.yaml with max_alias_count).
|
||||
if "\n---" in content or content.startswith("---\n"):
|
||||
# Multi-document YAML: only the first document is parsed; the rest
|
||||
# are silently ignored by yaml.safe_load. Warn so callers are aware.
|
||||
logger.warning(
|
||||
"Multi-document YAML detected (--- separator); "
|
||||
"only the first document will be parsed."
|
||||
)
|
||||
return _parse_container(yaml.safe_load, content)
|
||||
|
||||
|
||||
def _parse_toml(content: str) -> Any:
|
||||
parsed = tomllib.loads(content)
|
||||
# tomllib.loads always returns a dict — return it even if empty.
|
||||
return parsed
|
||||
|
||||
|
||||
_TEXT_PARSERS: dict[str, Callable[[str], Any]] = {
|
||||
"json": _parse_json,
|
||||
"jsonl": _parse_jsonl,
|
||||
"csv": _parse_csv,
|
||||
"tsv": _parse_tsv,
|
||||
"yaml": _parse_yaml,
|
||||
"toml": _parse_toml,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary-based parsers (content: bytes → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_parquet(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_parquet(io.BytesIO(content))
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _parse_xlsx(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
# Explicitly specify openpyxl engine; the default engine varies by pandas
|
||||
# version and does not support legacy .xls (which is excluded by our format map).
|
||||
df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _df_to_rows(df: Any) -> list[list[Any]]:
|
||||
"""Convert a DataFrame to ``list[list[Any]]`` with a header row.
|
||||
|
||||
NaN values are replaced with ``None`` so the result is JSON-serializable.
|
||||
Uses explicit cell-level checking because ``df.where(df.notna(), None)``
|
||||
silently converts ``None`` back to ``NaN`` in float64 columns.
|
||||
"""
|
||||
header = df.columns.tolist()
|
||||
rows = [
|
||||
[None if _is_nan(cell) else cell for cell in row] for row in df.values.tolist()
|
||||
]
|
||||
return [header] + rows
|
||||
|
||||
|
||||
def _is_nan(cell: Any) -> bool:
|
||||
"""Check if a cell value is NaN, handling non-scalar types (lists, dicts).
|
||||
|
||||
``pd.isna()`` on a list/dict returns a boolean array which raises
|
||||
``ValueError`` in a boolean context. Guard with a scalar check first.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
return bool(pd.api.types.is_scalar(cell) and pd.isna(cell))
|
||||
|
||||
|
||||
_BINARY_PARSERS: dict[str, Callable[[bytes], Any]] = {
|
||||
"parquet": _parse_parquet,
|
||||
"xlsx": _parse_xlsx,
|
||||
}
|
||||
@@ -1,624 +0,0 @@
|
||||
"""Tests for file_content_parser — format inference and structured parsing."""
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# infer_format_from_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferFormat:
|
||||
# --- extension-based ---
|
||||
|
||||
def test_json_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.json") == "json"
|
||||
|
||||
def test_jsonl_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.jsonl") == "jsonl"
|
||||
|
||||
def test_ndjson_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.ndjson") == "jsonl"
|
||||
|
||||
def test_csv_extension(self):
|
||||
assert infer_format_from_uri("workspace:///reports/sales.csv") == "csv"
|
||||
|
||||
def test_tsv_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.tsv") == "tsv"
|
||||
|
||||
def test_yaml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yaml") == "yaml"
|
||||
|
||||
def test_yml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yml") == "yaml"
|
||||
|
||||
def test_toml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.toml") == "toml"
|
||||
|
||||
def test_parquet_extension(self):
|
||||
assert infer_format_from_uri("/data/table.parquet") == "parquet"
|
||||
|
||||
def test_xlsx_extension(self):
|
||||
assert infer_format_from_uri("/data/spreadsheet.xlsx") == "xlsx"
|
||||
|
||||
def test_xls_extension_returns_xls_label(self):
|
||||
# Legacy .xls is mapped so callers can produce a helpful error.
|
||||
assert infer_format_from_uri("/data/old_spreadsheet.xls") == "xls"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert infer_format_from_uri("/data/FILE.JSON") == "json"
|
||||
assert infer_format_from_uri("/data/FILE.CSV") == "csv"
|
||||
|
||||
def test_unicode_filename(self):
|
||||
assert infer_format_from_uri("/home/user/\u30c7\u30fc\u30bf.json") == "json"
|
||||
assert infer_format_from_uri("/home/user/\u00e9t\u00e9.csv") == "csv"
|
||||
|
||||
def test_unknown_extension(self):
|
||||
assert infer_format_from_uri("/home/user/readme.txt") is None
|
||||
|
||||
def test_no_extension(self):
|
||||
assert infer_format_from_uri("workspace://abc123") is None
|
||||
|
||||
# --- MIME-based ---
|
||||
|
||||
def test_mime_json(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/json") == "json"
|
||||
|
||||
def test_mime_csv(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/csv") == "csv"
|
||||
|
||||
def test_mime_tsv(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#text/tab-separated-values")
|
||||
== "tsv"
|
||||
)
|
||||
|
||||
def test_mime_ndjson(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/x-ndjson") == "jsonl"
|
||||
)
|
||||
|
||||
def test_mime_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/x-yaml") == "yaml"
|
||||
|
||||
def test_mime_xlsx(self):
|
||||
uri = "workspace://abc123#application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
assert infer_format_from_uri(uri) == "xlsx"
|
||||
|
||||
def test_mime_parquet(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/vnd.apache.parquet")
|
||||
== "parquet"
|
||||
)
|
||||
|
||||
def test_unknown_mime(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/plain") is None
|
||||
|
||||
def test_unknown_mime_falls_through_to_extension(self):
|
||||
# Unknown MIME (text/plain) should fall through to extension-based detection.
|
||||
assert infer_format_from_uri("workspace:///data.csv#text/plain") == "csv"
|
||||
|
||||
# --- MIME takes precedence over extension ---
|
||||
|
||||
def test_mime_overrides_extension(self):
|
||||
# .txt extension but JSON MIME → json
|
||||
assert infer_format_from_uri("workspace:///file.txt#application/json") == "json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJson:
|
||||
def test_array(self):
|
||||
result = parse_file_content("[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_object(self):
|
||||
result = parse_file_content('{"key": "value"}', "json")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_nested(self):
|
||||
content = json.dumps({"rows": [[1, 2], [3, 4]]})
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == {"rows": [[1, 2], [3, 4]]}
|
||||
|
||||
def test_scalar_string_stays_as_string(self):
|
||||
result = parse_file_content('"hello"', "json")
|
||||
assert result == '"hello"' # original content, not parsed
|
||||
|
||||
def test_scalar_number_stays_as_string(self):
|
||||
result = parse_file_content("42", "json")
|
||||
assert result == "42"
|
||||
|
||||
def test_scalar_boolean_stays_as_string(self):
|
||||
result = parse_file_content("true", "json")
|
||||
assert result == "true"
|
||||
|
||||
def test_null_stays_as_string(self):
|
||||
result = parse_file_content("null", "json")
|
||||
assert result == "null"
|
||||
|
||||
def test_invalid_json_fallback(self):
|
||||
content = "not json at all"
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == content
|
||||
|
||||
def test_empty_string_fallback(self):
|
||||
result = parse_file_content("", "json")
|
||||
assert result == ""
|
||||
|
||||
def test_bytes_input_decoded(self):
|
||||
result = parse_file_content(b"[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSONL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJsonl:
|
||||
def test_tabular_uniform_dicts_to_table_format(self):
|
||||
"""JSONL with uniform dict keys → table format (header + rows),
|
||||
consistent with CSV/TSV/Parquet/Excel output."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":"yellow"}\n{"name":"cherry","color":"red"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", "yellow"],
|
||||
["cherry", "red"],
|
||||
]
|
||||
|
||||
def test_tabular_single_key_dicts(self):
|
||||
"""JSONL with single-key uniform dicts → table format."""
|
||||
content = '{"a": 1}\n{"a": 2}\n{"a": 3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2], [3]]
|
||||
|
||||
def test_tabular_blank_lines_skipped(self):
|
||||
content = '{"a": 1}\n\n{"a": 2}\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2]]
|
||||
|
||||
def test_heterogeneous_dicts_stay_as_list(self):
|
||||
"""JSONL with different keys across objects → list of dicts (no table)."""
|
||||
content = '{"name":"apple"}\n{"color":"red"}\n{"size":3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"name": "apple"}, {"color": "red"}, {"size": 3}]
|
||||
|
||||
def test_partially_overlapping_keys_stay_as_list(self):
|
||||
"""JSONL dicts with partially overlapping keys → list of dicts."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","size":"medium"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
{"name": "apple", "color": "red"},
|
||||
{"name": "banana", "size": "medium"},
|
||||
]
|
||||
|
||||
def test_mixed_types_stay_as_list(self):
|
||||
"""JSONL with non-dict lines → list of parsed values (no table)."""
|
||||
content = '1\n"hello"\n[1,2]\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [1, "hello", [1, 2]]
|
||||
|
||||
def test_mixed_dicts_and_non_dicts_stay_as_list(self):
|
||||
"""JSONL mixing dicts and non-dicts → list of parsed values."""
|
||||
content = '{"a": 1}\n42\n{"b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1}, 42, {"b": 2}]
|
||||
|
||||
def test_tabular_preserves_key_order(self):
|
||||
"""Table header should follow the key order of the first object."""
|
||||
content = '{"z": 1, "a": 2}\n{"z": 3, "a": 4}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result[0] == ["z", "a"] # order from first object
|
||||
assert result[1] == [1, 2]
|
||||
assert result[2] == [3, 4]
|
||||
|
||||
def test_single_dict_stays_as_list(self):
|
||||
"""Single-line JSONL with one dict → [dict], NOT a table.
|
||||
Tabular detection requires ≥2 dicts to avoid vacuously true all()."""
|
||||
content = '{"a": 1, "b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1, "b": 2}]
|
||||
|
||||
def test_tabular_with_none_values(self):
|
||||
"""Uniform keys but some null values → table with None cells."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":null}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", None],
|
||||
]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "jsonl")
|
||||
assert result == ""
|
||||
|
||||
def test_all_blank_lines_fallback(self):
|
||||
result = parse_file_content("\n\n\n", "jsonl")
|
||||
assert result == "\n\n\n"
|
||||
|
||||
def test_invalid_line_fallback(self):
|
||||
content = '{"a": 1}\nnot json\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == content # fallback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseCsv:
|
||||
def test_basic(self):
|
||||
content = "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_quoted_fields(self):
|
||||
content = 'Name,Bio\nAlice,"Loves, commas"\nBob,Simple'
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result[1] == ["Alice", "Loves, commas"]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
# Only 1 column — not tabular enough.
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
def test_empty_rows_skipped(self):
|
||||
content = "A,B\n\n1,2\n\n3,4"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["A", "B"], ["1", "2"], ["3", "4"]]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "csv")
|
||||
assert result == ""
|
||||
|
||||
def test_utf8_bom(self):
|
||||
"""CSV with a UTF-8 BOM should parse correctly (BOM stripped by decode)."""
|
||||
bom = "\ufeff"
|
||||
content = bom + "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
# The BOM may be part of the first header cell; ensure rows are still parsed.
|
||||
assert len(result) == 3
|
||||
assert result[1] == ["Alice", "90"]
|
||||
assert result[2] == ["Bob", "85"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseTsv:
|
||||
def test_basic(self):
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — YAML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseYaml:
|
||||
def test_list(self):
|
||||
content = "- apple\n- banana\n- cherry"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_dict(self):
|
||||
content = "name: Alice\nage: 30"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"name": "Alice", "age": 30}
|
||||
|
||||
def test_nested(self):
|
||||
content = "users:\n - name: Alice\n - name: Bob"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"users": [{"name": "Alice"}, {"name": "Bob"}]}
|
||||
|
||||
def test_scalar_stays_as_string(self):
|
||||
result = parse_file_content("hello world", "yaml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_invalid_yaml_fallback(self):
|
||||
content = ":\n :\n invalid: - -"
|
||||
result = parse_file_content(content, "yaml")
|
||||
# Malformed YAML should fall back to the original string, not raise.
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TOML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseToml:
|
||||
def test_basic(self):
|
||||
content = '[server]\nhost = "localhost"\nport = 8080'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"server": {"host": "localhost", "port": 8080}}
|
||||
|
||||
def test_flat(self):
|
||||
content = 'name = "test"\ncount = 42'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"name": "test", "count": 42}
|
||||
|
||||
def test_empty_string_returns_empty_dict(self):
|
||||
result = parse_file_content("", "toml")
|
||||
assert result == {}
|
||||
|
||||
def test_invalid_toml_fallback(self):
|
||||
result = parse_file_content("not = [valid toml", "toml")
|
||||
assert result == "not = [valid toml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Parquet (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
import pyarrow as _pa # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
|
||||
_has_pyarrow = True
|
||||
except ImportError:
|
||||
_has_pyarrow = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestParseParquet:
|
||||
@pytest.fixture
|
||||
def parquet_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, parquet_bytes: bytes):
|
||||
result = parse_file_content(parquet_bytes, "parquet")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
# Parquet is binary — string input can't be parsed.
|
||||
result = parse_file_content("not parquet", "parquet")
|
||||
assert result == "not parquet"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not parquet bytes", "parquet")
|
||||
assert result == b"not parquet bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "parquet")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in Parquet must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", None, "z"]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
# Row with NaN in float col → None
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[2][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]:
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Excel (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseExcel:
|
||||
@pytest.fixture
|
||||
def xlsx_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type] # BytesIO is a valid target
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, xlsx_bytes: bytes):
|
||||
result = parse_file_content(xlsx_bytes, "xlsx")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
result = parse_file_content("not xlsx", "xlsx")
|
||||
assert result == "not xlsx"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "xlsx")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in float columns must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", "y", None]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type]
|
||||
result = parse_file_content(buf.getvalue(), "xlsx")
|
||||
# Row with NaN in float col → None, not float('nan')
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[3][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]: # skip header
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — unknown format / fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallback:
|
||||
def test_unknown_format_returns_content(self):
|
||||
result = parse_file_content("hello world", "xml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_none_format_returns_content(self):
|
||||
# Shouldn't normally be called with unrecognised format, but must not crash.
|
||||
result = parse_file_content("hello", "unknown_format")
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BINARY_FORMATS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBinaryFormats:
|
||||
def test_parquet_is_binary(self):
|
||||
assert "parquet" in BINARY_FORMATS
|
||||
|
||||
def test_xlsx_is_binary(self):
|
||||
assert "xlsx" in BINARY_FORMATS
|
||||
|
||||
def test_text_formats_not_binary(self):
|
||||
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
|
||||
assert fmt not in BINARY_FORMATS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MIME mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMimeMapping:
|
||||
def test_application_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/yaml") == "yaml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSV sniffer fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCsvSnifferFallback:
|
||||
def test_tab_delimited_with_csv_format(self):
|
||||
"""Tab-delimited content parsed as csv should use sniffer fallback."""
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_sniffer_failure_returns_content(self):
|
||||
"""When sniffer fails, single-column falls back to raw content."""
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenpyxlInvalidFile fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenpyxlFallback:
|
||||
def test_invalid_xlsx_non_strict(self):
|
||||
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Header-only CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeaderOnlyCsv:
|
||||
def test_header_only_csv_returns_header_row(self):
|
||||
"""CSV with only a header row (no data rows) should return [[header]]."""
|
||||
content = "Name,Score"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
def test_header_only_csv_with_trailing_newline(self):
|
||||
content = "Name,Score\n"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary format + line range (line range ignored for binary formats)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestBinaryFormatLineRange:
|
||||
def test_parquet_ignores_line_range(self):
|
||||
"""Binary formats should parse the full file regardless of line range.
|
||||
|
||||
Line ranges are meaningless for binary formats (parquet/xlsx) — the
|
||||
caller (file_ref._expand_bare_ref) passes raw bytes and the parser
|
||||
should return the complete structured data.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
# parse_file_content itself doesn't take a line range — this tests
|
||||
# that the full content is parsed even though the bytes could have
|
||||
# been truncated upstream (it's not, by design).
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
assert result == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy .xls UX
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestXlsFallback:
|
||||
def test_xls_returns_helpful_error_string(self):
|
||||
"""Uploading a .xls file should produce a helpful error, not garbled binary."""
|
||||
result = parse_file_content(b"\xd0\xcf\x11\xe0garbled", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
assert "not supported" in result.lower()
|
||||
|
||||
def test_xls_with_string_content(self):
|
||||
result = parse_file_content("some text", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
@@ -8,12 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import (
|
||||
is_media_file_ref,
|
||||
parse_data_uri,
|
||||
resolve_media_content,
|
||||
store_media_file,
|
||||
)
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
@@ -349,162 +344,3 @@ class TestFileCloudIntegration:
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_media_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsMediaFileRef:
|
||||
def test_data_uri(self):
|
||||
assert is_media_file_ref("data:image/png;base64,iVBORw0KGg==") is True
|
||||
|
||||
def test_workspace_uri(self):
|
||||
assert is_media_file_ref("workspace://abc123") is True
|
||||
|
||||
def test_workspace_uri_with_mime(self):
|
||||
assert is_media_file_ref("workspace://abc123#image/png") is True
|
||||
|
||||
def test_http_url(self):
|
||||
assert is_media_file_ref("http://example.com/image.png") is True
|
||||
|
||||
def test_https_url(self):
|
||||
assert is_media_file_ref("https://example.com/image.png") is True
|
||||
|
||||
def test_plain_text(self):
|
||||
assert is_media_file_ref("print('hello')") is False
|
||||
|
||||
def test_local_path(self):
|
||||
assert is_media_file_ref("/tmp/file.txt") is False
|
||||
|
||||
def test_empty_string(self):
|
||||
assert is_media_file_ref("") is False
|
||||
|
||||
def test_filename(self):
|
||||
assert is_media_file_ref("image.png") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_data_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseDataUri:
|
||||
def test_valid_png(self):
|
||||
result = parse_data_uri("data:image/png;base64,iVBORw0KGg==")
|
||||
assert result is not None
|
||||
mime, payload = result
|
||||
assert mime == "image/png"
|
||||
assert payload == "iVBORw0KGg=="
|
||||
|
||||
def test_valid_text(self):
|
||||
result = parse_data_uri("data:text/plain;base64,SGVsbG8=")
|
||||
assert result is not None
|
||||
assert result[0] == "text/plain"
|
||||
assert result[1] == "SGVsbG8="
|
||||
|
||||
def test_mime_case_normalized(self):
|
||||
result = parse_data_uri("data:IMAGE/PNG;base64,abc")
|
||||
assert result is not None
|
||||
assert result[0] == "image/png"
|
||||
|
||||
def test_not_data_uri(self):
|
||||
assert parse_data_uri("workspace://abc123") is None
|
||||
|
||||
def test_plain_text(self):
|
||||
assert parse_data_uri("hello world") is None
|
||||
|
||||
def test_missing_base64(self):
|
||||
assert parse_data_uri("data:image/png;utf-8,abc") is None
|
||||
|
||||
def test_empty_payload(self):
|
||||
result = parse_data_uri("data:image/png;base64,")
|
||||
assert result is not None
|
||||
assert result[1] == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_media_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveMediaContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text_passthrough(self):
|
||||
"""Plain text content (not a media ref) passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("print('hello')"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_passthrough(self):
|
||||
"""Empty string passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(""),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_media_ref_delegates_to_store(self):
|
||||
"""Media references are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_uri_delegates_to_store(self):
|
||||
"""Data URIs are also resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
data_uri = "data:image/png;base64,iVBORw0KGg=="
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType(data_uri)),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(data_uri),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == data_uri
|
||||
mock_store.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_https_url_delegates_to_store(self):
|
||||
"""HTTPS URLs are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@@ -70,10 +70,6 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
# Count text block tokens (standard: "text" key, fallback: "content")
|
||||
text_val = item.get("text") or item.get("content", "")
|
||||
tool_call_tokens += _tok_len(text_val, enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
@@ -149,16 +145,10 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||
if max_tok < 1:
|
||||
return ""
|
||||
mid = enc.encode(" … ")
|
||||
if max_tok < 3:
|
||||
return enc.decode(ids[:max_tok])
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
@@ -555,14 +545,6 @@ async def _summarize_messages_llm(
|
||||
"- Actions taken and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"IMPORTANT: Preserve all concrete references verbatim — these are small but "
|
||||
"critical for continuing the conversation:\n"
|
||||
"- File paths and directory paths (e.g. /src/app/page.tsx, ./output/result.csv)\n"
|
||||
"- Image/media file paths from tool outputs\n"
|
||||
"- URLs, API endpoints, and webhook addresses\n"
|
||||
"- Resource IDs, session IDs, and identifiers\n"
|
||||
"- Tool names that were called and their key parameters\n"
|
||||
"- Environment variables, config keys, and credentials names (not values)\n\n"
|
||||
"Include ONLY the sections below that have relevant content "
|
||||
"(skip sections with nothing to report):\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
@@ -570,8 +552,7 @@ async def _summarize_messages_llm(
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers. "
|
||||
"Include exact file paths, image paths from tool outputs, and resource URLs.\n\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions.\n\n"
|
||||
"## 5. All User Messages\n"
|
||||
@@ -585,7 +566,7 @@ async def _summarize_messages_llm(
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=2000,
|
||||
max_tokens=1500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
@@ -705,15 +686,11 @@ async def compress_context(
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||
original_count,
|
||||
total_tokens(),
|
||||
messages_summarized,
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Summarization failed, continuing with truncation: %s", e
|
||||
)
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
|
||||
89
autogpt_platform/backend/poetry.lock
generated
89
autogpt_platform/backend/poetry.lock
generated
@@ -1360,18 +1360,6 @@ files = [
|
||||
dnspython = ">=2.0.0"
|
||||
idna = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "et-xmlfile"
|
||||
version = "2.0.0"
|
||||
description = "An implementation of lxml.xmlfile for the standard library"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"},
|
||||
{file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exa-py"
|
||||
version = "1.16.1"
|
||||
@@ -4240,21 +4228,6 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
realtime = ["websockets (>=13,<16)"]
|
||||
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "openpyxl"
|
||||
version = "3.1.5"
|
||||
description = "A Python library to read/write Excel 2010 xlsx/xlsm files"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"},
|
||||
{file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
et-xmlfile = "*"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.39.1"
|
||||
@@ -5457,66 +5430,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "23.0.1"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3fab8f82571844eb3c460f90a75583801d14ca0cc32b1acc8c361650e006fd56"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:3f91c038b95f71ddfc865f11d5876c42f343b4495535bd262c7b321b0b94507c"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d0744403adabef53c985a7f8a082b502a368510c40d184df349a0a8754533258"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c33b5bf406284fd0bba436ed6f6c3ebe8e311722b441d89397c54f871c6863a2"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ddf743e82f69dcd6dbbcb63628895d7161e04e56794ef80550ac6f3315eeb1d5"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e052a211c5ac9848ae15d5ec875ed0943c0221e2fcfe69eee80b604b4e703222"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5abde149bb3ce524782d838eb67ac095cd3fd6090eba051130589793f1a7f76d"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce"},
|
||||
{file = "pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.2"
|
||||
@@ -8969,4 +8882,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "86dab25684dd46e635a33bd33281a926e5626a874ecc048c34389fecf34a87d8"
|
||||
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
|
||||
|
||||
@@ -92,8 +92,6 @@ gravitas-md2gdocs = "^0.1.0"
|
||||
posthog = "^7.6.0"
|
||||
fpdf2 = "^2.8.6"
|
||||
langsmith = "^0.7.7"
|
||||
openpyxl = "^3.1.5"
|
||||
pyarrow = "^23.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
@@ -18,11 +18,8 @@ test.beforeEach(async ({ page }) => {
|
||||
await page.goto("/build");
|
||||
await buildPage.closeTutorial();
|
||||
|
||||
const [dictionaryBlock] = await buildPage.getFilteredBlocksFromAPI(
|
||||
(block) => block.name === "AddToDictionaryBlock",
|
||||
);
|
||||
|
||||
await buildPage.addBlock(dictionaryBlock);
|
||||
await buildPage.addBlockByClick("Add to Dictionary");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
|
||||
await buildPage.saveAgent("Test Agent", "Test Description");
|
||||
await test
|
||||
|
||||
@@ -1,363 +1,134 @@
|
||||
// TODO: These tests were written for the old (legacy) builder.
|
||||
// They need to be updated to work with the new flow editor.
|
||||
|
||||
// Note: all the comments with //(number)! are for the docs
|
||||
//ignore them when reading the code, but if you change something,
|
||||
//make sure to update the docs! Your autoformmater will break this page,
|
||||
// so don't run it on this file.
|
||||
// --8<-- [start:BuildPageExample]
|
||||
|
||||
import test from "@playwright/test";
|
||||
import test, { expect } from "@playwright/test";
|
||||
import { BuildPage } from "./pages/build.page";
|
||||
import { LoginPage } from "./pages/login.page";
|
||||
import { hasUrl } from "./utils/assertion";
|
||||
import { getTestUser } from "./utils/auth";
|
||||
|
||||
// Reason Ignore: admonishment is in the wrong place visually with correct prettier rules
|
||||
// prettier-ignore
|
||||
test.describe.skip("Build", () => { //(1)!
|
||||
let buildPage: BuildPage; //(2)!
|
||||
test.describe("Builder", () => {
|
||||
let buildPage: BuildPage;
|
||||
|
||||
// Reason Ignore: admonishment is in the wrong place visually with correct prettier rules
|
||||
// prettier-ignore
|
||||
test.beforeEach(async ({ page }) => { //(3)! ts-ignore
|
||||
test.setTimeout(25000);
|
||||
test.beforeEach(async ({ page }) => {
|
||||
test.setTimeout(60000);
|
||||
const loginPage = new LoginPage(page);
|
||||
const testUser = await getTestUser();
|
||||
|
||||
buildPage = new BuildPage(page);
|
||||
|
||||
// Start each test with login using worker auth
|
||||
await page.goto("/login"); //(4)!
|
||||
await page.goto("/login");
|
||||
await loginPage.login(testUser.email, testUser.password);
|
||||
await hasUrl(page, "/marketplace"); //(5)!
|
||||
await buildPage.navbar.clickBuildLink();
|
||||
await hasUrl(page, "/build");
|
||||
await hasUrl(page, "/marketplace");
|
||||
|
||||
await page.goto("/build");
|
||||
await page.waitForLoadState("domcontentloaded");
|
||||
await buildPage.closeTutorial();
|
||||
});
|
||||
|
||||
// Helper function to add blocks starting with a specific letter, split into parts for parallelization
|
||||
async function addBlocksStartingWithSplit(letter: string, part: number, totalParts: number): Promise<void> {
|
||||
const blockIdsToSkip = await buildPage.getBlocksToSkip();
|
||||
const blockTypesToSkip = ["Input", "Output", "Agent", "AI"];
|
||||
const targetLetter = letter.toLowerCase();
|
||||
|
||||
const allBlocks = await buildPage.getFilteredBlocksFromAPI(block =>
|
||||
block.name[0].toLowerCase() === targetLetter &&
|
||||
!blockIdsToSkip.includes(block.id) &&
|
||||
!blockTypesToSkip.includes(block.type)
|
||||
);
|
||||
// --- Core tests ---
|
||||
|
||||
const blocksToAdd = allBlocks.filter((_, index) =>
|
||||
index % totalParts === (part - 1)
|
||||
);
|
||||
|
||||
console.log(`Adding ${blocksToAdd.length} blocks starting with "${letter}" (part ${part}/${totalParts})`);
|
||||
|
||||
for (const block of blocksToAdd) {
|
||||
await buildPage.addBlock(block);
|
||||
}
|
||||
}
|
||||
|
||||
// Reason Ignore: admonishment is in the wrong place visually with correct prettier rules
|
||||
// prettier-ignore
|
||||
test("user can add a block", async ({ page: _page }) => { //(6)!
|
||||
await buildPage.openBlocksPanel(); //(10)!
|
||||
const blocks = await buildPage.getFilteredBlocksFromAPI(block => block.name[0].toLowerCase() === "a");
|
||||
const block = blocks.at(-1);
|
||||
if (!block) throw new Error("No block found");
|
||||
|
||||
await buildPage.addBlock(block); //(11)!
|
||||
await buildPage.closeBlocksPanel(); //(12)!
|
||||
await buildPage.hasBlock(block); //(13)!
|
||||
});
|
||||
// --8<-- [end:BuildPageExample]
|
||||
|
||||
test("user can add blocks starting with a (part 1)", async () => {
|
||||
await addBlocksStartingWithSplit("a", 1, 2);
|
||||
test("build page loads successfully", async () => {
|
||||
await expect(buildPage.isLoaded()).resolves.toBeTruthy();
|
||||
await expect(
|
||||
buildPage.getPlaywrightPage().getByTestId("blocks-control-blocks-button"),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
buildPage.getPlaywrightPage().getByTestId("save-control-save-button"),
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("user can add blocks starting with a (part 2)", async () => {
|
||||
await addBlocksStartingWithSplit("a", 2, 2);
|
||||
test("user can add a block via block menu", async () => {
|
||||
const initialCount = await buildPage.getNodeCount();
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(initialCount + 1);
|
||||
expect(await buildPage.getNodeCount()).toBe(initialCount + 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with b", async () => {
|
||||
await addBlocksStartingWithSplit("b", 1, 1);
|
||||
test("user can add multiple blocks", async () => {
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(2);
|
||||
|
||||
expect(await buildPage.getNodeCount()).toBe(2);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with c", async () => {
|
||||
await addBlocksStartingWithSplit("c", 1, 1);
|
||||
test("user can remove a block", async () => {
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
|
||||
// Deselect, then re-select the node and delete
|
||||
await buildPage.clickCanvas();
|
||||
await buildPage.selectNode(0);
|
||||
await buildPage.deleteSelectedNodes();
|
||||
|
||||
await expect(buildPage.getNodeLocator()).toHaveCount(0, { timeout: 5000 });
|
||||
});
|
||||
|
||||
test("user can add blocks starting with d", async () => {
|
||||
await addBlocksStartingWithSplit("d", 1, 1);
|
||||
test("user can save an agent", async ({ page }) => {
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
|
||||
await buildPage.saveAgent("E2E Test Agent", "Created by e2e test");
|
||||
await buildPage.waitForSaveComplete();
|
||||
|
||||
expect(page.url()).toContain("flowID=");
|
||||
});
|
||||
|
||||
test("user can add blocks starting with e", async () => {
|
||||
test.setTimeout(60000); // Increase timeout for many Exa blocks
|
||||
await addBlocksStartingWithSplit("e", 1, 2);
|
||||
});
|
||||
test("user can save and run button becomes enabled", async () => {
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
|
||||
test("user can add blocks starting with e pt 2", async () => {
|
||||
test.setTimeout(60000); // Increase timeout for many Exa blocks
|
||||
await addBlocksStartingWithSplit("e", 2, 2);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with f", async () => {
|
||||
await addBlocksStartingWithSplit("f", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with g (part 1)", async () => {
|
||||
await addBlocksStartingWithSplit("g", 1, 3);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with g (part 2)", async () => {
|
||||
await addBlocksStartingWithSplit("g", 2, 3);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with g (part 3)", async () => {
|
||||
await addBlocksStartingWithSplit("g", 3, 3);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with h", async () => {
|
||||
await addBlocksStartingWithSplit("h", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with i", async () => {
|
||||
await addBlocksStartingWithSplit("i", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with j", async () => {
|
||||
await addBlocksStartingWithSplit("j", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with k", async () => {
|
||||
await addBlocksStartingWithSplit("k", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with l", async () => {
|
||||
await addBlocksStartingWithSplit("l", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with m", async () => {
|
||||
await addBlocksStartingWithSplit("m", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with n", async () => {
|
||||
await addBlocksStartingWithSplit("n", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with o", async () => {
|
||||
await addBlocksStartingWithSplit("o", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with p", async () => {
|
||||
await addBlocksStartingWithSplit("p", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with q", async () => {
|
||||
await addBlocksStartingWithSplit("q", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with r", async () => {
|
||||
await addBlocksStartingWithSplit("r", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with s (part 1)", async () => {
|
||||
await addBlocksStartingWithSplit("s", 1, 3);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with s (part 2)", async () => {
|
||||
await addBlocksStartingWithSplit("s", 2, 3);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with s (part 3)", async () => {
|
||||
await addBlocksStartingWithSplit("s", 3, 3);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with t", async () => {
|
||||
await addBlocksStartingWithSplit("t", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with u", async () => {
|
||||
await addBlocksStartingWithSplit("u", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with v", async () => {
|
||||
await addBlocksStartingWithSplit("v", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with w", async () => {
|
||||
await addBlocksStartingWithSplit("w", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with x", async () => {
|
||||
await addBlocksStartingWithSplit("x", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with y", async () => {
|
||||
await addBlocksStartingWithSplit("y", 1, 1);
|
||||
});
|
||||
|
||||
test("user can add blocks starting with z", async () => {
|
||||
await addBlocksStartingWithSplit("z", 1, 1);
|
||||
});
|
||||
|
||||
test("build navigation is accessible from navbar", async ({ page }) => {
|
||||
// Navigate somewhere else first
|
||||
await page.goto("/marketplace"); //(4)!
|
||||
|
||||
// Check that navigation to the Builder is available on the page
|
||||
await buildPage.navbar.clickBuildLink();
|
||||
|
||||
await hasUrl(page, "/build");
|
||||
await test.expect(buildPage.isLoaded()).resolves.toBeTruthy();
|
||||
});
|
||||
|
||||
test("user can add two blocks and connect them", async ({ page }) => {
|
||||
await buildPage.openBlocksPanel();
|
||||
|
||||
// Define the blocks to add
|
||||
const block1 = {
|
||||
id: "1ff065e9-88e8-4358-9d82-8dc91f622ba9",
|
||||
name: "Store Value 1",
|
||||
description: "Store Value Block 1",
|
||||
type: "Standard",
|
||||
};
|
||||
const block2 = {
|
||||
id: "1ff065e9-88e8-4358-9d82-8dc91f622ba9",
|
||||
name: "Store Value 2",
|
||||
description: "Store Value Block 2",
|
||||
type: "Standard",
|
||||
};
|
||||
|
||||
// Add the blocks
|
||||
await buildPage.addBlock(block1);
|
||||
await buildPage.addBlock(block2);
|
||||
await buildPage.closeBlocksPanel();
|
||||
|
||||
// Connect the blocks
|
||||
await buildPage.connectBlockOutputToBlockInputViaDataId(
|
||||
"1-1-output-source",
|
||||
"1-2-input-target",
|
||||
);
|
||||
|
||||
// Fill in the input for the first block
|
||||
await buildPage.fillBlockInputByPlaceholder(
|
||||
block1.id,
|
||||
"Enter input",
|
||||
"Test Value",
|
||||
"1",
|
||||
);
|
||||
|
||||
// Save the agent and wait for the URL to update
|
||||
await buildPage.saveAgent(
|
||||
"Connected Blocks Test",
|
||||
"Testing block connections",
|
||||
);
|
||||
await test.expect(page).toHaveURL(({ searchParams }) => !!searchParams.get("flowID"));
|
||||
|
||||
// Wait for the save button to be enabled again
|
||||
await buildPage.saveAgent("Runnable Agent", "Test run button");
|
||||
await buildPage.waitForSaveComplete();
|
||||
await buildPage.waitForSaveButton();
|
||||
|
||||
// Ensure the run button is enabled
|
||||
await test.expect(buildPage.isRunButtonEnabled()).resolves.toBeTruthy();
|
||||
await expect(buildPage.isRunButtonEnabled()).resolves.toBeTruthy();
|
||||
});
|
||||
|
||||
test.skip("user can build an agent with inputs and output blocks", async ({ page }, testInfo) => {
|
||||
test.setTimeout(testInfo.timeout * 10);
|
||||
// --- Copy / Paste test ---
|
||||
|
||||
// prep
|
||||
await buildPage.openBlocksPanel();
|
||||
test("user can copy and paste a node", async ({ context }) => {
|
||||
await context.grantPermissions(["clipboard-read", "clipboard-write"]);
|
||||
|
||||
// Get input block from Input category
|
||||
const inputBlocks = await buildPage.getBlocksForCategory("Input");
|
||||
const inputBlock = inputBlocks.find((b) => b.name === "Agent Input");
|
||||
if (!inputBlock) throw new Error("Input block not found");
|
||||
await buildPage.addBlock(inputBlock);
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
|
||||
// Get output block from Output category
|
||||
const outputBlocks = await buildPage.getBlocksForCategory("Output");
|
||||
const outputBlock = outputBlocks.find((b) => b.name === "Agent Output");
|
||||
if (!outputBlock) throw new Error("Output block not found");
|
||||
await buildPage.addBlock(outputBlock);
|
||||
await buildPage.selectNode(0);
|
||||
await buildPage.copyViaKeyboard();
|
||||
await buildPage.pasteViaKeyboard();
|
||||
|
||||
// Get calculator block from Logic category
|
||||
const logicBlocks = await buildPage.getBlocksForCategory("Logic");
|
||||
const calculatorBlock = logicBlocks.find((b) => b.name === "Calculator");
|
||||
if (!calculatorBlock) throw new Error("Calculator block not found");
|
||||
await buildPage.addBlock(calculatorBlock);
|
||||
await buildPage.waitForNodeOnCanvas(2);
|
||||
expect(await buildPage.getNodeCount()).toBe(2);
|
||||
});
|
||||
|
||||
await buildPage.closeBlocksPanel();
|
||||
// --- Run agent test ---
|
||||
|
||||
// Wait for blocks to be fully loaded
|
||||
await page.waitForTimeout(1000);
|
||||
test("user can run an agent from the builder", async () => {
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
|
||||
// Wait for blocks to be ready for connections
|
||||
await page.waitForTimeout(1000);
|
||||
// Save the agent (required before running)
|
||||
await buildPage.saveAgent("Run Test Agent", "Testing run from builder");
|
||||
await buildPage.waitForSaveComplete();
|
||||
await buildPage.waitForSaveButton();
|
||||
|
||||
await buildPage.connectBlockOutputToBlockInputViaName(
|
||||
inputBlock.id,
|
||||
"Result",
|
||||
calculatorBlock.id,
|
||||
"A",
|
||||
);
|
||||
await buildPage.connectBlockOutputToBlockInputViaName(
|
||||
inputBlock.id,
|
||||
"Result",
|
||||
calculatorBlock.id,
|
||||
"B",
|
||||
);
|
||||
await buildPage.connectBlockOutputToBlockInputViaName(
|
||||
calculatorBlock.id,
|
||||
"Result",
|
||||
outputBlock.id,
|
||||
"Value",
|
||||
);
|
||||
// Click run button
|
||||
await buildPage.clickRunButton();
|
||||
|
||||
// Wait for connections to stabilize
|
||||
await page.waitForTimeout(1000);
|
||||
// Either the run dialog appears or the agent starts running directly
|
||||
const runDialogOrRunning = await Promise.race([
|
||||
buildPage
|
||||
.getPlaywrightPage()
|
||||
.locator('[data-id="run-input-dialog-content"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 })
|
||||
.then(() => "dialog"),
|
||||
buildPage
|
||||
.getPlaywrightPage()
|
||||
.locator('[data-id="stop-graph-button"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 })
|
||||
.then(() => "running"),
|
||||
]).catch(() => "timeout");
|
||||
|
||||
await buildPage.fillBlockInputByPlaceholder(
|
||||
inputBlock.id,
|
||||
"Enter Name",
|
||||
"Value",
|
||||
);
|
||||
await buildPage.fillBlockInputByPlaceholder(
|
||||
outputBlock.id,
|
||||
"Enter Name",
|
||||
"Doubled",
|
||||
);
|
||||
|
||||
// Wait before changing dropdown
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await buildPage.selectBlockInputValue(
|
||||
calculatorBlock.id,
|
||||
"Operation",
|
||||
"Add",
|
||||
);
|
||||
|
||||
// Wait before saving
|
||||
await page.waitForTimeout(1000);
|
||||
|
||||
await buildPage.saveAgent(
|
||||
"Input and Output Blocks Test",
|
||||
"Testing input and output blocks",
|
||||
);
|
||||
await test.expect(page).toHaveURL(({ searchParams }) => !!searchParams.get("flowID"));
|
||||
|
||||
// Wait for save to complete
|
||||
await page.waitForTimeout(1000);
|
||||
|
||||
// await buildPage.runAgent();
|
||||
// await buildPage.fillRunDialog({
|
||||
// Value: "10",
|
||||
// });
|
||||
// await buildPage.clickRunDialogRunButton();
|
||||
// await buildPage.waitForCompletionBadge();
|
||||
// await test
|
||||
// .expect(buildPage.isCompletionBadgeVisible())
|
||||
// .resolves.toBeTruthy();
|
||||
expect(["dialog", "running"]).toContain(runDialogOrRunning);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,44 +1,47 @@
|
||||
import { Locator, Page } from "@playwright/test";
|
||||
import { Block as APIBlock } from "../../lib/autogpt-server-api/types";
|
||||
import { beautifyString } from "../../lib/utils";
|
||||
import { expect, Locator, Page } from "@playwright/test";
|
||||
import { BasePage } from "./base.page";
|
||||
|
||||
export interface Block {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class BuildPage extends BasePage {
|
||||
private cachedBlocks: Record<string, Block> = {};
|
||||
|
||||
constructor(page: Page) {
|
||||
super(page);
|
||||
}
|
||||
|
||||
private getDisplayName(blockName: string): string {
|
||||
return beautifyString(blockName).replace(/ Block$/, "");
|
||||
// --- Navigation ---
|
||||
|
||||
async goto(): Promise<void> {
|
||||
await this.page.goto("/build");
|
||||
await this.page.waitForLoadState("domcontentloaded");
|
||||
}
|
||||
|
||||
async isLoaded(): Promise<boolean> {
|
||||
try {
|
||||
await this.page.waitForLoadState("domcontentloaded", { timeout: 10_000 });
|
||||
await this.page
|
||||
.locator(".react-flow")
|
||||
.waitFor({ state: "visible", timeout: 10_000 });
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async closeTutorial(): Promise<void> {
|
||||
console.log(`closing tutorial`);
|
||||
try {
|
||||
await this.page
|
||||
.getByRole("button", { name: "Skip Tutorial", exact: true })
|
||||
.click({ timeout: 3000 });
|
||||
} catch (_error) {
|
||||
console.info("Tutorial not shown or already dismissed");
|
||||
} catch {
|
||||
// Tutorial not shown or already dismissed
|
||||
}
|
||||
}
|
||||
|
||||
// --- Block Menu ---
|
||||
|
||||
async openBlocksPanel(): Promise<void> {
|
||||
const popoverContent = this.page.locator(
|
||||
'[data-id="blocks-control-popover-content"]',
|
||||
);
|
||||
const isPanelOpen = await popoverContent.isVisible();
|
||||
|
||||
if (!isPanelOpen) {
|
||||
if (!(await popoverContent.isVisible())) {
|
||||
await this.page.getByTestId("blocks-control-blocks-button").click();
|
||||
await popoverContent.waitFor({ state: "visible", timeout: 5000 });
|
||||
}
|
||||
@@ -50,501 +53,258 @@ export class BuildPage extends BasePage {
|
||||
);
|
||||
if (await popoverContent.isVisible()) {
|
||||
await this.page.getByTestId("blocks-control-blocks-button").click();
|
||||
await popoverContent.waitFor({ state: "hidden", timeout: 5000 });
|
||||
}
|
||||
}
|
||||
|
||||
async searchBlock(searchTerm: string): Promise<void> {
|
||||
const searchInput = this.page.locator(
|
||||
'[data-id="blocks-control-search-bar"] input[type="text"]',
|
||||
);
|
||||
await searchInput.clear();
|
||||
await searchInput.fill(searchTerm);
|
||||
await this.page.waitForTimeout(300);
|
||||
}
|
||||
|
||||
private getBlockCardByName(name: string): Locator {
|
||||
const escapedName = name.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||
const exactName = new RegExp(`^\\s*${escapedName}\\s*$`, "i");
|
||||
return this.page
|
||||
.locator('[data-id^="block-card-"]')
|
||||
.filter({ has: this.page.locator("span", { hasText: exactName }) })
|
||||
.first();
|
||||
}
|
||||
|
||||
async addBlockByClick(searchTerm: string): Promise<void> {
|
||||
await this.openBlocksPanel();
|
||||
await this.searchBlock(searchTerm);
|
||||
|
||||
// Wait for any search results to appear
|
||||
const anyCard = this.page.locator('[data-id^="block-card-"]').first();
|
||||
await anyCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Click the card matching the search term name
|
||||
const blockCard = this.getBlockCardByName(searchTerm);
|
||||
await blockCard.waitFor({ state: "visible", timeout: 5000 });
|
||||
await blockCard.click();
|
||||
|
||||
// Close the panel so it doesn't overlay the canvas
|
||||
await this.closeBlocksPanel();
|
||||
}
|
||||
|
||||
async dragBlockToCanvas(searchTerm: string): Promise<void> {
|
||||
await this.openBlocksPanel();
|
||||
await this.searchBlock(searchTerm);
|
||||
|
||||
const anyCard = this.page.locator('[data-id^="block-card-"]').first();
|
||||
await anyCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const blockCard = this.getBlockCardByName(searchTerm);
|
||||
await blockCard.waitFor({ state: "visible", timeout: 5000 });
|
||||
|
||||
const canvas = this.page.locator(".react-flow__pane").first();
|
||||
await blockCard.dragTo(canvas);
|
||||
}
|
||||
|
||||
// --- Nodes on Canvas ---
|
||||
|
||||
getNodeLocator(index?: number): Locator {
|
||||
const locator = this.page.locator('[data-id^="custom-node-"]');
|
||||
return index !== undefined ? locator.nth(index) : locator;
|
||||
}
|
||||
|
||||
async getNodeCount(): Promise<number> {
|
||||
return await this.getNodeLocator().count();
|
||||
}
|
||||
|
||||
async waitForNodeOnCanvas(expectedCount?: number): Promise<void> {
|
||||
if (expectedCount !== undefined) {
|
||||
await expect(this.getNodeLocator()).toHaveCount(expectedCount, {
|
||||
timeout: 10000,
|
||||
});
|
||||
} else {
|
||||
await this.getNodeLocator()
|
||||
.first()
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
}
|
||||
}
|
||||
|
||||
async selectNode(index: number = 0): Promise<void> {
|
||||
const node = this.getNodeLocator(index);
|
||||
await node.click();
|
||||
}
|
||||
|
||||
async selectAllNodes(): Promise<void> {
|
||||
await this.page.locator(".react-flow__pane").first().click();
|
||||
const isMac = process.platform === "darwin";
|
||||
await this.page.keyboard.press(isMac ? "Meta+a" : "Control+a");
|
||||
}
|
||||
|
||||
async deleteSelectedNodes(): Promise<void> {
|
||||
await this.page.keyboard.press("Backspace");
|
||||
}
|
||||
|
||||
// --- Connections (Edges) ---
|
||||
|
||||
async connectNodes(
|
||||
sourceNodeIndex: number,
|
||||
targetNodeIndex: number,
|
||||
): Promise<void> {
|
||||
// Get the node wrapper elements to scope handle search
|
||||
const sourceNode = this.getNodeLocator(sourceNodeIndex);
|
||||
const targetNode = this.getNodeLocator(targetNodeIndex);
|
||||
|
||||
// ReactFlow renders Handle components as .react-flow__handle elements
|
||||
// Output handles have class .react-flow__handle-right (Position.Right)
|
||||
// Input handles have class .react-flow__handle-left (Position.Left)
|
||||
const sourceHandle = sourceNode
|
||||
.locator(".react-flow__handle-right")
|
||||
.first();
|
||||
const targetHandle = targetNode.locator(".react-flow__handle-left").first();
|
||||
|
||||
// Get precise center coordinates using evaluate to avoid CSS transform issues
|
||||
const getHandleCenter = async (locator: Locator) => {
|
||||
const el = await locator.elementHandle();
|
||||
if (!el) throw new Error("Handle element not found");
|
||||
const rect = await el.evaluate((node) => {
|
||||
const r = node.getBoundingClientRect();
|
||||
return { x: r.x + r.width / 2, y: r.y + r.height / 2 };
|
||||
});
|
||||
return rect;
|
||||
};
|
||||
|
||||
const source = await getHandleCenter(sourceHandle);
|
||||
const target = await getHandleCenter(targetHandle);
|
||||
|
||||
// ReactFlow requires a proper drag sequence with intermediate moves
|
||||
await this.page.mouse.move(source.x, source.y);
|
||||
await this.page.mouse.down();
|
||||
// Move in steps to trigger ReactFlow's connection detection
|
||||
const steps = 20;
|
||||
for (let i = 1; i <= steps; i++) {
|
||||
const ratio = i / steps;
|
||||
await this.page.mouse.move(
|
||||
source.x + (target.x - source.x) * ratio,
|
||||
source.y + (target.y - source.y) * ratio,
|
||||
);
|
||||
}
|
||||
await this.page.mouse.up();
|
||||
}
|
||||
|
||||
async getEdgeCount(): Promise<number> {
|
||||
return await this.page.locator(".react-flow__edge").count();
|
||||
}
|
||||
|
||||
// --- Save ---
|
||||
|
||||
async saveAgent(
|
||||
name: string = "Test Agent",
|
||||
description: string = "",
|
||||
): Promise<void> {
|
||||
console.log(`Saving agent '${name}' with description '${description}'`);
|
||||
await this.page.getByTestId("save-control-save-button").click();
|
||||
await this.page.getByTestId("save-control-name-input").fill(name);
|
||||
await this.page
|
||||
.getByTestId("save-control-description-input")
|
||||
.fill(description);
|
||||
|
||||
const nameInput = this.page.getByTestId("save-control-name-input");
|
||||
await nameInput.waitFor({ state: "visible", timeout: 5000 });
|
||||
await nameInput.fill(name);
|
||||
|
||||
if (description) {
|
||||
await this.page
|
||||
.getByTestId("save-control-description-input")
|
||||
.fill(description);
|
||||
}
|
||||
|
||||
await this.page.getByTestId("save-control-save-agent-button").click();
|
||||
}
|
||||
|
||||
async getBlocksFromAPI(): Promise<Block[]> {
|
||||
if (Object.keys(this.cachedBlocks).length > 0) {
|
||||
return Object.values(this.cachedBlocks);
|
||||
}
|
||||
async waitForSaveComplete(): Promise<void> {
|
||||
await expect(this.page).toHaveURL(/flowID=/, { timeout: 15000 });
|
||||
}
|
||||
|
||||
console.log(`Getting blocks from API request`);
|
||||
|
||||
// Make direct API request using the page's request context
|
||||
const response = await this.page.request.get(
|
||||
"http://localhost:3000/api/proxy/api/blocks",
|
||||
async waitForSaveButton(): Promise<void> {
|
||||
await this.page.waitForSelector(
|
||||
'[data-testid="save-control-save-button"]:not([disabled])',
|
||||
{ timeout: 10000 },
|
||||
);
|
||||
const apiBlocks: APIBlock[] = await response.json();
|
||||
|
||||
console.log(`Found ${apiBlocks.length} blocks from API`);
|
||||
|
||||
// Convert API blocks to test Block format
|
||||
const blocks = apiBlocks.map((block) => ({
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
description: block.description,
|
||||
type: block.uiType,
|
||||
}));
|
||||
|
||||
this.cachedBlocks = blocks.reduce(
|
||||
(acc, block) => {
|
||||
acc[block.id] = block;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, Block>,
|
||||
);
|
||||
return blocks;
|
||||
}
|
||||
|
||||
async getFilteredBlocksFromAPI(
|
||||
filterFn: (block: Block) => boolean,
|
||||
): Promise<Block[]> {
|
||||
console.log(`Getting filtered blocks from API`);
|
||||
const blocks = await this.getBlocksFromAPI();
|
||||
return blocks.filter(filterFn);
|
||||
}
|
||||
|
||||
async addBlock(block: Block): Promise<void> {
|
||||
console.log(`Adding block ${block.name} (${block.id}) to agent`);
|
||||
|
||||
await this.openBlocksPanel();
|
||||
|
||||
const searchInput = this.page.locator(
|
||||
'[data-id="blocks-control-search-bar"] input[type="text"]',
|
||||
);
|
||||
|
||||
const displayName = this.getDisplayName(block.name);
|
||||
await searchInput.clear();
|
||||
await searchInput.fill(displayName);
|
||||
|
||||
const blockCardId = block.id.replace(/[^a-zA-Z0-9]/g, "");
|
||||
const blockCard = this.page.locator(
|
||||
`[data-id="block-card-${blockCardId}"]`,
|
||||
);
|
||||
|
||||
await blockCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
await blockCard.click();
|
||||
}
|
||||
|
||||
async hasBlock(_block: Block) {
|
||||
// In the new flow editor, verify a node exists on the canvas
|
||||
const node = this.page.locator('[data-id^="custom-node-"]').first();
|
||||
await node.isVisible();
|
||||
}
|
||||
|
||||
async getBlockInputs(blockId: string): Promise<string[]> {
|
||||
console.log(`Getting block ${blockId} inputs`);
|
||||
try {
|
||||
const node = this.page.locator(`[data-blockid="${blockId}"]`).first();
|
||||
const inputsData = await node.getAttribute("data-inputs");
|
||||
return inputsData ? JSON.parse(inputsData) : [];
|
||||
} catch (error) {
|
||||
console.error("Error getting block inputs:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async selectBlockCategory(category: string): Promise<void> {
|
||||
console.log(`Selecting block category: ${category}`);
|
||||
await this.page.getByText(category, { exact: true }).click();
|
||||
// Wait for the blocks to load after category selection
|
||||
await this.page.waitForTimeout(3000);
|
||||
}
|
||||
|
||||
async getBlocksForCategory(category: string): Promise<Block[]> {
|
||||
console.log(`Getting blocks for category: ${category}`);
|
||||
|
||||
// Clear any existing search to ensure we see all blocks in the category
|
||||
const searchInput = this.page.locator(
|
||||
'[data-id="blocks-control-search-bar"] input[type="text"]',
|
||||
);
|
||||
await searchInput.clear();
|
||||
|
||||
// Wait for search to clear
|
||||
await this.page.waitForTimeout(300);
|
||||
|
||||
// Select the category first
|
||||
await this.selectBlockCategory(category);
|
||||
|
||||
try {
|
||||
const blockFinder = this.page.locator('[data-id^="block-card-"]');
|
||||
await blockFinder.first().waitFor();
|
||||
const blocks = await blockFinder.all();
|
||||
|
||||
console.log(`found ${blocks.length} blocks in category ${category}`);
|
||||
|
||||
const results = await Promise.all(
|
||||
blocks.map(async (block) => {
|
||||
try {
|
||||
const fullId = (await block.getAttribute("data-id")) || "";
|
||||
const id = fullId.replace("block-card-", "");
|
||||
const nameElement = block.locator('[data-testid^="block-name-"]');
|
||||
const descriptionElement = block.locator(
|
||||
'[data-testid^="block-description-"]',
|
||||
);
|
||||
|
||||
const name = (await nameElement.textContent()) || "";
|
||||
const description = (await descriptionElement.textContent()) || "";
|
||||
const type = (await nameElement.getAttribute("data-type")) || "";
|
||||
|
||||
return {
|
||||
id,
|
||||
name: name.trim(),
|
||||
type: type.trim(),
|
||||
description: description.trim(),
|
||||
};
|
||||
} catch (elementError) {
|
||||
console.error("Error processing block:", elementError);
|
||||
return null;
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Filter out any null results from errors
|
||||
return results.filter((block): block is Block => block !== null);
|
||||
} catch (error) {
|
||||
console.error(`Error getting blocks for category ${category}:`, error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async _buildBlockSelector(blockId: string, dataId?: string): Promise<string> {
|
||||
const selector = dataId
|
||||
? `[data-id="${dataId}"] [data-blockid="${blockId}"]`
|
||||
: `[data-blockid="${blockId}"]`;
|
||||
return selector;
|
||||
}
|
||||
|
||||
private async moveBlockToViewportPosition(
|
||||
blockSelector: string,
|
||||
options: { xRatio?: number; yRatio?: number } = {},
|
||||
): Promise<void> {
|
||||
const { xRatio = 0.5, yRatio = 0.5 } = options;
|
||||
const blockLocator = this.page.locator(blockSelector).first();
|
||||
|
||||
await blockLocator.waitFor({ state: "visible" });
|
||||
|
||||
const boundingBox = await blockLocator.boundingBox();
|
||||
const viewport = this.page.viewportSize();
|
||||
|
||||
if (!boundingBox || !viewport) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentX = boundingBox.x + boundingBox.width / 2;
|
||||
const currentY = boundingBox.y + boundingBox.height / 2;
|
||||
|
||||
const targetX = viewport.width * xRatio;
|
||||
const targetY = viewport.height * yRatio;
|
||||
|
||||
const distance = Math.hypot(targetX - currentX, targetY - currentY);
|
||||
|
||||
if (distance < 5) {
|
||||
return;
|
||||
}
|
||||
|
||||
await this.page.mouse.move(currentX, currentY);
|
||||
await this.page.mouse.down();
|
||||
await this.page.mouse.move(targetX, targetY, { steps: 15 });
|
||||
await this.page.mouse.up();
|
||||
await this.page.waitForTimeout(200);
|
||||
}
|
||||
|
||||
async getBlockById(blockId: string, dataId?: string): Promise<Locator> {
|
||||
console.log(`getting block ${blockId} with dataId ${dataId}`);
|
||||
return this.page.locator(await this._buildBlockSelector(blockId, dataId));
|
||||
}
|
||||
|
||||
// dataId is optional, if provided, it will start the search with that container, otherwise it will start with the blockId
|
||||
// this is useful if you have multiple blocks with the same id, but different dataIds which you should have when adding a block to the graph.
|
||||
// Do note that once you run an agent, the dataId will change, so you will need to update the tests to use the new dataId or not use the same block in tests that run an agent
|
||||
async fillBlockInputByPlaceholder(
|
||||
blockId: string,
|
||||
placeholder: string,
|
||||
value: string,
|
||||
dataId?: string,
|
||||
): Promise<void> {
|
||||
console.log(
|
||||
`filling block input ${placeholder} with value ${value} of block ${blockId}`,
|
||||
);
|
||||
const block = await this.getBlockById(blockId, dataId);
|
||||
const input = block.getByPlaceholder(placeholder);
|
||||
await input.fill(value);
|
||||
}
|
||||
|
||||
async selectBlockInputValue(
|
||||
blockId: string,
|
||||
inputName: string,
|
||||
value: string,
|
||||
dataId?: string,
|
||||
): Promise<void> {
|
||||
console.log(
|
||||
`selecting value ${value} for input ${inputName} of block ${blockId}`,
|
||||
);
|
||||
// First get the button that opens the dropdown
|
||||
const baseSelector = await this._buildBlockSelector(blockId, dataId);
|
||||
|
||||
// Find the combobox button within the input handle container
|
||||
const comboboxSelector = `${baseSelector} [data-id="input-handle-${inputName.toLowerCase()}"] button[role="combobox"]`;
|
||||
|
||||
try {
|
||||
// Click the combobox to open it
|
||||
await this.page.click(comboboxSelector);
|
||||
|
||||
// Wait a moment for the dropdown to open
|
||||
await this.page.waitForTimeout(100);
|
||||
|
||||
// Select the option from the dropdown
|
||||
// The actual selector for the option might need adjustment based on the dropdown structure
|
||||
await this.page.getByRole("option", { name: value }).click();
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error selecting value "${value}" for input "${inputName}":`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async fillBlockInputByLabel(
|
||||
blockId: string,
|
||||
label: string,
|
||||
value: string,
|
||||
): Promise<void> {
|
||||
console.log(`filling block input ${label} with value ${value}`);
|
||||
const block = await this.getBlockById(blockId);
|
||||
const input = block.getByLabel(label);
|
||||
await input.fill(value);
|
||||
}
|
||||
|
||||
async connectBlockOutputToBlockInputViaDataId(
|
||||
blockOutputId: string,
|
||||
blockInputId: string,
|
||||
): Promise<void> {
|
||||
console.log(
|
||||
`connecting block output ${blockOutputId} to block input ${blockInputId}`,
|
||||
);
|
||||
try {
|
||||
// Locate the output element
|
||||
const outputElement = this.page.locator(`[data-id="${blockOutputId}"]`);
|
||||
// Locate the input element
|
||||
const inputElement = this.page.locator(`[data-id="${blockInputId}"]`);
|
||||
|
||||
await outputElement.dragTo(inputElement);
|
||||
} catch (error) {
|
||||
console.error("Error connecting block output to input:", error);
|
||||
}
|
||||
}
|
||||
|
||||
async connectBlockOutputToBlockInputViaName(
|
||||
startBlockId: string,
|
||||
startBlockOutputName: string,
|
||||
endBlockId: string,
|
||||
endBlockInputName: string,
|
||||
startDataId?: string,
|
||||
endDataId?: string,
|
||||
): Promise<void> {
|
||||
console.log(
|
||||
`connecting block output ${startBlockOutputName} of block ${startBlockId} to block input ${endBlockInputName} of block ${endBlockId}`,
|
||||
);
|
||||
|
||||
const startBlockBase = await this._buildBlockSelector(
|
||||
startBlockId,
|
||||
startDataId,
|
||||
);
|
||||
|
||||
const endBlockBase = await this._buildBlockSelector(endBlockId, endDataId);
|
||||
|
||||
await this.moveBlockToViewportPosition(startBlockBase, { xRatio: 0.35 });
|
||||
await this.moveBlockToViewportPosition(endBlockBase, { xRatio: 0.65 });
|
||||
|
||||
const startBlockOutputSelector = `${startBlockBase} [data-testid="output-handle-${startBlockOutputName.toLowerCase()}"]`;
|
||||
const endBlockInputSelector = `${endBlockBase} [data-testid="input-handle-${endBlockInputName.toLowerCase()}"]`;
|
||||
|
||||
console.log("Start block selector:", startBlockOutputSelector);
|
||||
console.log("End block selector:", endBlockInputSelector);
|
||||
|
||||
const startElement = this.page.locator(startBlockOutputSelector);
|
||||
const endElement = this.page.locator(endBlockInputSelector);
|
||||
|
||||
await startElement.scrollIntoViewIfNeeded();
|
||||
await this.page.waitForTimeout(200);
|
||||
|
||||
await endElement.scrollIntoViewIfNeeded();
|
||||
await this.page.waitForTimeout(200);
|
||||
|
||||
await startElement.dragTo(endElement);
|
||||
}
|
||||
|
||||
async isLoaded(): Promise<boolean> {
|
||||
console.log(`checking if build page is loaded`);
|
||||
try {
|
||||
await this.page.waitForLoadState("domcontentloaded", { timeout: 10_000 });
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// --- Run ---
|
||||
|
||||
async isRunButtonEnabled(): Promise<boolean> {
|
||||
console.log(`checking if run button is enabled`);
|
||||
const runButton = this.page.locator('[data-id="run-graph-button"]');
|
||||
return await runButton.isEnabled();
|
||||
}
|
||||
|
||||
async runAgent(): Promise<void> {
|
||||
console.log(`clicking run button`);
|
||||
async clickRunButton(): Promise<void> {
|
||||
const runButton = this.page.locator('[data-id="run-graph-button"]');
|
||||
await runButton.click();
|
||||
await this.page.waitForTimeout(1000);
|
||||
await runButton.click();
|
||||
}
|
||||
|
||||
async fillRunDialog(inputs: Record<string, string>): Promise<void> {
|
||||
console.log(`filling run dialog`);
|
||||
for (const [key, value] of Object.entries(inputs)) {
|
||||
await this.page.getByTestId(`agent-input-${key}`).fill(value);
|
||||
// --- Undo / Redo ---
|
||||
|
||||
async isUndoEnabled(): Promise<boolean> {
|
||||
const btn = this.page.locator('[data-id="undo-button"]');
|
||||
return !(await btn.isDisabled());
|
||||
}
|
||||
|
||||
async isRedoEnabled(): Promise<boolean> {
|
||||
const btn = this.page.locator('[data-id="redo-button"]');
|
||||
return !(await btn.isDisabled());
|
||||
}
|
||||
|
||||
async clickUndo(): Promise<void> {
|
||||
await this.page.locator('[data-id="undo-button"]').click();
|
||||
}
|
||||
|
||||
async clickRedo(): Promise<void> {
|
||||
await this.page.locator('[data-id="redo-button"]').click();
|
||||
}
|
||||
|
||||
// --- Copy / Paste ---
|
||||
|
||||
async copyViaKeyboard(): Promise<void> {
|
||||
const isMac = process.platform === "darwin";
|
||||
await this.page.keyboard.press(isMac ? "Meta+c" : "Control+c");
|
||||
}
|
||||
|
||||
async pasteViaKeyboard(): Promise<void> {
|
||||
const isMac = process.platform === "darwin";
|
||||
await this.page.keyboard.press(isMac ? "Meta+v" : "Control+v");
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
async fillBlockInputByPlaceholder(
|
||||
placeholder: string,
|
||||
value: string,
|
||||
nodeIndex: number = 0,
|
||||
): Promise<void> {
|
||||
const node = this.getNodeLocator(nodeIndex);
|
||||
const input = node.getByPlaceholder(placeholder);
|
||||
await input.fill(value);
|
||||
}
|
||||
|
||||
async clickCanvas(): Promise<void> {
|
||||
const pane = this.page.locator(".react-flow__pane").first();
|
||||
const box = await pane.boundingBox();
|
||||
if (box) {
|
||||
// Click in the center of the canvas to avoid sidebar/toolbar overlaps
|
||||
await pane.click({
|
||||
position: { x: box.width / 2, y: box.height / 2 },
|
||||
});
|
||||
} else {
|
||||
await pane.click();
|
||||
}
|
||||
}
|
||||
async clickRunDialogRunButton(): Promise<void> {
|
||||
console.log(`clicking run button`);
|
||||
await this.page.getByTestId("agent-run-button").click();
|
||||
|
||||
getPlaywrightPage(): Page {
|
||||
return this.page;
|
||||
}
|
||||
|
||||
async waitForCompletionBadge(): Promise<void> {
|
||||
console.log(`waiting for completion badge`);
|
||||
await this.page.waitForSelector(
|
||||
'[data-id^="badge-"][data-id$="-COMPLETED"]',
|
||||
);
|
||||
}
|
||||
|
||||
async waitForSaveButton(): Promise<void> {
|
||||
console.log(`waiting for save button`);
|
||||
await this.page.waitForSelector(
|
||||
'[data-testid="save-control-save-button"]:not([disabled])',
|
||||
);
|
||||
}
|
||||
|
||||
async isCompletionBadgeVisible(): Promise<boolean> {
|
||||
console.log(`checking for completion badge`);
|
||||
const completionBadge = this.page
|
||||
.locator('[data-id^="badge-"][data-id$="-COMPLETED"]')
|
||||
.first();
|
||||
return await completionBadge.isVisible();
|
||||
}
|
||||
|
||||
async waitForVersionField(): Promise<void> {
|
||||
console.log(`waiting for version field`);
|
||||
|
||||
// wait for the url to have the flowID
|
||||
await this.page.waitForSelector(
|
||||
'[data-testid="save-control-version-output"]',
|
||||
);
|
||||
}
|
||||
|
||||
async getDictionaryBlockDetails(): Promise<Block> {
|
||||
return {
|
||||
id: "dummy-id-1",
|
||||
name: "Add to Dictionary",
|
||||
description: "Add to Dictionary",
|
||||
type: "Standard",
|
||||
};
|
||||
}
|
||||
|
||||
async getCalculatorBlockDetails(): Promise<Block> {
|
||||
return {
|
||||
id: "dummy-id-2",
|
||||
name: "Calculator",
|
||||
description: "Calculator",
|
||||
type: "Standard",
|
||||
};
|
||||
}
|
||||
|
||||
async waitForSaveDialogClose(): Promise<void> {
|
||||
console.log(`waiting for save dialog to close`);
|
||||
|
||||
await this.page.waitForSelector(
|
||||
'[data-id="save-control-popover-content"]',
|
||||
{ state: "hidden" },
|
||||
);
|
||||
}
|
||||
|
||||
async getGithubTriggerBlockDetails(): Promise<Block[]> {
|
||||
return [
|
||||
{
|
||||
id: "6c60ec01-8128-419e-988f-96a063ee2fea",
|
||||
name: "Github Trigger",
|
||||
description:
|
||||
"This block triggers on pull request events and outputs the event type and payload.",
|
||||
type: "Standard",
|
||||
},
|
||||
{
|
||||
id: "551e0a35-100b-49b7-89b8-3031322239b6",
|
||||
name: "Github Star Trigger",
|
||||
description:
|
||||
"This block triggers on star events and outputs the event type and payload.",
|
||||
type: "Standard",
|
||||
},
|
||||
{
|
||||
id: "2052dd1b-74e1-46ac-9c87-c7a0e057b60b",
|
||||
name: "Github Release Trigger",
|
||||
description:
|
||||
"This block triggers on release events and outputs the event type and payload.",
|
||||
type: "Standard",
|
||||
},
|
||||
{
|
||||
id: "b2605464-e486-4bf4-aad3-d8a213c8a48a",
|
||||
name: "Github Issue Trigger",
|
||||
description:
|
||||
"This block triggers on issue events and outputs the event type and payload.",
|
||||
type: "Standard",
|
||||
},
|
||||
{
|
||||
id: "87f847b3-d81a-424e-8e89-acadb5c9d52b",
|
||||
name: "Github Discussion Trigger",
|
||||
description:
|
||||
"This block triggers on discussion events and outputs the event type and payload.",
|
||||
type: "Standard",
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
async nextTutorialStep(): Promise<void> {
|
||||
console.log(`clicking next tutorial step`);
|
||||
await this.page.getByRole("button", { name: "Next" }).click();
|
||||
}
|
||||
|
||||
async getBlocksToSkip(): Promise<string[]> {
|
||||
return [
|
||||
(await this.getGithubTriggerBlockDetails()).map((b) => b.id),
|
||||
// MCP Tool block requires an interactive dialog (server URL + OAuth) before
|
||||
// it can be placed, so it can't be tested via the standard "add block" flow.
|
||||
"a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
||||
].flat();
|
||||
}
|
||||
|
||||
async createDummyAgent() {
|
||||
async createDummyAgent(): Promise<void> {
|
||||
await this.closeTutorial();
|
||||
await this.openBlocksPanel();
|
||||
|
||||
const searchInput = this.page.locator(
|
||||
'[data-id="blocks-control-search-bar"] input[type="text"]',
|
||||
);
|
||||
|
||||
await searchInput.clear();
|
||||
await searchInput.fill("Add to Dictionary");
|
||||
|
||||
const blockCard = this.page.locator('[data-id^="block-card-"]').first();
|
||||
try {
|
||||
await blockCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
await blockCard.click();
|
||||
} catch (error) {
|
||||
console.log("Could not find Add to Dictionary block:", error);
|
||||
}
|
||||
|
||||
await this.addBlockByClick("Add to Dictionary");
|
||||
await this.waitForNodeOnCanvas(1);
|
||||
await this.saveAgent("Test Agent", "Test Description");
|
||||
await this.waitForSaveComplete();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user