mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 20:05:11 -05:00
Compare commits
1 Commits
release-v0
...
otto/secrt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6b76e672c |
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Save binary block outputs to workspace, return references instead of base64.
|
||||
|
||||
This module post-processes block execution outputs to detect and save binary
|
||||
content (from code execution results) to the workspace, returning workspace://
|
||||
references instead of raw base64 data. This reduces LLM output token usage
|
||||
by ~97% for file generation tasks.
|
||||
|
||||
Detection is field-name based, targeting the standard e2b CodeExecutionResult
|
||||
fields: png, jpeg, pdf, svg. Other image-producing blocks already use
|
||||
store_media_file() and don't need this post-processing.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.util.file import sanitize_filename
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Field names that contain binary data (base64 encoded)
|
||||
BINARY_FIELDS = {"png", "jpeg", "pdf"}
|
||||
|
||||
# Field names that contain large text data (not base64, save as-is)
|
||||
TEXT_FIELDS = {"svg"}
|
||||
|
||||
# Combined set for quick lookup
|
||||
SAVEABLE_FIELDS = BINARY_FIELDS | TEXT_FIELDS
|
||||
|
||||
# Only process content larger than this (string length, not decoded size)
|
||||
SIZE_THRESHOLD = 1024 # 1KB
|
||||
|
||||
|
||||
async def process_binary_outputs(
|
||||
outputs: dict[str, list[Any]],
|
||||
workspace_manager: WorkspaceManager,
|
||||
block_name: str,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Replace binary data in block outputs with workspace:// references.
|
||||
|
||||
Scans outputs for known binary field names (png, jpeg, pdf, svg) and saves
|
||||
large content to the workspace. Returns processed outputs with base64 data
|
||||
replaced by workspace:// references.
|
||||
|
||||
Deduplicates identical content within a single call using content hashing.
|
||||
|
||||
Args:
|
||||
outputs: Block execution outputs (dict of output_name -> list of values)
|
||||
workspace_manager: WorkspaceManager instance with session scoping
|
||||
block_name: Name of the block (used in generated filenames)
|
||||
|
||||
Returns:
|
||||
Processed outputs with binary data replaced by workspace references
|
||||
"""
|
||||
cache: dict[str, str] = {} # content_hash -> workspace_ref
|
||||
|
||||
processed: dict[str, list[Any]] = {}
|
||||
for name, items in outputs.items():
|
||||
processed_items: list[Any] = []
|
||||
for item in items:
|
||||
processed_items.append(
|
||||
await _process_item(item, workspace_manager, block_name, cache)
|
||||
)
|
||||
processed[name] = processed_items
|
||||
return processed
|
||||
|
||||
|
||||
async def _process_item(
|
||||
item: Any,
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
cache: dict[str, str],
|
||||
) -> Any:
|
||||
"""Recursively process an item, handling dicts and lists."""
|
||||
if isinstance(item, dict):
|
||||
return await _process_dict(item, wm, block, cache)
|
||||
if isinstance(item, list):
|
||||
processed: list[Any] = []
|
||||
for i in item:
|
||||
processed.append(await _process_item(i, wm, block, cache))
|
||||
return processed
|
||||
return item
|
||||
|
||||
|
||||
async def _process_dict(
|
||||
data: dict[str, Any],
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
cache: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""Process a dict, saving binary fields and recursing into nested structures."""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if (
|
||||
key in SAVEABLE_FIELDS
|
||||
and isinstance(value, str)
|
||||
and len(value) > SIZE_THRESHOLD
|
||||
):
|
||||
# Determine content bytes based on field type
|
||||
if key in BINARY_FIELDS:
|
||||
content = _decode_base64(value)
|
||||
if content is None:
|
||||
# Decode failed, keep original value
|
||||
result[key] = value
|
||||
continue
|
||||
else:
|
||||
# TEXT_FIELDS: encode as UTF-8
|
||||
content = value.encode("utf-8")
|
||||
|
||||
# Hash decoded content for deduplication
|
||||
content_hash = hashlib.sha256(content).hexdigest()
|
||||
|
||||
if content_hash in cache:
|
||||
# Reuse existing workspace reference
|
||||
result[key] = cache[content_hash]
|
||||
elif ref := await _save_content(content, key, wm, block):
|
||||
# Save succeeded, cache and use reference
|
||||
cache[content_hash] = ref
|
||||
result[key] = ref
|
||||
else:
|
||||
# Save failed, keep original value
|
||||
result[key] = value
|
||||
|
||||
elif isinstance(value, dict):
|
||||
result[key] = await _process_dict(value, wm, block, cache)
|
||||
elif isinstance(value, list):
|
||||
processed: list[Any] = []
|
||||
for i in value:
|
||||
processed.append(await _process_item(i, wm, block, cache))
|
||||
result[key] = processed
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _save_content(
|
||||
content: bytes,
|
||||
field: str,
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
Save content to workspace, return workspace:// reference.
|
||||
|
||||
Args:
|
||||
content: Decoded binary content to save
|
||||
field: Field name (used for extension)
|
||||
wm: WorkspaceManager instance
|
||||
block: Block name (used in filename)
|
||||
|
||||
Returns:
|
||||
workspace://file-id reference, or None if save failed
|
||||
"""
|
||||
try:
|
||||
# Map field name to file extension
|
||||
ext = {"jpeg": "jpg"}.get(field, field)
|
||||
|
||||
# Sanitize block name for safe filename
|
||||
safe_block = sanitize_filename(block.lower())[:20]
|
||||
filename = f"{safe_block}_{field}_{uuid.uuid4().hex[:12]}.{ext}"
|
||||
|
||||
file = await wm.write_file(content=content, filename=filename)
|
||||
return f"workspace://{file.id}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {field} to workspace for block '{block}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _decode_base64(value: str) -> bytes | None:
|
||||
"""
|
||||
Decode base64 string, handling both raw base64 and data URI formats.
|
||||
|
||||
Args:
|
||||
value: Base64 string or data URI (data:<mime>;base64,<payload>)
|
||||
|
||||
Returns:
|
||||
Decoded bytes, or None if decoding failed
|
||||
"""
|
||||
try:
|
||||
# Handle data URI format
|
||||
if value.startswith("data:"):
|
||||
if "," in value:
|
||||
value = value.split(",", 1)[1]
|
||||
else:
|
||||
# Malformed data URI, no comma separator
|
||||
return None
|
||||
|
||||
# Normalize padding (handle missing = chars)
|
||||
padded = value + "=" * (-len(value) % 4)
|
||||
|
||||
# Strict validation to prevent corrupted data
|
||||
return base64.b64decode(padded, validate=True)
|
||||
|
||||
except (binascii.Error, ValueError):
|
||||
return None
|
||||
@@ -8,12 +8,16 @@ from typing import Any
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.binary_output_processor import (
|
||||
process_binary_outputs,
|
||||
)
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -321,11 +325,20 @@ class RunBlockTool(BaseTool):
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Save binary outputs to workspace to prevent context bloat
|
||||
# (code execution results with png/jpeg/pdf/svg fields)
|
||||
workspace_manager = WorkspaceManager(
|
||||
user_id, workspace.id, session.session_id
|
||||
)
|
||||
processed_outputs = await process_binary_outputs(
|
||||
dict(outputs), workspace_manager, block.name
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
outputs=processed_outputs,
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
"""Unit tests for binary_output_processor module."""
|
||||
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.binary_output_processor import (
|
||||
_decode_base64,
|
||||
process_binary_outputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_manager():
|
||||
"""Create a mock WorkspaceManager."""
|
||||
mock = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-123"
|
||||
mock.write_file = AsyncMock(return_value=mock_file)
|
||||
return mock
|
||||
|
||||
|
||||
class TestDecodeBase64:
|
||||
"""Tests for _decode_base64 function."""
|
||||
|
||||
def test_raw_base64(self):
|
||||
"""Decode raw base64 string."""
|
||||
encoded = base64.b64encode(b"test content").decode()
|
||||
result = _decode_base64(encoded)
|
||||
assert result == b"test content"
|
||||
|
||||
def test_data_uri(self):
|
||||
"""Decode base64 from data URI format."""
|
||||
content = b"test content"
|
||||
encoded = base64.b64encode(content).decode()
|
||||
data_uri = f"data:image/png;base64,{encoded}"
|
||||
result = _decode_base64(data_uri)
|
||||
assert result == content
|
||||
|
||||
def test_invalid_base64(self):
|
||||
"""Return None for invalid base64."""
|
||||
result = _decode_base64("not valid base64!!!")
|
||||
assert result is None
|
||||
|
||||
def test_missing_padding(self):
|
||||
"""Handle base64 with missing padding."""
|
||||
# base64.b64encode(b"test") = "dGVzdA=="
|
||||
# Remove padding
|
||||
result = _decode_base64("dGVzdA")
|
||||
assert result == b"test"
|
||||
|
||||
def test_malformed_data_uri(self):
|
||||
"""Return None for data URI without comma."""
|
||||
result = _decode_base64("data:image/png;base64")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestProcessBinaryOutputs:
|
||||
"""Tests for process_binary_outputs function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saves_large_png(self, workspace_manager):
|
||||
"""Large PNG content should be saved to workspace."""
|
||||
# Create content larger than SIZE_THRESHOLD (1KB)
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {"result": [{"png": encoded}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["png"] == "workspace://file-123"
|
||||
workspace_manager.write_file.assert_called_once()
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
assert call_kwargs["content"] == large_content
|
||||
assert "testblock_png_" in call_kwargs["filename"]
|
||||
assert call_kwargs["filename"].endswith(".png")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_small_content(self, workspace_manager):
|
||||
"""Small content should be preserved as-is."""
|
||||
small_content = base64.b64encode(b"tiny").decode()
|
||||
outputs = {"result": [{"png": small_content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["png"] == small_content
|
||||
workspace_manager.write_file.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_identical_content(self, workspace_manager):
|
||||
"""Identical content should only be saved once."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {
|
||||
"main_result": [{"png": encoded}],
|
||||
"results": [{"png": encoded}],
|
||||
}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
# Both should have the same workspace reference
|
||||
assert result["main_result"][0]["png"] == "workspace://file-123"
|
||||
assert result["results"][0]["png"] == "workspace://file-123"
|
||||
# But write_file should only be called once
|
||||
assert workspace_manager.write_file.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_on_save_failure(self, workspace_manager):
|
||||
"""Original content should be preserved if save fails."""
|
||||
workspace_manager.write_file = AsyncMock(side_effect=Exception("Save failed"))
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {"result": [{"png": encoded}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
# Original content should be preserved
|
||||
assert result["result"][0]["png"] == encoded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_nested_structures(self, workspace_manager):
|
||||
"""Should traverse nested dicts and lists."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {
|
||||
"result": [
|
||||
{
|
||||
"nested": {
|
||||
"deep": {
|
||||
"png": encoded,
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["nested"]["deep"]["png"] == "workspace://file-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_svg_as_text(self, workspace_manager):
|
||||
"""SVG should be saved as UTF-8 text, not base64 decoded."""
|
||||
svg_content = "<svg>" + "x" * 2000 + "</svg>"
|
||||
outputs = {"result": [{"svg": svg_content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["svg"] == "workspace://file-123"
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
# SVG should be UTF-8 encoded, not base64 decoded
|
||||
assert call_kwargs["content"] == svg_content.encode("utf-8")
|
||||
assert call_kwargs["filename"].endswith(".svg")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_unknown_fields(self, workspace_manager):
|
||||
"""Fields not in SAVEABLE_FIELDS should be ignored."""
|
||||
large_content = "x" * 2000 # Large text in an unknown field
|
||||
outputs = {"result": [{"unknown_field": large_content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["unknown_field"] == large_content
|
||||
workspace_manager.write_file.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_jpeg_extension(self, workspace_manager):
|
||||
"""JPEG files should use .jpg extension."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {"result": [{"jpeg": encoded}]}
|
||||
|
||||
await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
assert call_kwargs["filename"].endswith(".jpg")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_data_uri_in_binary_field(self, workspace_manager):
|
||||
"""Data URI format in binary fields should be properly decoded."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
data_uri = f"data:image/png;base64,{encoded}"
|
||||
outputs = {"result": [{"png": data_uri}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["png"] == "workspace://file-123"
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
assert call_kwargs["content"] == large_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_base64_preserves_original(self, workspace_manager):
|
||||
"""Invalid base64 in a binary field should preserve the original value."""
|
||||
invalid_content = "not valid base64!!!" + "x" * 2000
|
||||
outputs = {"result": [{"png": invalid_content}]}
|
||||
|
||||
processed = await process_binary_outputs(
|
||||
outputs, workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert processed["result"][0]["png"] == invalid_content
|
||||
workspace_manager.write_file.assert_not_called()
|
||||
Reference in New Issue
Block a user