mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-06 04:45:10 -05:00
refactor: Focus binary detection on embedded base64 in stdout_logs
Previous implementation detected data URIs and pure base64 strings, but these cases already worked in CoPilot. The actual problem was code execution printing base64 to stdout with markers. Changes: - Remove data URI detection (no user-facing impact) - Remove pure base64 detection (covered by embedded detection) - Add embedded base64 pattern matching within text strings - Handle marker patterns (---BASE64_START---/---BASE64_END---) - Replace base64 + markers with workspace:// reference This now solves the actual problem: ExecuteCodeBlock stdout_logs containing embedded base64 that would otherwise cost 17k+ output tokens for the LLM to re-type.
This commit is contained in:
@@ -1,15 +1,14 @@
|
||||
"""
|
||||
Content-based detection and saving of binary data in block outputs.
|
||||
Detect and save embedded binary data in block outputs.
|
||||
|
||||
This module post-processes block execution outputs to detect and save binary
|
||||
content (images, PDFs) to the workspace, returning workspace:// references
|
||||
instead of raw base64 data. This reduces LLM output token usage by ~97% for
|
||||
Scans stdout_logs and other string outputs for embedded base64 patterns,
|
||||
saves detected binary content to workspace, and replaces the base64 with
|
||||
workspace:// references. This reduces LLM output token usage by ~97% for
|
||||
file generation tasks.
|
||||
|
||||
Detection is content-based (not field-name based) because:
|
||||
- Code execution blocks return base64 in stdout_logs, not structured fields
|
||||
- The png/jpeg/pdf fields only populate from Jupyter display mechanisms
|
||||
- Other blocks use various field names: image, result, output, response, etc.
|
||||
Primary use case: ExecuteCodeBlock prints base64 to stdout, which appears
|
||||
in stdout_logs. Without this processor, the LLM would re-type the entire
|
||||
base64 string when saving files.
|
||||
"""
|
||||
|
||||
import base64
|
||||
@@ -25,40 +24,21 @@ from backend.util.workspace import WorkspaceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only process strings larger than this (filters out tokens, hashes, short strings)
|
||||
SIZE_THRESHOLD = 1024 # 1KB
|
||||
# Minimum decoded size to process (filters out small base64 strings)
|
||||
MIN_DECODED_SIZE = 1024 # 1KB
|
||||
|
||||
# Data URI pattern with mimetype extraction
|
||||
DATA_URI_PATTERN = re.compile(
|
||||
r"^data:([a-zA-Z0-9.+-]+/[a-zA-Z0-9.+-]+);base64,(.+)$",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Only process these mimetypes from data URIs (avoid text/plain, etc.)
|
||||
ALLOWED_MIMETYPES = {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg", # Non-standard but sometimes used
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/svg+xml",
|
||||
"application/pdf",
|
||||
"application/octet-stream",
|
||||
}
|
||||
|
||||
# Base64 character validation (strict - must be pure base64)
|
||||
# Allows whitespace which will be stripped before decoding (RFC 2045 line wrapping)
|
||||
BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/\s]+=*$")
|
||||
# Pattern to find base64 chunks in text (at least 100 chars to be worth checking)
|
||||
# Matches continuous base64 characters, optionally ending with = padding
|
||||
EMBEDDED_BASE64_PATTERN = re.compile(r"[A-Za-z0-9+/]{100,}={0,2}")
|
||||
|
||||
# Magic numbers for binary file detection
|
||||
# Note: WebP requires two-step detection: RIFF prefix + WEBP at offset 8
|
||||
MAGIC_SIGNATURES = [
|
||||
(b"\x89PNG\r\n\x1a\n", "png"),
|
||||
(b"\xff\xd8\xff", "jpg"),
|
||||
(b"%PDF-", "pdf"),
|
||||
(b"GIF87a", "gif"),
|
||||
(b"GIF89a", "gif"),
|
||||
(b"RIFF", "webp"), # Special case: also check content[8:12] == b'WEBP'
|
||||
(b"RIFF", "webp"), # Also check content[8:12] == b'WEBP'
|
||||
]
|
||||
|
||||
|
||||
@@ -68,12 +48,8 @@ async def process_binary_outputs(
|
||||
block_name: str,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Scan all string values in outputs and replace detected binary content
|
||||
with workspace:// references.
|
||||
|
||||
Uses content-based detection (data URIs, magic numbers) to find binary
|
||||
data regardless of field name. Deduplicates identical content within
|
||||
a single call using content hashing.
|
||||
Scan all string values in outputs for embedded base64 binary content.
|
||||
Save detected binaries to workspace and replace with references.
|
||||
|
||||
Args:
|
||||
outputs: Block execution outputs (dict of output_name -> list of values)
|
||||
@@ -81,7 +57,7 @@ async def process_binary_outputs(
|
||||
block_name: Name of the block (used in generated filenames)
|
||||
|
||||
Returns:
|
||||
Processed outputs with binary data replaced by workspace references
|
||||
Processed outputs with embedded base64 replaced by workspace references
|
||||
"""
|
||||
cache: dict[str, str] = {} # content_hash -> workspace_ref
|
||||
|
||||
@@ -102,7 +78,7 @@ async def _process_value(
|
||||
block: str,
|
||||
cache: dict[str, str],
|
||||
) -> Any:
|
||||
"""Recursively process a value, detecting binary content in strings."""
|
||||
"""Recursively process a value, detecting embedded base64 in strings."""
|
||||
if isinstance(value, dict):
|
||||
result = {}
|
||||
for k, v in value.items():
|
||||
@@ -110,78 +86,68 @@ async def _process_value(
|
||||
return result
|
||||
if isinstance(value, list):
|
||||
return [await _process_value(v, wm, block, cache) for v in value]
|
||||
if isinstance(value, str) and len(value) > SIZE_THRESHOLD:
|
||||
return await _try_detect_and_save(value, wm, block, cache)
|
||||
if isinstance(value, str) and len(value) > MIN_DECODED_SIZE:
|
||||
return await _extract_and_replace_base64(value, wm, block, cache)
|
||||
return value
|
||||
|
||||
|
||||
async def _try_detect_and_save(
|
||||
value: str,
|
||||
async def _extract_and_replace_base64(
|
||||
text: str,
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
cache: dict[str, str],
|
||||
) -> str:
|
||||
"""Attempt to detect binary content and save it. Returns original if not binary."""
|
||||
|
||||
# Try data URI first (highest confidence - explicit mimetype)
|
||||
result = _detect_data_uri(value)
|
||||
if result:
|
||||
content, ext = result
|
||||
return await _save_binary(content, ext, wm, block, cache, value)
|
||||
|
||||
# Try raw base64 with magic number detection
|
||||
result = _detect_raw_base64(value)
|
||||
if result:
|
||||
content, ext = result
|
||||
return await _save_binary(content, ext, wm, block, cache, value)
|
||||
|
||||
return value # Not binary, return unchanged
|
||||
|
||||
|
||||
def _detect_data_uri(value: str) -> Optional[tuple[bytes, str]]:
|
||||
"""
|
||||
Detect data URI with whitelisted mimetype.
|
||||
Find embedded base64 in text, save binaries, replace with references.
|
||||
|
||||
Returns (content, extension) or None.
|
||||
Scans for base64 patterns, validates each as binary via magic numbers,
|
||||
saves valid binaries to workspace, and replaces the base64 portion
|
||||
(plus any surrounding markers) with the workspace reference.
|
||||
"""
|
||||
match = DATA_URI_PATTERN.match(value)
|
||||
if not match:
|
||||
return None
|
||||
result = text
|
||||
offset = 0
|
||||
|
||||
mimetype, b64_payload = match.groups()
|
||||
if mimetype not in ALLOWED_MIMETYPES:
|
||||
return None
|
||||
for match in EMBEDDED_BASE64_PATTERN.finditer(text):
|
||||
b64_str = match.group(0)
|
||||
|
||||
# Try to decode and validate
|
||||
detection = _decode_and_validate(b64_str)
|
||||
if detection is None:
|
||||
continue
|
||||
|
||||
content, ext = detection
|
||||
|
||||
# Save to workspace
|
||||
ref = await _save_binary(content, ext, wm, block, cache)
|
||||
if ref is None:
|
||||
continue
|
||||
|
||||
# Calculate replacement bounds (include surrounding markers if present)
|
||||
start, end = match.start(), match.end()
|
||||
start, end = _expand_to_markers(text, start, end)
|
||||
|
||||
# Apply replacement with offset adjustment
|
||||
adj_start = start + offset
|
||||
adj_end = end + offset
|
||||
result = result[:adj_start] + ref + result[adj_end:]
|
||||
offset += len(ref) - (end - start)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _decode_and_validate(b64_str: str) -> Optional[tuple[bytes, str]]:
|
||||
"""
|
||||
Decode base64 and validate it's a known binary format.
|
||||
|
||||
Returns (content, extension) if valid binary, None otherwise.
|
||||
"""
|
||||
try:
|
||||
content = base64.b64decode(b64_payload, validate=True)
|
||||
content = base64.b64decode(b64_str, validate=True)
|
||||
except (ValueError, binascii.Error):
|
||||
return None
|
||||
|
||||
ext = _mimetype_to_ext(mimetype)
|
||||
return content, ext
|
||||
|
||||
|
||||
def _detect_raw_base64(value: str) -> Optional[tuple[bytes, str]]:
|
||||
"""
|
||||
Detect raw base64 with magic number validation.
|
||||
|
||||
Only processes strings that:
|
||||
1. Look like pure base64 (regex pre-filter)
|
||||
2. Successfully decode as base64
|
||||
3. Start with a known binary file magic number
|
||||
|
||||
Returns (content, extension) or None.
|
||||
"""
|
||||
# Pre-filter: must look like base64 (allows whitespace for RFC 2045 line wrapping)
|
||||
if not BASE64_PATTERN.match(value):
|
||||
return None
|
||||
|
||||
# Strip whitespace before decoding (RFC 2045 allows line breaks in base64)
|
||||
normalized = re.sub(r"\s+", "", value)
|
||||
|
||||
try:
|
||||
content = base64.b64decode(normalized, validate=True)
|
||||
except (ValueError, binascii.Error):
|
||||
# Must meet minimum size
|
||||
if len(content) < MIN_DECODED_SIZE:
|
||||
return None
|
||||
|
||||
# Check magic numbers
|
||||
@@ -193,7 +159,47 @@ def _detect_raw_base64(value: str) -> Optional[tuple[bytes, str]]:
|
||||
continue
|
||||
return content, ext
|
||||
|
||||
return None # No magic number match = not a recognized binary format
|
||||
return None
|
||||
|
||||
|
||||
def _expand_to_markers(text: str, start: int, end: int) -> tuple[int, int]:
|
||||
"""
|
||||
Expand replacement bounds to include surrounding markers if present.
|
||||
|
||||
Handles patterns like:
|
||||
- ---BASE64_START---\\n{base64}\\n---BASE64_END---
|
||||
- [BASE64]{base64}[/BASE64]
|
||||
- Or just the raw base64
|
||||
"""
|
||||
# Common marker patterns to strip
|
||||
start_markers = [
|
||||
"---BASE64_START---\n",
|
||||
"---BASE64_START---",
|
||||
"[BASE64]\n",
|
||||
"[BASE64]",
|
||||
]
|
||||
end_markers = [
|
||||
"\n---BASE64_END---",
|
||||
"---BASE64_END---",
|
||||
"\n[/BASE64]",
|
||||
"[/BASE64]",
|
||||
]
|
||||
|
||||
# Check for start markers
|
||||
for marker in start_markers:
|
||||
marker_start = start - len(marker)
|
||||
if marker_start >= 0 and text[marker_start:start] == marker:
|
||||
start = marker_start
|
||||
break
|
||||
|
||||
# Check for end markers
|
||||
for marker in end_markers:
|
||||
marker_end = end + len(marker)
|
||||
if marker_end <= len(text) and text[end:marker_end] == marker:
|
||||
end = marker_end
|
||||
break
|
||||
|
||||
return start, end
|
||||
|
||||
|
||||
async def _save_binary(
|
||||
@@ -202,12 +208,11 @@ async def _save_binary(
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
cache: dict[str, str],
|
||||
original: str,
|
||||
) -> str:
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Save binary content to workspace with deduplication.
|
||||
|
||||
Returns workspace://file-id reference, or original value on failure.
|
||||
Returns workspace://file-id reference, or None on failure.
|
||||
"""
|
||||
content_hash = hashlib.sha256(content).hexdigest()
|
||||
|
||||
@@ -224,19 +229,4 @@ async def _save_binary(
|
||||
return ref
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save binary output: {e}")
|
||||
return original # Graceful degradation
|
||||
|
||||
|
||||
def _mimetype_to_ext(mimetype: str) -> str:
|
||||
"""Convert mimetype to file extension."""
|
||||
mapping = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/jpg": "jpg",
|
||||
"image/gif": "gif",
|
||||
"image/webp": "webp",
|
||||
"image/svg+xml": "svg",
|
||||
"application/pdf": "pdf",
|
||||
"application/octet-stream": "bin",
|
||||
}
|
||||
return mapping.get(mimetype, "bin")
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for content-based binary output detection and saving."""
|
||||
"""Tests for embedded binary detection in block outputs."""
|
||||
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
@@ -6,9 +6,8 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from .binary_output_processor import (
|
||||
_detect_data_uri,
|
||||
_detect_raw_base64,
|
||||
_mimetype_to_ext,
|
||||
_decode_and_validate,
|
||||
_expand_to_markers,
|
||||
process_binary_outputs,
|
||||
)
|
||||
|
||||
@@ -27,190 +26,125 @@ def mock_workspace_manager():
|
||||
return wm
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data URI Detection Tests
|
||||
# =============================================================================
|
||||
def _make_pdf_base64(size: int = 2000) -> str:
|
||||
"""Create a valid PDF base64 string of specified size."""
|
||||
pdf_content = b"%PDF-1.4 " + b"x" * size
|
||||
return base64.b64encode(pdf_content).decode()
|
||||
|
||||
|
||||
class TestDetectDataUri:
|
||||
"""Tests for _detect_data_uri function."""
|
||||
|
||||
def test_detects_png_data_uri(self):
|
||||
"""Should detect valid PNG data URI."""
|
||||
# Minimal valid PNG (1x1 transparent)
|
||||
png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
data_uri = f"data:image/png;base64,{png_b64}"
|
||||
|
||||
result = _detect_data_uri(data_uri)
|
||||
|
||||
assert result is not None
|
||||
content, ext = result
|
||||
assert ext == "png"
|
||||
assert content.startswith(b"\x89PNG")
|
||||
|
||||
def test_detects_pdf_data_uri(self):
|
||||
"""Should detect valid PDF data URI."""
|
||||
pdf_content = b"%PDF-1.4 test content"
|
||||
pdf_b64 = base64.b64encode(pdf_content).decode()
|
||||
data_uri = f"data:application/pdf;base64,{pdf_b64}"
|
||||
|
||||
result = _detect_data_uri(data_uri)
|
||||
|
||||
assert result is not None
|
||||
content, ext = result
|
||||
assert ext == "pdf"
|
||||
assert content == pdf_content
|
||||
|
||||
def test_rejects_text_plain_mimetype(self):
|
||||
"""Should reject text/plain mimetype (not in whitelist)."""
|
||||
text_b64 = base64.b64encode(b"Hello World").decode()
|
||||
data_uri = f"data:text/plain;base64,{text_b64}"
|
||||
|
||||
result = _detect_data_uri(data_uri)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_rejects_non_data_uri_string(self):
|
||||
"""Should return None for non-data-URI strings."""
|
||||
result = _detect_data_uri("https://example.com/image.png")
|
||||
assert result is None
|
||||
|
||||
def test_rejects_invalid_base64_in_data_uri(self):
|
||||
"""Should return None for data URI with invalid base64."""
|
||||
data_uri = "-valid-base64!!!"
|
||||
result = _detect_data_uri(data_uri)
|
||||
assert result is None
|
||||
|
||||
def test_handles_jpeg_mimetype(self):
|
||||
"""Should handle image/jpeg mimetype."""
|
||||
jpeg_content = b"\xff\xd8\xff\xe0test"
|
||||
jpeg_b64 = base64.b64encode(jpeg_content).decode()
|
||||
data_uri = f"data:image/jpeg;base64,{jpeg_b64}"
|
||||
|
||||
result = _detect_data_uri(data_uri)
|
||||
|
||||
assert result is not None
|
||||
_, ext = result
|
||||
assert ext == "jpg"
|
||||
def _make_png_base64(size: int = 2000) -> str:
|
||||
"""Create a valid PNG base64 string of specified size."""
|
||||
png_content = b"\x89PNG\r\n\x1a\n" + b"\x00" * size
|
||||
return base64.b64encode(png_content).decode()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Raw Base64 Detection Tests
|
||||
# Decode and Validate Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDetectRawBase64:
|
||||
"""Tests for _detect_raw_base64 function."""
|
||||
|
||||
def test_detects_png_magic_number(self):
|
||||
"""Should detect raw base64 PNG by magic number."""
|
||||
# Minimal valid PNG
|
||||
png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
result = _detect_raw_base64(png_b64)
|
||||
|
||||
assert result is not None
|
||||
content, ext = result
|
||||
assert ext == "png"
|
||||
assert content.startswith(b"\x89PNG")
|
||||
|
||||
def test_detects_jpeg_magic_number(self):
|
||||
"""Should detect raw base64 JPEG by magic number."""
|
||||
jpeg_content = b"\xff\xd8\xff\xe0" + b"\x00" * 100
|
||||
jpeg_b64 = base64.b64encode(jpeg_content).decode()
|
||||
|
||||
result = _detect_raw_base64(jpeg_b64)
|
||||
|
||||
assert result is not None
|
||||
_, ext = result
|
||||
assert ext == "jpg"
|
||||
class TestDecodeAndValidate:
|
||||
"""Tests for _decode_and_validate function."""
|
||||
|
||||
def test_detects_pdf_magic_number(self):
|
||||
"""Should detect raw base64 PDF by magic number."""
|
||||
pdf_content = b"%PDF-1.4 " + b"x" * 100
|
||||
pdf_b64 = base64.b64encode(pdf_content).decode()
|
||||
|
||||
result = _detect_raw_base64(pdf_b64)
|
||||
|
||||
"""Should detect valid PDF by magic number."""
|
||||
pdf_b64 = _make_pdf_base64()
|
||||
result = _decode_and_validate(pdf_b64)
|
||||
assert result is not None
|
||||
_, ext = result
|
||||
content, ext = result
|
||||
assert ext == "pdf"
|
||||
assert content.startswith(b"%PDF-")
|
||||
|
||||
def test_detects_gif87a_magic_number(self):
|
||||
"""Should detect GIF87a magic number."""
|
||||
gif_content = b"GIF87a" + b"\x00" * 100
|
||||
gif_b64 = base64.b64encode(gif_content).decode()
|
||||
|
||||
result = _detect_raw_base64(gif_b64)
|
||||
def test_detects_png_magic_number(self):
|
||||
"""Should detect valid PNG by magic number."""
|
||||
png_b64 = _make_png_base64()
|
||||
result = _decode_and_validate(png_b64)
|
||||
assert result is not None
|
||||
content, ext = result
|
||||
assert ext == "png"
|
||||
|
||||
def test_detects_jpeg_magic_number(self):
|
||||
"""Should detect valid JPEG by magic number."""
|
||||
jpeg_content = b"\xff\xd8\xff\xe0" + b"\x00" * 2000
|
||||
jpeg_b64 = base64.b64encode(jpeg_content).decode()
|
||||
result = _decode_and_validate(jpeg_b64)
|
||||
assert result is not None
|
||||
_, ext = result
|
||||
assert ext == "gif"
|
||||
assert ext == "jpg"
|
||||
|
||||
def test_detects_gif89a_magic_number(self):
|
||||
"""Should detect GIF89a magic number."""
|
||||
gif_content = b"GIF89a" + b"\x00" * 100
|
||||
def test_detects_gif_magic_number(self):
|
||||
"""Should detect valid GIF by magic number."""
|
||||
gif_content = b"GIF89a" + b"\x00" * 2000
|
||||
gif_b64 = base64.b64encode(gif_content).decode()
|
||||
|
||||
result = _detect_raw_base64(gif_b64)
|
||||
|
||||
result = _decode_and_validate(gif_b64)
|
||||
assert result is not None
|
||||
_, ext = result
|
||||
assert ext == "gif"
|
||||
|
||||
def test_detects_webp_magic_number(self):
|
||||
"""Should detect WebP (RIFF + WEBP at offset 8)."""
|
||||
# WebP header: RIFF + size (4 bytes) + WEBP
|
||||
webp_content = b"RIFF\x00\x00\x00\x00WEBP" + b"\x00" * 100
|
||||
"""Should detect valid WebP by magic number."""
|
||||
webp_content = b"RIFF\x00\x00\x00\x00WEBP" + b"\x00" * 2000
|
||||
webp_b64 = base64.b64encode(webp_content).decode()
|
||||
|
||||
result = _detect_raw_base64(webp_b64)
|
||||
|
||||
result = _decode_and_validate(webp_b64)
|
||||
assert result is not None
|
||||
_, ext = result
|
||||
assert ext == "webp"
|
||||
|
||||
def test_rejects_riff_without_webp(self):
|
||||
"""Should reject RIFF files that aren't WebP (e.g., WAV)."""
|
||||
wav_content = b"RIFF\x00\x00\x00\x00WAVE" + b"\x00" * 100
|
||||
wav_b64 = base64.b64encode(wav_content).decode()
|
||||
|
||||
result = _detect_raw_base64(wav_b64)
|
||||
|
||||
def test_rejects_small_content(self):
|
||||
"""Should reject content smaller than threshold."""
|
||||
small_pdf = b"%PDF-1.4 small"
|
||||
small_b64 = base64.b64encode(small_pdf).decode()
|
||||
result = _decode_and_validate(small_b64)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_non_base64_string(self):
|
||||
"""Should reject strings that don't look like base64."""
|
||||
result = _detect_raw_base64("Hello, this is regular text with spaces!")
|
||||
assert result is None
|
||||
|
||||
def test_rejects_base64_without_magic_number(self):
|
||||
"""Should reject valid base64 that doesn't have a known magic number."""
|
||||
random_content = b"This is just random text, not a binary file"
|
||||
def test_rejects_no_magic_number(self):
|
||||
"""Should reject content without recognized magic number."""
|
||||
random_content = b"This is just random text" * 100
|
||||
random_b64 = base64.b64encode(random_content).decode()
|
||||
|
||||
result = _detect_raw_base64(random_b64)
|
||||
|
||||
result = _decode_and_validate(random_b64)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_invalid_base64(self):
|
||||
"""Should return None for invalid base64."""
|
||||
result = _detect_raw_base64("not-valid-base64!!!")
|
||||
"""Should reject invalid base64."""
|
||||
result = _decode_and_validate("not-valid-base64!!!")
|
||||
assert result is None
|
||||
|
||||
def test_detects_base64_with_line_breaks(self):
|
||||
"""Should detect raw base64 with RFC 2045 line breaks."""
|
||||
png_content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
png_b64 = base64.b64encode(png_content).decode()
|
||||
# Simulate RFC 2045 line wrapping at 76 chars
|
||||
wrapped = png_b64[:76] + "\n" + png_b64[76:]
|
||||
def test_rejects_riff_without_webp(self):
|
||||
"""Should reject RIFF files that aren't WebP (e.g., WAV)."""
|
||||
wav_content = b"RIFF\x00\x00\x00\x00WAVE" + b"\x00" * 2000
|
||||
wav_b64 = base64.b64encode(wav_content).decode()
|
||||
result = _decode_and_validate(wav_b64)
|
||||
assert result is None
|
||||
|
||||
result = _detect_raw_base64(wrapped)
|
||||
|
||||
assert result is not None
|
||||
content, ext = result
|
||||
assert ext == "png"
|
||||
assert content == png_content
|
||||
# =============================================================================
|
||||
# Marker Expansion Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestExpandToMarkers:
|
||||
"""Tests for _expand_to_markers function."""
|
||||
|
||||
def test_expands_base64_start_end_markers(self):
|
||||
"""Should expand to include ---BASE64_START--- and ---BASE64_END---."""
|
||||
text = "prefix\n---BASE64_START---\nABCDEF\n---BASE64_END---\nsuffix"
|
||||
# Base64 is at position 27-33
|
||||
start, end = _expand_to_markers(text, 27, 33)
|
||||
assert text[start:end] == "---BASE64_START---\nABCDEF\n---BASE64_END---"
|
||||
|
||||
def test_expands_bracket_markers(self):
|
||||
"""Should expand to include [BASE64] and [/BASE64] markers."""
|
||||
text = "prefix[BASE64]ABCDEF[/BASE64]suffix"
|
||||
# Base64 is at position 14-20
|
||||
start, end = _expand_to_markers(text, 14, 20)
|
||||
assert text[start:end] == "[BASE64]ABCDEF[/BASE64]"
|
||||
|
||||
def test_no_expansion_without_markers(self):
|
||||
"""Should not expand if no markers present."""
|
||||
text = "prefix ABCDEF suffix"
|
||||
start, end = _expand_to_markers(text, 7, 13)
|
||||
assert start == 7
|
||||
assert end == 13
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -222,183 +156,131 @@ class TestProcessBinaryOutputs:
|
||||
"""Tests for process_binary_outputs function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saves_large_png_and_returns_reference(self, mock_workspace_manager):
|
||||
"""Should save PNG > 1KB and return workspace reference."""
|
||||
# Create PNG > 1KB
|
||||
png_header = b"\x89PNG\r\n\x1a\n"
|
||||
png_content = png_header + b"\x00" * 2000
|
||||
png_b64 = base64.b64encode(png_content).decode()
|
||||
async def test_detects_embedded_pdf_in_stdout_logs(self, mock_workspace_manager):
|
||||
"""Should detect and replace embedded PDF in stdout_logs."""
|
||||
pdf_b64 = _make_pdf_base64()
|
||||
stdout = f"PDF generated!\n---BASE64_START---\n{pdf_b64}\n---BASE64_END---\n"
|
||||
|
||||
outputs = {"result": [png_b64]}
|
||||
outputs = {"stdout_logs": [stdout]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
outputs, mock_workspace_manager, "ExecuteCodeBlock"
|
||||
)
|
||||
|
||||
assert result["result"][0].startswith("workspace://")
|
||||
# Should contain workspace reference, not base64
|
||||
assert "workspace://" in result["stdout_logs"][0]
|
||||
assert pdf_b64 not in result["stdout_logs"][0]
|
||||
assert "PDF generated!" in result["stdout_logs"][0]
|
||||
mock_workspace_manager.write_file.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_small_content(self, mock_workspace_manager):
|
||||
"""Should not process strings smaller than threshold."""
|
||||
small_content = "small"
|
||||
async def test_detects_embedded_png_without_markers(self, mock_workspace_manager):
|
||||
"""Should detect embedded PNG even without markers."""
|
||||
png_b64 = _make_png_base64()
|
||||
stdout = f"Image created: {png_b64} done"
|
||||
|
||||
outputs = {"result": [small_content]}
|
||||
outputs = {"stdout_logs": [stdout]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "ExecuteCodeBlock"
|
||||
)
|
||||
|
||||
assert "workspace://" in result["stdout_logs"][0]
|
||||
assert "Image created:" in result["stdout_logs"][0]
|
||||
assert "done" in result["stdout_logs"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_small_strings(self, mock_workspace_manager):
|
||||
"""Should not process small strings."""
|
||||
outputs = {"stdout_logs": ["small output"]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert result["result"][0] == small_content
|
||||
assert result["stdout_logs"][0] == "small output"
|
||||
mock_workspace_manager.write_file.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_non_binary_large_strings(self, mock_workspace_manager):
|
||||
"""Should preserve large strings that don't contain valid binary."""
|
||||
large_text = "A" * 5000 # Large but not base64
|
||||
|
||||
outputs = {"stdout_logs": [large_text]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert result["stdout_logs"][0] == large_text
|
||||
mock_workspace_manager.write_file.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_identical_content(self, mock_workspace_manager):
|
||||
"""Should save identical content only once."""
|
||||
png_header = b"\x89PNG\r\n\x1a\n"
|
||||
png_content = png_header + b"\x00" * 2000
|
||||
png_b64 = base64.b64encode(png_content).decode()
|
||||
pdf_b64 = _make_pdf_base64()
|
||||
stdout1 = f"First: {pdf_b64}"
|
||||
stdout2 = f"Second: {pdf_b64}"
|
||||
|
||||
outputs = {"result": [png_b64, png_b64]}
|
||||
outputs = {"stdout_logs": [stdout1, stdout2]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
# Both should have references
|
||||
assert result["result"][0].startswith("workspace://")
|
||||
assert result["result"][1].startswith("workspace://")
|
||||
# But only one write should have happened
|
||||
assert "workspace://" in result["stdout_logs"][0]
|
||||
assert "workspace://" in result["stdout_logs"][1]
|
||||
# But only one write
|
||||
assert mock_workspace_manager.write_file.call_count == 1
|
||||
# And they should be the same reference
|
||||
assert result["result"][0] == result["result"][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processes_nested_dict(self, mock_workspace_manager):
|
||||
"""Should recursively process nested dictionaries."""
|
||||
png_header = b"\x89PNG\r\n\x1a\n"
|
||||
png_content = png_header + b"\x00" * 2000
|
||||
png_b64 = base64.b64encode(png_content).decode()
|
||||
async def test_handles_multiple_binaries_in_one_string(
|
||||
self, mock_workspace_manager
|
||||
):
|
||||
"""Should handle multiple embedded binaries in a single string."""
|
||||
pdf_b64 = _make_pdf_base64()
|
||||
png_b64 = _make_png_base64()
|
||||
stdout = f"PDF: {pdf_b64}\nPNG: {png_b64}"
|
||||
|
||||
outputs = {"result": [{"nested": {"deep": png_b64}}]}
|
||||
outputs = {"stdout_logs": [stdout]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert result["result"][0]["nested"]["deep"].startswith("workspace://")
|
||||
# Should have two workspace references
|
||||
assert result["stdout_logs"][0].count("workspace://") == 2
|
||||
assert mock_workspace_manager.write_file.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processes_nested_list(self, mock_workspace_manager):
|
||||
"""Should recursively process nested lists."""
|
||||
png_header = b"\x89PNG\r\n\x1a\n"
|
||||
png_content = png_header + b"\x00" * 2000
|
||||
png_b64 = base64.b64encode(png_content).decode()
|
||||
async def test_processes_nested_structures(self, mock_workspace_manager):
|
||||
"""Should recursively process nested dicts and lists."""
|
||||
pdf_b64 = _make_pdf_base64()
|
||||
|
||||
outputs = {"result": [[png_b64]]}
|
||||
outputs = {"result": [{"nested": {"deep": f"data: {pdf_b64}"}}]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert result["result"][0][0].startswith("workspace://")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_data_uri_format(self, mock_workspace_manager):
|
||||
"""Should handle data URI format."""
|
||||
png_header = b"\x89PNG\r\n\x1a\n"
|
||||
png_content = png_header + b"\x00" * 2000
|
||||
png_b64 = base64.b64encode(png_content).decode()
|
||||
data_uri = f"data:image/png;base64,{png_b64}"
|
||||
|
||||
outputs = {"result": [data_uri]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert result["result"][0].startswith("workspace://")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_non_binary_large_strings(self, mock_workspace_manager):
|
||||
"""Should preserve large strings that aren't binary."""
|
||||
large_text = "A" * 2000 # Large but not base64 or binary
|
||||
|
||||
outputs = {"result": [large_text]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert result["result"][0] == large_text
|
||||
mock_workspace_manager.write_file.assert_not_called()
|
||||
assert "workspace://" in result["result"][0]["nested"]["deep"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_on_save_failure(self, mock_workspace_manager):
|
||||
"""Should preserve original value if save fails."""
|
||||
"""Should preserve original on save failure."""
|
||||
mock_workspace_manager.write_file = AsyncMock(
|
||||
side_effect=Exception("Storage error")
|
||||
)
|
||||
|
||||
png_header = b"\x89PNG\r\n\x1a\n"
|
||||
png_content = png_header + b"\x00" * 2000
|
||||
png_b64 = base64.b64encode(png_content).decode()
|
||||
pdf_b64 = _make_pdf_base64()
|
||||
stdout = f"PDF: {pdf_b64}"
|
||||
|
||||
outputs = {"result": [png_b64]}
|
||||
outputs = {"stdout_logs": [stdout]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
# Should return original value on failure
|
||||
assert result["result"][0] == png_b64
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_stdout_logs_field(self, mock_workspace_manager):
|
||||
"""Should detect binary in stdout_logs (the actual failing case)."""
|
||||
pdf_content = b"%PDF-1.4 " + b"x" * 2000
|
||||
pdf_b64 = base64.b64encode(pdf_content).decode()
|
||||
|
||||
outputs = {"stdout_logs": [pdf_b64]}
|
||||
|
||||
result = await process_binary_outputs(
|
||||
outputs, mock_workspace_manager, "ExecuteCodeBlock"
|
||||
)
|
||||
|
||||
assert result["stdout_logs"][0].startswith("workspace://")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mimetype to Extension Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMimetypeToExt:
|
||||
"""Tests for _mimetype_to_ext function."""
|
||||
|
||||
def test_png_mapping(self):
|
||||
assert _mimetype_to_ext("image/png") == "png"
|
||||
|
||||
def test_jpeg_mapping(self):
|
||||
assert _mimetype_to_ext("image/jpeg") == "jpg"
|
||||
|
||||
def test_nonstandard_jpg_mapping(self):
|
||||
assert _mimetype_to_ext("image/jpg") == "jpg"
|
||||
|
||||
def test_gif_mapping(self):
|
||||
assert _mimetype_to_ext("image/gif") == "gif"
|
||||
|
||||
def test_webp_mapping(self):
|
||||
assert _mimetype_to_ext("image/webp") == "webp"
|
||||
|
||||
def test_svg_mapping(self):
|
||||
assert _mimetype_to_ext("image/svg+xml") == "svg"
|
||||
|
||||
def test_pdf_mapping(self):
|
||||
assert _mimetype_to_ext("application/pdf") == "pdf"
|
||||
|
||||
def test_octet_stream_mapping(self):
|
||||
assert _mimetype_to_ext("application/octet-stream") == "bin"
|
||||
|
||||
def test_unknown_mimetype(self):
|
||||
assert _mimetype_to_ext("application/unknown") == "bin"
|
||||
# Should keep original since save failed
|
||||
assert pdf_b64 in result["stdout_logs"][0]
|
||||
|
||||
Reference in New Issue
Block a user