fix(backend/copilot): harden file-ref parsing — size guards, edge cases, sniffer fallback

- Add pre-read size check in read_file_bytes for local files
- Add OpenpyxlInvalidFile to parse_file_content exception list
- Add csv.Sniffer fallback for misidentified delimiters
- Add application/yaml MIME mapping
- Fix _is_tabular to reject empty header rows
- Fix _adapt_to_schema dict→List[str] not wrapping incorrectly
- Fix _apply_line_range to note when range exceeds file
- Remove inconsistent ChatSession string quoting
- Add clarifying comments for budget/size check ordering
- Add tests for all new behaviors
This commit is contained in:
Zamil Majdy
2026-03-14 21:54:52 +07:00
parent eafac037c2
commit 4934f7a766
4 changed files with 189 additions and 10 deletions

View File

@@ -121,14 +121,22 @@ def parse_file_ref(text: str) -> FileRef | None:
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
"""Slice *text* to the requested 1-indexed line range (inclusive).
When the requested range extends beyond the file, a note is appended
so the LLM knows it received the entire remaining content.
"""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
total = len(lines)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
e = end if end is not None else total
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
result = "".join(selected)
if end is not None and end > total:
result += f"\n[Note: file has only {total} lines]\n"
return result
async def read_file_bytes(
@@ -162,6 +170,11 @@ async def read_file_bytes(
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
size = os.path.getsize(resolved)
if size > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large ({size} bytes, limit {_MAX_BARE_REF_BYTES})"
)
with open(resolved, "rb") as fh:
return fh.read()
except FileNotFoundError:
@@ -222,7 +235,7 @@ async def resolve_file_ref(
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
session: ChatSession,
*,
raise_on_error: bool = False,
) -> str:
@@ -268,6 +281,9 @@ async def expand_file_refs_in_string(
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
# remaining == 0 means the budget was exactly exhausted by the
# previous ref. The elif below (len > remaining) won't catch
# this since 0 > 0 is false, so we need the <= 0 check.
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
@@ -322,6 +338,7 @@ def _is_tabular(parsed: Any) -> bool:
isinstance(parsed, list)
and len(parsed) >= 2
and all(isinstance(row, list) for row in parsed)
and len(parsed[0]) >= 1
and all(isinstance(h, str) for h in parsed[0])
)
@@ -377,6 +394,11 @@ def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
list_values = [v for v in parsed.values() if isinstance(v, list)]
if list_values:
return list_values
if items_type == "string":
# Target is List[str] — wrapping a dict would give [dict]
# which can't coerce to strings. Return unchanged and let
# pydantic surface a clear validation error.
return parsed
# Fallback: wrap in a single-element list so the block gets [dict]
# instead of pydantic flattening keys/values into a flat list.
return [parsed]
@@ -435,6 +457,9 @@ async def _expand_bare_ref(
except ValueError as exc:
raise FileRefExpansionError(str(exc)) from exc
# For known formats this rejects files >10 MB before parsing.
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
# but this check still guards the parsing path which has no char limit.
_check_content_size(content)
# When the schema declares this parameter as "string",
@@ -491,7 +516,7 @@ async def _expand_bare_ref(
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
session: ChatSession,
*,
input_schema: dict[str, Any] | None = None,
) -> dict[str, Any]:

View File

@@ -111,7 +111,8 @@ def test_apply_line_range_single_line():
def test_apply_line_range_beyond_eof():
result = _apply_line_range(TEXT, 4, 999)
assert result == "line4\nline5\n"
assert "line4\nline5\n" in result
assert "[Note: file has only 5 lines]" in result
# ---------------------------------------------------------------------------
@@ -1797,3 +1798,98 @@ async def test_non_media_string_field_still_reads_content():
)
assert result["text"] == "file content here"
# ---------------------------------------------------------------------------
# _apply_line_range — range exceeds file
# ---------------------------------------------------------------------------
def test_apply_line_range_beyond_eof_note():
"""When the requested end line exceeds the file, a note is appended."""
result = _apply_line_range(TEXT, 4, 999)
assert "line4" in result
assert "line5" in result
assert "[Note: file has only 5 lines]" in result
# ---------------------------------------------------------------------------
# _is_tabular — edge cases
# ---------------------------------------------------------------------------
def test_is_tabular_empty_header():
"""Empty inner lists should NOT be considered tabular."""
from backend.copilot.sdk.file_ref import _is_tabular
assert _is_tabular([[], []]) is False
# ---------------------------------------------------------------------------
# _adapt_to_schema — non-tabular list + object target
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_adapt_non_tabular_list_to_object_target():
"""Non-tabular list with object target type passes through unchanged."""
json_content = "[1, 2, 3]"
async def _resolve(ref, *a, **kw): # noqa: ARG001
return json_content
schema = {
"type": "object",
"properties": {
"data": {"type": "object"},
},
}
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve),
):
result = await expand_file_refs_in_args(
{"data": "@@agptfile:workspace:///data.json"},
user_id="u1",
session=_make_session(),
input_schema=schema,
)
# Non-tabular list should pass through unchanged (not adapted)
assert result["data"] == [1, 2, 3]
# ---------------------------------------------------------------------------
# _adapt_to_schema — dict → List[str] target should NOT wrap
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_adapt_dict_to_list_str_target_not_wrapped():
"""Dict with List[str] target should not be wrapped in [dict]."""
yaml_content = "key: value"
async def _resolve(ref, *a, **kw): # noqa: ARG001
return yaml_content
schema = {
"type": "object",
"properties": {
"data": {"type": "array", "items": {"type": "string"}},
},
}
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve),
):
result = await expand_file_refs_in_args(
{"data": "@@agptfile:workspace:///config.yaml"},
user_id="u1",
session=_make_session(),
input_schema=schema,
)
# Dict should pass through unchanged, not wrapped in [dict]
assert result["data"] == {"key": "value"}

