Compare commits

...

4 Commits

Author SHA1 Message Date
Otto
e5aad862ce docs: Add missing docstrings for pre-merge check
Adds docstrings to _process_item and _process_dict to meet
the 80% docstring coverage requirement.
2026-02-04 23:15:20 +00:00
Otto
4769a281cc fix: Use strict base64 validation to prevent corrupted saves
Addresses CodeRabbit review feedback:
- Add padding normalization before decoding
- Use validate=True to reject invalid characters instead of silently discarding

This prevents corrupted data from being saved to workspace.
2026-02-04 22:58:59 +00:00
Otto
96ca9daefe feat(copilot): Auto-save binary block outputs to workspace
When blocks produce binary outputs (PNG, JPEG, PDF, SVG), the data is now
automatically saved to the user's workspace and replaced with workspace://
references. This prevents:

- Massive token waste from LLM re-typing base64 strings (17,000+ tokens)
- Potential data corruption from truncation/hallucination
- Poor UX from slow character-by-character output

Implementation:
- New binary_output_processor.py module with hash-based deduplication
- Integration in run_block.py (single entry point for all block executions)
- Graceful degradation: failures preserve original data

Fixes SECRT-1887
2026-02-04 22:26:30 +00:00
Reinier van der Leer
c1aa684743 fix(platform/chat): Filter host-scoped credentials for run_agent tool (#11905)
- Fixes [SECRT-1851: \[Copilot\] `run_agent` tool doesn't filter
host-scoped credentials](https://linear.app/autogpt/issue/SECRT-1851)
- Follow-up to #11881

### Changes 🏗️

- Filter host-scoped credentials for `run_agent` tool
- Tighten validation on host input field in `HostScopedCredentialsModal`
- Use netloc (w/ port) rather than just hostname (w/o port) as host
scope

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Create graph that requires host-scoped credentials to work
  - Create host-scoped credentials with a *different* host
  - Try to have Copilot run the graph
  - [x] -> no matching credentials available
  - Create new credentials
  - [x] -> works

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-02-04 16:27:14 +00:00
8 changed files with 330 additions and 37 deletions

View File

@@ -0,0 +1,123 @@
"""Save binary block outputs to workspace, return references instead of base64."""
import base64
import binascii
import hashlib
import logging
import uuid
from typing import Any
from backend.util.workspace import WorkspaceManager
logger = logging.getLogger(__name__)
BINARY_FIELDS = {"png", "jpeg", "pdf"} # Base64 encoded
TEXT_FIELDS = {"svg"} # Large text, save raw
SAVEABLE_FIELDS = BINARY_FIELDS | TEXT_FIELDS
SIZE_THRESHOLD = 1024 # Only process content > 1KB (string length, not decoded size)
async def process_binary_outputs(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
block_name: str,
) -> dict[str, list[Any]]:
"""
Replace binary data in block outputs with workspace:// references.
Deduplicates identical content within a single call (e.g., same PDF
appearing in both main_result and results).
"""
cache: dict[str, str] = {} # content_hash -> workspace_ref
processed: dict[str, list[Any]] = {}
for name, items in outputs.items():
processed_items: list[Any] = []
for item in items:
processed_items.append(
await _process_item(item, workspace_manager, block_name, cache)
)
processed[name] = processed_items
return processed
async def _process_item(
item: Any, wm: WorkspaceManager, block: str, cache: dict
) -> Any:
"""Recursively process an item, handling dicts and lists."""
if isinstance(item, dict):
return await _process_dict(item, wm, block, cache)
if isinstance(item, list):
processed: list[Any] = []
for i in item:
processed.append(await _process_item(i, wm, block, cache))
return processed
return item
async def _process_dict(
data: dict, wm: WorkspaceManager, block: str, cache: dict
) -> dict:
"""Process a dict, saving binary fields and recursing into nested structures."""
result: dict[str, Any] = {}
for key, value in data.items():
if (
key in SAVEABLE_FIELDS
and isinstance(value, str)
and len(value) > SIZE_THRESHOLD
):
content_hash = hashlib.sha256(value.encode()).hexdigest()
if content_hash in cache:
result[key] = cache[content_hash]
elif ref := await _save(value, key, wm, block):
cache[content_hash] = ref
result[key] = ref
else:
result[key] = value # Save failed, keep original
elif isinstance(value, dict):
result[key] = await _process_dict(value, wm, block, cache)
elif isinstance(value, list):
processed: list[Any] = []
for i in value:
processed.append(await _process_item(i, wm, block, cache))
result[key] = processed
else:
result[key] = value
return result
async def _save(value: str, field: str, wm: WorkspaceManager, block: str) -> str | None:
"""Save content to workspace, return workspace:// reference or None on failure."""
try:
if field in BINARY_FIELDS:
content = _decode_base64(value)
if content is None:
return None
else:
content = value.encode("utf-8")
ext = {"jpeg": "jpg"}.get(field, field)
filename = f"{block.lower().replace(' ', '_')[:20]}_{field}_{uuid.uuid4().hex[:12]}.{ext}"
file = await wm.write_file(content=content, filename=filename)
return f"workspace://{file.id}"
except Exception as e:
logger.error(f"Failed to save {field} to workspace for block '{block}': {e}")
return None
def _decode_base64(value: str) -> bytes | None:
"""Decode base64, handling data URI format. Returns None on failure."""
try:
if value.startswith("data:"):
value = value.split(",", 1)[1] if "," in value else value
# Normalize padding and use strict validation to prevent corrupted data
padded = value + "=" * (-len(value) % 4)
return base64.b64decode(padded, validate=True)
except (binascii.Error, ValueError):
return None

View File

@@ -8,12 +8,16 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.binary_output_processor import (
process_binary_outputs,
)
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import (
@@ -321,11 +325,19 @@ class RunBlockTool(BaseTool):
):
outputs[output_name].append(output_data)
# Save binary outputs to workspace to prevent context bloat
workspace_manager = WorkspaceManager(
user_id, workspace.id, session.session_id
)
processed_outputs = await process_binary_outputs(
dict(outputs), workspace_manager, block.name
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
outputs=processed_outputs,
success=True,
session_id=session_id,
)

View File

@@ -0,0 +1,92 @@
import base64
from unittest.mock import AsyncMock, MagicMock
import pytest
from backend.api.features.chat.tools.binary_output_processor import (
_decode_base64,
process_binary_outputs,
)
@pytest.fixture
def workspace_manager():
wm = AsyncMock()
wm.write_file = AsyncMock(return_value=MagicMock(id="file-123"))
return wm
class TestDecodeBase64:
def test_raw_base64(self):
assert _decode_base64(base64.b64encode(b"test").decode()) == b"test"
def test_data_uri(self):
encoded = base64.b64encode(b"test").decode()
assert _decode_base64(f"data:image/png;base64,{encoded}") == b"test"
def test_invalid_returns_none(self):
assert _decode_base64("not base64!!!") is None
class TestProcessBinaryOutputs:
@pytest.mark.asyncio
async def test_saves_large_binary(self, workspace_manager):
content = base64.b64encode(b"x" * 2000).decode()
outputs = {"result": [{"png": content, "text": "ok"}]}
result = await process_binary_outputs(outputs, workspace_manager, "Test")
assert result["result"][0]["png"] == "workspace://file-123"
assert result["result"][0]["text"] == "ok"
@pytest.mark.asyncio
async def test_skips_small_content(self, workspace_manager):
content = base64.b64encode(b"tiny").decode()
outputs = {"result": [{"png": content}]}
result = await process_binary_outputs(outputs, workspace_manager, "Test")
assert result["result"][0]["png"] == content
workspace_manager.write_file.assert_not_called()
@pytest.mark.asyncio
async def test_deduplicates_identical_content(self, workspace_manager):
content = base64.b64encode(b"x" * 2000).decode()
outputs = {"a": [{"pdf": content}], "b": [{"pdf": content}]}
result = await process_binary_outputs(outputs, workspace_manager, "Test")
assert result["a"][0]["pdf"] == result["b"][0]["pdf"] == "workspace://file-123"
assert workspace_manager.write_file.call_count == 1
@pytest.mark.asyncio
async def test_failure_preserves_original(self, workspace_manager):
workspace_manager.write_file.side_effect = Exception("Storage error")
content = base64.b64encode(b"x" * 2000).decode()
result = await process_binary_outputs(
{"r": [{"png": content}]}, workspace_manager, "Test"
)
assert result["r"][0]["png"] == content
@pytest.mark.asyncio
async def test_handles_nested_structures(self, workspace_manager):
content = base64.b64encode(b"x" * 2000).decode()
outputs = {"result": [{"outer": {"inner": {"png": content}}}]}
result = await process_binary_outputs(outputs, workspace_manager, "Test")
assert result["result"][0]["outer"]["inner"]["png"] == "workspace://file-123"
@pytest.mark.asyncio
async def test_handles_lists_in_output(self, workspace_manager):
content = base64.b64encode(b"x" * 2000).decode()
outputs = {"result": [{"images": [{"png": content}, {"png": content}]}]}
result = await process_binary_outputs(outputs, workspace_manager, "Test")
assert result["result"][0]["images"][0]["png"] == "workspace://file-123"
assert result["result"][0]["images"][1]["png"] == "workspace://file-123"
# Deduplication should still work
assert workspace_manager.write_file.call_count == 1

View File

@@ -8,7 +8,12 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import (
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError
@@ -273,7 +278,14 @@ async def match_user_credentials_to_graph(
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and _credential_has_required_scopes(cred, credential_requirements)
and (
cred.type != "oauth2"
or _credential_has_required_scopes(cred, credential_requirements)
)
and (
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
),
None,
)
@@ -318,19 +330,10 @@ async def match_user_credentials_to_graph(
def _credential_has_required_scopes(
credential: Credentials,
credential: OAuth2Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
"""Check if an OAuth2 credential has all the scopes required by the input."""
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
@@ -339,6 +342,22 @@ def _credential_has_required_scopes(
return set(credential.scopes).issuperset(requirements.required_scopes)
def _credential_is_for_host(
credential: HostScopedCredentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if a host-scoped credential matches the host required by the input."""
# We need to know the host to match host-scoped credentials to.
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
# to discriminator_values. No discriminator_values -> no host to match against.
if not requirements.discriminator_values:
return True
# Check that credential host matches required host.
# Host-scoped credential inputs are grouped by host, so any item from the set works.
return credential.matches_url(list(requirements.discriminator_values)[0])
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -19,7 +19,6 @@ from typing import (
cast,
get_args,
)
from urllib.parse import urlparse
from uuid import uuid4
from prisma.enums import CreditTransactionType, OnboardingStep
@@ -42,6 +41,7 @@ from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.request import parse_url
from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones)
@@ -397,19 +397,25 @@ class HostScopedCredentials(_BaseCredentials):
def matches_url(self, url: str) -> bool:
"""Check if this credential should be applied to the given URL."""
parsed_url = urlparse(url)
# Extract hostname without port
request_host = parsed_url.hostname
request_host, request_port = _extract_host_from_url(url)
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
if not request_host:
return False
# Simple host matching - exact match or wildcard subdomain match
if self.host == request_host:
# If a port is specified in credential host, the request host port must match
if cred_scope_port is not None and request_port != cred_scope_port:
return False
# Non-standard ports are only allowed if explicitly specified in credential host
elif cred_scope_port is None and request_port not in (80, 443, None):
return False
# Simple host matching
if cred_scope_host == request_host:
return True
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
if self.host.startswith("*."):
domain = self.host[2:] # Remove "*."
if cred_scope_host.startswith("*."):
domain = cred_scope_host[2:] # Remove "*."
return request_host.endswith(f".{domain}") or request_host == domain
return False
@@ -551,13 +557,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
def _extract_host_from_url(url: str) -> str:
"""Extract host from URL for grouping host-scoped credentials."""
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
"""Extract host and port from URL for grouping host-scoped credentials."""
try:
parsed = urlparse(url)
return parsed.hostname or url
parsed = parse_url(url)
return parsed.hostname or url, parsed.port
except Exception:
return ""
return "", None
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
@@ -606,7 +612,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
providers = frozenset(
[cast(CP, "http")]
+ [
cast(CP, _extract_host_from_url(str(value)))
cast(CP, parse_url(str(value)).netloc)
for value in field.discriminator_values
]
)

View File

@@ -79,10 +79,23 @@ class TestHostScopedCredentials:
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
# Non-standard ports require explicit port in credential host
assert not creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("https://localhost:443/secure/endpoint")
assert creds.matches_url("http://localhost/simple")
def test_matches_url_with_explicit_port(self):
"""Test URL matching with explicit port in credential host."""
creds = HostScopedCredentials(
provider="custom",
host="localhost:8080",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
assert not creds.matches_url("http://localhost:3000/api/v1")
assert not creds.matches_url("http://localhost/simple")
def test_empty_headers_dict(self):
"""Test HostScopedCredentials with empty headers."""
creds = HostScopedCredentials(
@@ -128,8 +141,20 @@ class TestHostScopedCredentials:
("*.example.com", "https://sub.api.example.com/test", True),
("*.example.com", "https://example.com/test", True),
("*.example.com", "https://example.org/test", False),
("localhost", "http://localhost:3000/test", True),
# Non-standard ports require explicit port in credential host
("localhost", "http://localhost:3000/test", False),
("localhost:3000", "http://localhost:3000/test", True),
("localhost", "http://127.0.0.1:3000/test", False),
# IPv6 addresses (frontend stores with brackets via URL.hostname)
("[::1]", "http://[::1]/test", True),
("[::1]", "http://[::1]:80/test", True),
("[::1]", "https://[::1]:443/test", True),
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
("[::1]:8080", "http://[::1]:8080/test", True),
("[::1]:8080", "http://[::1]:9090/test", False),
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
],
)
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):

View File

@@ -157,12 +157,7 @@ async def validate_url(
is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted
"""
# Canonicalize URL
url = url.strip("/ ").replace("\\", "/")
parsed = urlparse(url)
if not parsed.scheme:
url = f"http://{url}"
parsed = urlparse(url)
parsed = parse_url(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
@@ -220,6 +215,17 @@ async def validate_url(
)
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
url = url.strip("/ ").replace("\\", "/")
# Ensure scheme is present for proper parsing
if not re.match(r"[a-z0-9+.\-]+://", url):
url = f"http://{url}"
return urlparse(url)
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
"""
Pins a URL to a specific IP address to prevent DNS rebinding attacks.

View File

@@ -41,7 +41,17 @@ export function HostScopedCredentialsModal({
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
const formSchema = z.object({
host: z.string().min(1, "Host is required"),
host: z
.string()
.min(1, "Host is required")
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
message: "Enter only the host (e.g. api.example.com), not a full URL",
})
.refine((val) => !val.includes("/"), {
message:
"Enter only the host (e.g. api.example.com), without a trailing path. " +
"You may specify a port (e.g. api.example.com:8080) if needed.",
}),
title: z.string().optional(),
headers: z.record(z.string()).optional(),
});