mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 20:05:11 -05:00
Compare commits
4 Commits
master
...
otto/secrt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5aad862ce | ||
|
|
4769a281cc | ||
|
|
96ca9daefe | ||
|
|
c1aa684743 |
@@ -0,0 +1,123 @@
|
||||
"""Save binary block outputs to workspace, return references instead of base64."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BINARY_FIELDS = {"png", "jpeg", "pdf"} # Base64 encoded
|
||||
TEXT_FIELDS = {"svg"} # Large text, save raw
|
||||
SAVEABLE_FIELDS = BINARY_FIELDS | TEXT_FIELDS
|
||||
SIZE_THRESHOLD = 1024 # Only process content > 1KB (string length, not decoded size)
|
||||
|
||||
|
||||
async def process_binary_outputs(
|
||||
outputs: dict[str, list[Any]],
|
||||
workspace_manager: WorkspaceManager,
|
||||
block_name: str,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Replace binary data in block outputs with workspace:// references.
|
||||
|
||||
Deduplicates identical content within a single call (e.g., same PDF
|
||||
appearing in both main_result and results).
|
||||
"""
|
||||
cache: dict[str, str] = {} # content_hash -> workspace_ref
|
||||
|
||||
processed: dict[str, list[Any]] = {}
|
||||
for name, items in outputs.items():
|
||||
processed_items: list[Any] = []
|
||||
for item in items:
|
||||
processed_items.append(
|
||||
await _process_item(item, workspace_manager, block_name, cache)
|
||||
)
|
||||
processed[name] = processed_items
|
||||
return processed
|
||||
|
||||
|
||||
async def _process_item(
|
||||
item: Any, wm: WorkspaceManager, block: str, cache: dict
|
||||
) -> Any:
|
||||
"""Recursively process an item, handling dicts and lists."""
|
||||
if isinstance(item, dict):
|
||||
return await _process_dict(item, wm, block, cache)
|
||||
if isinstance(item, list):
|
||||
processed: list[Any] = []
|
||||
for i in item:
|
||||
processed.append(await _process_item(i, wm, block, cache))
|
||||
return processed
|
||||
return item
|
||||
|
||||
|
||||
async def _process_dict(
|
||||
data: dict, wm: WorkspaceManager, block: str, cache: dict
|
||||
) -> dict:
|
||||
"""Process a dict, saving binary fields and recursing into nested structures."""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if (
|
||||
key in SAVEABLE_FIELDS
|
||||
and isinstance(value, str)
|
||||
and len(value) > SIZE_THRESHOLD
|
||||
):
|
||||
content_hash = hashlib.sha256(value.encode()).hexdigest()
|
||||
|
||||
if content_hash in cache:
|
||||
result[key] = cache[content_hash]
|
||||
elif ref := await _save(value, key, wm, block):
|
||||
cache[content_hash] = ref
|
||||
result[key] = ref
|
||||
else:
|
||||
result[key] = value # Save failed, keep original
|
||||
|
||||
elif isinstance(value, dict):
|
||||
result[key] = await _process_dict(value, wm, block, cache)
|
||||
elif isinstance(value, list):
|
||||
processed: list[Any] = []
|
||||
for i in value:
|
||||
processed.append(await _process_item(i, wm, block, cache))
|
||||
result[key] = processed
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _save(value: str, field: str, wm: WorkspaceManager, block: str) -> str | None:
|
||||
"""Save content to workspace, return workspace:// reference or None on failure."""
|
||||
try:
|
||||
if field in BINARY_FIELDS:
|
||||
content = _decode_base64(value)
|
||||
if content is None:
|
||||
return None
|
||||
else:
|
||||
content = value.encode("utf-8")
|
||||
|
||||
ext = {"jpeg": "jpg"}.get(field, field)
|
||||
filename = f"{block.lower().replace(' ', '_')[:20]}_{field}_{uuid.uuid4().hex[:12]}.{ext}"
|
||||
|
||||
file = await wm.write_file(content=content, filename=filename)
|
||||
return f"workspace://{file.id}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {field} to workspace for block '{block}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _decode_base64(value: str) -> bytes | None:
|
||||
"""Decode base64, handling data URI format. Returns None on failure."""
|
||||
try:
|
||||
if value.startswith("data:"):
|
||||
value = value.split(",", 1)[1] if "," in value else value
|
||||
# Normalize padding and use strict validation to prevent corrupted data
|
||||
padded = value + "=" * (-len(value) % 4)
|
||||
return base64.b64decode(padded, validate=True)
|
||||
except (binascii.Error, ValueError):
|
||||
return None
|
||||
@@ -8,12 +8,16 @@ from typing import Any
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.binary_output_processor import (
|
||||
process_binary_outputs,
|
||||
)
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -321,11 +325,19 @@ class RunBlockTool(BaseTool):
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Save binary outputs to workspace to prevent context bloat
|
||||
workspace_manager = WorkspaceManager(
|
||||
user_id, workspace.id, session.session_id
|
||||
)
|
||||
processed_outputs = await process_binary_outputs(
|
||||
dict(outputs), workspace_manager, block.name
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
outputs=processed_outputs,
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.binary_output_processor import (
|
||||
_decode_base64,
|
||||
process_binary_outputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_manager():
|
||||
wm = AsyncMock()
|
||||
wm.write_file = AsyncMock(return_value=MagicMock(id="file-123"))
|
||||
return wm
|
||||
|
||||
|
||||
class TestDecodeBase64:
|
||||
def test_raw_base64(self):
|
||||
assert _decode_base64(base64.b64encode(b"test").decode()) == b"test"
|
||||
|
||||
def test_data_uri(self):
|
||||
encoded = base64.b64encode(b"test").decode()
|
||||
assert _decode_base64(f"data:image/png;base64,{encoded}") == b"test"
|
||||
|
||||
def test_invalid_returns_none(self):
|
||||
assert _decode_base64("not base64!!!") is None
|
||||
|
||||
|
||||
class TestProcessBinaryOutputs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_saves_large_binary(self, workspace_manager):
|
||||
content = base64.b64encode(b"x" * 2000).decode()
|
||||
outputs = {"result": [{"png": content, "text": "ok"}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "Test")
|
||||
|
||||
assert result["result"][0]["png"] == "workspace://file-123"
|
||||
assert result["result"][0]["text"] == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_small_content(self, workspace_manager):
|
||||
content = base64.b64encode(b"tiny").decode()
|
||||
outputs = {"result": [{"png": content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "Test")
|
||||
|
||||
assert result["result"][0]["png"] == content
|
||||
workspace_manager.write_file.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_identical_content(self, workspace_manager):
|
||||
content = base64.b64encode(b"x" * 2000).decode()
|
||||
outputs = {"a": [{"pdf": content}], "b": [{"pdf": content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "Test")
|
||||
|
||||
assert result["a"][0]["pdf"] == result["b"][0]["pdf"] == "workspace://file-123"
|
||||
assert workspace_manager.write_file.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_preserves_original(self, workspace_manager):
|
||||
workspace_manager.write_file.side_effect = Exception("Storage error")
|
||||
content = base64.b64encode(b"x" * 2000).decode()
|
||||
|
||||
result = await process_binary_outputs(
|
||||
{"r": [{"png": content}]}, workspace_manager, "Test"
|
||||
)
|
||||
|
||||
assert result["r"][0]["png"] == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_nested_structures(self, workspace_manager):
|
||||
content = base64.b64encode(b"x" * 2000).decode()
|
||||
outputs = {"result": [{"outer": {"inner": {"png": content}}}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "Test")
|
||||
|
||||
assert result["result"][0]["outer"]["inner"]["png"] == "workspace://file-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_lists_in_output(self, workspace_manager):
|
||||
content = base64.b64encode(b"x" * 2000).decode()
|
||||
outputs = {"result": [{"images": [{"png": content}, {"png": content}]}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "Test")
|
||||
|
||||
assert result["result"][0]["images"][0]["png"] == "workspace://file-123"
|
||||
assert result["result"][0]["images"][1]["png"] == "workspace://file-123"
|
||||
# Deduplication should still work
|
||||
assert workspace_manager.write_file.call_count == 1
|
||||
@@ -8,7 +8,12 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import (
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -273,7 +278,14 @@ async def match_user_credentials_to_graph(
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
and _credential_has_required_scopes(cred, credential_requirements)
|
||||
and (
|
||||
cred.type != "oauth2"
|
||||
or _credential_has_required_scopes(cred, credential_requirements)
|
||||
)
|
||||
and (
|
||||
cred.type != "host_scoped"
|
||||
or _credential_is_for_host(cred, credential_requirements)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -318,19 +330,10 @@ async def match_user_credentials_to_graph(
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
credential: OAuth2Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
@@ -339,6 +342,22 @@ def _credential_has_required_scopes(
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
def _credential_is_for_host(
|
||||
credential: HostScopedCredentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if a host-scoped credential matches the host required by the input."""
|
||||
# We need to know the host to match host-scoped credentials to.
|
||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
||||
if not requirements.discriminator_values:
|
||||
return True
|
||||
|
||||
# Check that credential host matches required host.
|
||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import (
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||
@@ -42,6 +41,7 @@ from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.request import parse_url
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
@@ -397,19 +397,25 @@ class HostScopedCredentials(_BaseCredentials):
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
request_host, request_port = _extract_host_from_url(url)
|
||||
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
# If a port is specified in credential host, the request host port must match
|
||||
if cred_scope_port is not None and request_port != cred_scope_port:
|
||||
return False
|
||||
# Non-standard ports are only allowed if explicitly specified in credential host
|
||||
elif cred_scope_port is None and request_port not in (80, 443, None):
|
||||
return False
|
||||
|
||||
# Simple host matching
|
||||
if cred_scope_host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
if cred_scope_host.startswith("*."):
|
||||
domain = cred_scope_host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
@@ -551,13 +557,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
||||
"""Extract host and port from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
parsed = parse_url(url)
|
||||
return parsed.hostname or url, parsed.port
|
||||
except Exception:
|
||||
return ""
|
||||
return "", None
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
@@ -606,7 +612,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
cast(CP, parse_url(str(value)).netloc)
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
|
||||
@@ -79,10 +79,23 @@ class TestHostScopedCredentials:
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
# Non-standard ports require explicit port in credential host
|
||||
assert not creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_matches_url_with_explicit_port(self):
|
||||
"""Test URL matching with explicit port in credential host."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost:8080",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert not creds.matches_url("http://localhost:3000/api/v1")
|
||||
assert not creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
@@ -128,8 +141,20 @@ class TestHostScopedCredentials:
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
# Non-standard ports require explicit port in credential host
|
||||
("localhost", "http://localhost:3000/test", False),
|
||||
("localhost:3000", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
||||
("[::1]", "http://[::1]/test", True),
|
||||
("[::1]", "http://[::1]:80/test", True),
|
||||
("[::1]", "https://[::1]:443/test", True),
|
||||
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
||||
("[::1]:8080", "http://[::1]:8080/test", True),
|
||||
("[::1]:8080", "http://[::1]:9090/test", False),
|
||||
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
||||
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
||||
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
|
||||
@@ -157,12 +157,7 @@ async def validate_url(
|
||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||
"""
|
||||
# Canonicalize URL
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme:
|
||||
url = f"http://{url}"
|
||||
parsed = urlparse(url)
|
||||
parsed = parse_url(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
@@ -220,6 +215,17 @@ async def validate_url(
|
||||
)
|
||||
|
||||
|
||||
def parse_url(url: str) -> URL:
|
||||
"""Canonicalizes and parses a URL string."""
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
|
||||
# Ensure scheme is present for proper parsing
|
||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
||||
url = f"http://{url}"
|
||||
|
||||
return urlparse(url)
|
||||
|
||||
|
||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||
"""
|
||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||
|
||||
@@ -41,7 +41,17 @@ export function HostScopedCredentialsModal({
|
||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||
|
||||
const formSchema = z.object({
|
||||
host: z.string().min(1, "Host is required"),
|
||||
host: z
|
||||
.string()
|
||||
.min(1, "Host is required")
|
||||
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
|
||||
message: "Enter only the host (e.g. api.example.com), not a full URL",
|
||||
})
|
||||
.refine((val) => !val.includes("/"), {
|
||||
message:
|
||||
"Enter only the host (e.g. api.example.com), without a trailing path. " +
|
||||
"You may specify a port (e.g. api.example.com:8080) if needed.",
|
||||
}),
|
||||
title: z.string().optional(),
|
||||
headers: z.record(z.string()).optional(),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user