View File

@@ -21,9 +21,11 @@ Supported formats:
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
All parsers follow the **fallback contract**: if parsing fails for *any* reason,
the original content is returned unchanged (string for text formats, bytes for
binary formats). Callers should never see an exception from this module.
The **fallback contract** is enforced by :func:`parse_file_content`, not by
individual parser functions. If any parser raises, ``parse_file_content``
catches the exception and returns the original content unchanged (string for
text formats, bytes for binary formats). Callers should never see an
exception from the public API when ``strict=False``.
"""
import csv
@@ -40,6 +42,7 @@ from posixpath import splitext
from typing import Any
import yaml
from openpyxl.utils.exceptions import InvalidFileException as OpenpyxlInvalidFile
logger = logging.getLogger(__name__)
@@ -67,6 +70,7 @@ MIME_TO_FORMAT: dict[str, str] = {
"text/csv": "csv",
"text/tab-separated-values": "tsv",
"application/x-yaml": "yaml",
"application/yaml": "yaml",
"text/yaml": "yaml",
"application/toml": "toml",
"application/vnd.apache.parquet": "parquet",
@@ -146,7 +150,19 @@ def _parse_tsv(content: str) -> Any:
def _parse_delimited(content: str, *, delimiter: str) -> Any:
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
rows = [row for row in reader if row]
# Require ≥1 row and ≥2 columns to qualify as tabular data.
if not rows:
return content
# If the declared delimiter produces only single-column rows, try
# sniffing the actual delimiter — catches misidentified files (e.g.
# a tab-delimited file with a .csv extension).
if len(rows[0]) == 1:
try:
dialect = csv.Sniffer().sniff(content[:8192])
if dialect.delimiter != delimiter:
reader = csv.reader(io.StringIO(content), dialect)
rows = [row for row in reader if row]
except csv.Error:
pass
if rows and len(rows[0]) >= 2:
return rows
return content
@@ -266,6 +282,7 @@ def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False)
KeyError,
TypeError,
zipfile.BadZipFile,
OpenpyxlInvalidFile,
):
if strict:
raise

View File

@@ -466,3 +466,44 @@ class TestBinaryFormats:
def test_text_formats_not_binary(self):
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
assert fmt not in BINARY_FORMATS
# ---------------------------------------------------------------------------
# MIME mapping
# ---------------------------------------------------------------------------
class TestMimeMapping:
def test_application_yaml(self):
assert infer_format("workspace://abc123#application/yaml") == "yaml"
# ---------------------------------------------------------------------------
# CSV sniffer fallback
# ---------------------------------------------------------------------------
class TestCsvSnifferFallback:
def test_tab_delimited_with_csv_format(self):
"""Tab-delimited content parsed as csv should use sniffer fallback."""
content = "Name\tScore\nAlice\t90\nBob\t85"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
def test_sniffer_failure_returns_content(self):
"""When sniffer fails, single-column falls back to raw content."""
content = "Name\nAlice\nBob"
result = parse_file_content(content, "csv")
assert result == content
# ---------------------------------------------------------------------------
# OpenpyxlInvalidFile fallback
# ---------------------------------------------------------------------------
class TestOpenpyxlFallback:
def test_invalid_xlsx_non_strict(self):
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
result = parse_file_content(b"not xlsx bytes", "xlsx")
assert result == b"not xlsx bytes"