mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend/copilot): harden file-ref parsing — size guards, edge cases, sniffer fallback
- Add pre-read size check in read_file_bytes for local files - Add OpenpyxlInvalidFile to parse_file_content exception list - Add csv.Sniffer fallback for misidentified delimiters - Add application/yaml MIME mapping - Fix _is_tabular to reject empty header rows - Fix _adapt_to_schema dict→List[str] not wrapping incorrectly - Fix _apply_line_range to note when range exceeds file - Remove inconsistent ChatSession string quoting - Add clarifying comments for budget/size check ordering - Add tests for all new behaviors
This commit is contained in:
@@ -121,14 +121,22 @@ def parse_file_ref(text: str) -> FileRef | None:
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
@@ -162,6 +170,11 @@ async def read_file_bytes(
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
size = os.path.getsize(resolved)
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
with open(resolved, "rb") as fh:
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
@@ -222,7 +235,7 @@ async def resolve_file_ref(
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
session: ChatSession,
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
@@ -268,6 +281,9 @@ async def expand_file_refs_in_string(
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
@@ -322,6 +338,7 @@ def _is_tabular(parsed: Any) -> bool:
|
||||
isinstance(parsed, list)
|
||||
and len(parsed) >= 2
|
||||
and all(isinstance(row, list) for row in parsed)
|
||||
and len(parsed[0]) >= 1
|
||||
and all(isinstance(h, str) for h in parsed[0])
|
||||
)
|
||||
|
||||
@@ -377,6 +394,11 @@ def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
@@ -435,6 +457,9 @@ async def _expand_bare_ref(
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
_check_content_size(content)
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
@@ -491,7 +516,7 @@ async def _expand_bare_ref(
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -111,7 +111,8 @@ def test_apply_line_range_single_line():
|
||||
|
||||
def test_apply_line_range_beyond_eof():
|
||||
result = _apply_line_range(TEXT, 4, 999)
|
||||
assert result == "line4\nline5\n"
|
||||
assert "line4\nline5\n" in result
|
||||
assert "[Note: file has only 5 lines]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1797,3 +1798,98 @@ async def test_non_media_string_field_still_reads_content():
|
||||
)
|
||||
|
||||
assert result["text"] == "file content here"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _apply_line_range — range exceeds file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_apply_line_range_beyond_eof_note():
|
||||
"""When the requested end line exceeds the file, a note is appended."""
|
||||
result = _apply_line_range(TEXT, 4, 999)
|
||||
assert "line4" in result
|
||||
assert "line5" in result
|
||||
assert "[Note: file has only 5 lines]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_tabular — edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_tabular_empty_header():
|
||||
"""Empty inner lists should NOT be considered tabular."""
|
||||
from backend.copilot.sdk.file_ref import _is_tabular
|
||||
|
||||
assert _is_tabular([[], []]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _adapt_to_schema — non-tabular list + object target
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapt_non_tabular_list_to_object_target():
|
||||
"""Non-tabular list with object target type passes through unchanged."""
|
||||
json_content = "[1, 2, 3]"
|
||||
|
||||
async def _resolve(ref, *a, **kw): # noqa: ARG001
|
||||
return json_content
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "object"},
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": "@@agptfile:workspace:///data.json"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
input_schema=schema,
|
||||
)
|
||||
|
||||
# Non-tabular list should pass through unchanged (not adapted)
|
||||
assert result["data"] == [1, 2, 3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _adapt_to_schema — dict → List[str] target should NOT wrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapt_dict_to_list_str_target_not_wrapped():
|
||||
"""Dict with List[str] target should not be wrapped in [dict]."""
|
||||
yaml_content = "key: value"
|
||||
|
||||
async def _resolve(ref, *a, **kw): # noqa: ARG001
|
||||
return yaml_content
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": "@@agptfile:workspace:///config.yaml"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
input_schema=schema,
|
||||
)
|
||||
|
||||
# Dict should pass through unchanged, not wrapped in [dict]
|
||||
assert result["data"] == {"key": "value"}
|
||||
|
||||
@@ -21,9 +21,11 @@ Supported formats:
|
||||
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
|
||||
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
|
||||
|
||||
All parsers follow the **fallback contract**: if parsing fails for *any* reason,
|
||||
the original content is returned unchanged (string for text formats, bytes for
|
||||
binary formats). Callers should never see an exception from this module.
|
||||
The **fallback contract** is enforced by :func:`parse_file_content`, not by
|
||||
individual parser functions. If any parser raises, ``parse_file_content``
|
||||
catches the exception and returns the original content unchanged (string for
|
||||
text formats, bytes for binary formats). Callers should never see an
|
||||
exception from the public API when ``strict=False``.
|
||||
"""
|
||||
|
||||
import csv
|
||||
@@ -40,6 +42,7 @@ from posixpath import splitext
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from openpyxl.utils.exceptions import InvalidFileException as OpenpyxlInvalidFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,6 +70,7 @@ MIME_TO_FORMAT: dict[str, str] = {
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "tsv",
|
||||
"application/x-yaml": "yaml",
|
||||
"application/yaml": "yaml",
|
||||
"text/yaml": "yaml",
|
||||
"application/toml": "toml",
|
||||
"application/vnd.apache.parquet": "parquet",
|
||||
@@ -146,7 +150,19 @@ def _parse_tsv(content: str) -> Any:
|
||||
def _parse_delimited(content: str, *, delimiter: str) -> Any:
|
||||
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
|
||||
rows = [row for row in reader if row]
|
||||
# Require ≥1 row and ≥2 columns to qualify as tabular data.
|
||||
if not rows:
|
||||
return content
|
||||
# If the declared delimiter produces only single-column rows, try
|
||||
# sniffing the actual delimiter — catches misidentified files (e.g.
|
||||
# a tab-delimited file with a .csv extension).
|
||||
if len(rows[0]) == 1:
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(content[:8192])
|
||||
if dialect.delimiter != delimiter:
|
||||
reader = csv.reader(io.StringIO(content), dialect)
|
||||
rows = [row for row in reader if row]
|
||||
except csv.Error:
|
||||
pass
|
||||
if rows and len(rows[0]) >= 2:
|
||||
return rows
|
||||
return content
|
||||
@@ -266,6 +282,7 @@ def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False)
|
||||
KeyError,
|
||||
TypeError,
|
||||
zipfile.BadZipFile,
|
||||
OpenpyxlInvalidFile,
|
||||
):
|
||||
if strict:
|
||||
raise
|
||||
|
||||
@@ -466,3 +466,44 @@ class TestBinaryFormats:
|
||||
def test_text_formats_not_binary(self):
|
||||
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
|
||||
assert fmt not in BINARY_FORMATS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MIME mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMimeMapping:
|
||||
def test_application_yaml(self):
|
||||
assert infer_format("workspace://abc123#application/yaml") == "yaml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSV sniffer fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCsvSnifferFallback:
|
||||
def test_tab_delimited_with_csv_format(self):
|
||||
"""Tab-delimited content parsed as csv should use sniffer fallback."""
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_sniffer_failure_returns_content(self):
|
||||
"""When sniffer fails, single-column falls back to raw content."""
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenpyxlInvalidFile fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenpyxlFallback:
|
||||
def test_invalid_xlsx_non_strict(self):
|
||||
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
Reference in New Issue
Block a user