Compare commits

..

9 Commits

Author SHA1 Message Date
Otto
883067e2a4 style: move _credential_has_required_scopes after match_user_credentials_to_graph
Both find_matching_credential and match_user_credentials_to_graph use this
helper, so it belongs after both of them.
2026-02-03 18:46:03 +00:00
Otto
85bd23cea4 style: reorder credential functions in proper top-down order
High-level functions first, helpers below:
1. match_credentials_to_requirements (main entry point)
2. get_user_credentials
3. find_matching_credential
4. create_credential_meta_from_match
5. _credential_has_required_scopes (lowest-level helper)
2026-02-03 18:43:52 +00:00
Otto
5c2cae3eb3 refactor: address Pwuts review feedback
- Remove unused error_response() and format_inputs_as_markdown() from helpers.py
- Remove _get_inputs_list() wrapper from run_agent.py, use get_inputs_from_schema directly
- Fix type annotations: get_user_credentials, find_matching_credential, create_credential_meta_from_match
- Remove check_scopes parameter - always check scopes (original missing check was broken behavior)
- Reorder _credential_has_required_scopes to be defined before find_matching_credential
2026-02-03 16:57:38 +00:00
Otto
d8909af967 fix: enable scope checking for block credentials (consistency with graphs)
Previously run_block didn't check OAuth2 scopes while run_agent did.
Now both use the same scope-checking logic for credential matching.
2026-02-03 13:19:15 +00:00
Otto
ff8ca11845 fix: preserve original credential matching behavior
- Add check_scopes parameter to find_matching_credential and
  match_credentials_to_requirements (default True)
- run_block uses check_scopes=False to preserve original behavior
  (original run_block did not verify OAuth2 scopes)
- Add isinstance check to get_inputs_from_schema for safety
  (original returned [] if input_schema wasn't a dict)
2026-02-03 13:15:59 +00:00
Zamil Majdy
4d6471a7eb Merge branch 'dev' into otto/copilot-cleanup-dev-v2 2026-02-03 20:08:03 +07:00
Otto
7dc53071e8 fix(backend): Add retry and error handling to block initialization (#11946)
## Summary
Adds retry logic and graceful error handling to `initialize_blocks()` to
prevent transient DB errors from crashing server startup.

## Problem
When a transient database error occurs during block initialization
(e.g., Prisma P1017 "Server has closed the connection"), the entire
server fails to start. This is overly aggressive since:
1. Blocks are already registered in memory
2. The DB sync is primarily for tracking/schema storage
3. One flaky connection shouldn't prevent the server from starting

**Triggered by:** [Sentry
AUTOGPT-SERVER-7PW](https://significant-gravitas.sentry.io/issues/7238733543/)

## Solution
- Add retry decorator (3 attempts with exponential backoff) for DB
operations
- On failure after retries, log a warning and continue to the next block
- Blocks remain available in memory even if DB sync fails
- Log summary of any failed blocks at the end

## Changes
- `autogpt_platform/backend/backend/data/block.py`: Wrap block DB sync
in retry logic with graceful fallback

## Testing
- Existing block initialization behavior unchanged on success
- On transient DB errors: retries up to 3 times, then continues with
warning
2026-02-03 12:43:30 +00:00
Otto
dcdd886067 chore: remove docstrings and use sorted() for deterministic UUID ordering 2026-02-03 12:29:31 +00:00
Otto
6098c5eed6 refactor(copilot): code cleanup - extract shared helpers and reduce duplication
- Create util/validation.py with UUID validation helpers
- Create tools/helpers.py with shared utilities (get_inputs_from_schema, etc.)
- Add shared credential matching utilities to utils.py
- Refactor run_block to use shared matching with discriminator support
- Extract _create_stream_generator in routes.py
- Update run_agent.py to use shared helpers

Preserves discriminator logic for multi-provider credential matching.
2026-02-03 12:15:24 +00:00
9 changed files with 268 additions and 616 deletions

View File

@@ -17,6 +17,13 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_
config = ChatConfig()
SSE_RESPONSE_HEADERS = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
}
logger = logging.getLogger(__name__)
@@ -32,6 +39,48 @@ async def _validate_and_get_session(
return session
async def _create_stream_generator(
session_id: str,
message: str,
user_id: str | None,
session: ChatSession,
is_user_message: bool = True,
context: dict[str, str] | None = None,
) -> AsyncGenerator[str, None]:
"""Create SSE event generator for chat streaming."""
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session,
context=context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
yield "data: [DONE]\n\n"
router = APIRouter(
tags=["chat"],
)
@@ -221,49 +270,17 @@ async def stream_chat_post(
"""
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
_create_stream_generator(
session_id=session_id,
message=request.message,
user_id=user_id,
session=session,
is_user_message=request.is_user_message,
context=request.context,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
headers=SSE_RESPONSE_HEADERS,
)
@@ -295,48 +312,16 @@ async def stream_chat_get(
"""
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
_create_stream_generator(
session_id=session_id,
message=message,
user_id=user_id,
session=session,
is_user_message=is_user_message,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
headers=SSE_RESPONSE_HEADERS,
)

View File

@@ -1,204 +0,0 @@
"""
Save binary block outputs to workspace, return references instead of base64.
This module post-processes block execution outputs to detect and save binary
content (from code execution results) to the workspace, returning workspace://
references instead of raw base64 data. This reduces LLM output token usage
by ~97% for file generation tasks.
Detection is field-name based, targeting the standard e2b CodeExecutionResult
fields: png, jpeg, pdf, svg. Other image-producing blocks already use
store_media_file() and don't need this post-processing.
"""
import base64
import binascii
import hashlib
import logging
import uuid
from typing import Any
from backend.util.file import sanitize_filename
from backend.util.workspace import WorkspaceManager
logger = logging.getLogger(__name__)
# Field names that contain binary data (base64 encoded)
BINARY_FIELDS = {"png", "jpeg", "pdf"}
# Field names that contain large text data (not base64, save as-is)
TEXT_FIELDS = {"svg"}
# Combined set for quick lookup
SAVEABLE_FIELDS = BINARY_FIELDS | TEXT_FIELDS
# Only process content larger than this (string length, not decoded size)
SIZE_THRESHOLD = 1024 # 1KB
async def process_binary_outputs(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
block_name: str,
) -> dict[str, list[Any]]:
"""
Replace binary data in block outputs with workspace:// references.
Scans outputs for known binary field names (png, jpeg, pdf, svg) and saves
large content to the workspace. Returns processed outputs with base64 data
replaced by workspace:// references.
Deduplicates identical content within a single call using content hashing.
Args:
outputs: Block execution outputs (dict of output_name -> list of values)
workspace_manager: WorkspaceManager instance with session scoping
block_name: Name of the block (used in generated filenames)
Returns:
Processed outputs with binary data replaced by workspace references
"""
cache: dict[str, str] = {} # content_hash -> workspace_ref
processed: dict[str, list[Any]] = {}
for name, items in outputs.items():
processed_items: list[Any] = []
for item in items:
processed_items.append(
await _process_item(item, workspace_manager, block_name, cache)
)
processed[name] = processed_items
return processed
async def _process_item(
item: Any,
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> Any:
"""Recursively process an item, handling dicts and lists."""
if isinstance(item, dict):
return await _process_dict(item, wm, block, cache)
if isinstance(item, list):
processed: list[Any] = []
for i in item:
processed.append(await _process_item(i, wm, block, cache))
return processed
return item
async def _process_dict(
data: dict[str, Any],
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> dict[str, Any]:
"""Process a dict, saving binary fields and recursing into nested structures."""
result: dict[str, Any] = {}
for key, value in data.items():
if (
key in SAVEABLE_FIELDS
and isinstance(value, str)
and len(value) > SIZE_THRESHOLD
):
# Determine content bytes based on field type
if key in BINARY_FIELDS:
content = _decode_base64(value)
if content is None:
# Decode failed, keep original value
result[key] = value
continue
else:
# TEXT_FIELDS: encode as UTF-8
content = value.encode("utf-8")
# Hash decoded content for deduplication
content_hash = hashlib.sha256(content).hexdigest()
if content_hash in cache:
# Reuse existing workspace reference
result[key] = cache[content_hash]
elif ref := await _save_content(content, key, wm, block):
# Save succeeded, cache and use reference
cache[content_hash] = ref
result[key] = ref
else:
# Save failed, keep original value
result[key] = value
elif isinstance(value, dict):
result[key] = await _process_dict(value, wm, block, cache)
elif isinstance(value, list):
processed: list[Any] = []
for i in value:
processed.append(await _process_item(i, wm, block, cache))
result[key] = processed
else:
result[key] = value
return result
async def _save_content(
content: bytes,
field: str,
wm: WorkspaceManager,
block: str,
) -> str | None:
"""
Save content to workspace, return workspace:// reference.
Args:
content: Decoded binary content to save
field: Field name (used for extension)
wm: WorkspaceManager instance
block: Block name (used in filename)
Returns:
workspace://file-id reference, or None if save failed
"""
try:
# Map field name to file extension
ext = {"jpeg": "jpg"}.get(field, field)
# Sanitize block name for safe filename
safe_block = sanitize_filename(block.lower())[:20]
filename = f"{safe_block}_{field}_{uuid.uuid4().hex[:12]}.{ext}"
file = await wm.write_file(content=content, filename=filename)
return f"workspace://{file.id}"
except Exception as e:
logger.error(f"Failed to save {field} to workspace for block '{block}': {e}")
return None
def _decode_base64(value: str) -> bytes | None:
"""
Decode base64 string, handling both raw base64 and data URI formats.
Args:
value: Base64 string or data URI (data:<mime>;base64,<payload>)
Returns:
Decoded bytes, or None if decoding failed
"""
try:
# Handle data URI format
if value.startswith("data:"):
if "," in value:
value = value.split(",", 1)[1]
else:
# Malformed data URI, no comma separator
return None
# Normalize padding (handle missing = chars)
padded = value + "=" * (-len(value) % 4)
# Strict validation to prevent corrupted data
return base64.b64decode(padded, validate=True)
except (binascii.Error, ValueError):
return None

View File

@@ -0,0 +1,29 @@
"""Shared helpers for chat tools."""
from typing import Any
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema."""
if not isinstance(input_schema, dict):
return []
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]

View File

@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
),
requirements={
"credentials": requirements_creds_list,
"inputs": self._get_inputs_list(graph.input_schema),
"inputs": get_inputs_from_schema(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
suffix: str,
) -> str:
"""Build a message describing available inputs for an agent."""
inputs_list = self._get_inputs_list(graph.input_schema)
inputs_list = get_inputs_from_schema(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]]

View File

@@ -8,18 +8,15 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.binary_output_processor import (
process_binary_outputs,
)
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
BlockOutputResponse,
ErrorResponse,
@@ -28,7 +25,10 @@ from .models import (
ToolResponseBase,
UserReadiness,
)
from .utils import build_missing_credentials_from_field_info
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
logger = logging.getLogger(__name__)
@@ -77,41 +77,22 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
def _resolve_discriminated_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
return {}
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
resolved: dict[str, CredentialsFieldInfo] = {}
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
@@ -130,37 +111,34 @@ class RunBlockTool(BaseTool):
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
resolved[field_name] = effective_field_info
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return resolved
return matched_credentials, missing_credentials
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
async def _execute(
self,
@@ -325,20 +303,11 @@ class RunBlockTool(BaseTool):
):
outputs[output_name].append(output_data)
# Save binary outputs to workspace to prevent context bloat
# (code execution results with png/jpeg/pdf/svg fields)
workspace_manager = WorkspaceManager(
user_id, workspace.id, session.session_id
)
processed_outputs = await process_binary_outputs(
dict(outputs), workspace_manager, block.name
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=processed_outputs,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
@@ -360,27 +329,6 @@ class RunBlockTool(BaseTool):
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
return inputs_list
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)

View File

@@ -1,204 +0,0 @@
"""Unit tests for binary_output_processor module."""
import base64
from unittest.mock import AsyncMock, MagicMock
import pytest
from backend.api.features.chat.tools.binary_output_processor import (
_decode_base64,
process_binary_outputs,
)
@pytest.fixture
def workspace_manager():
"""Create a mock WorkspaceManager."""
mock = MagicMock()
mock_file = MagicMock()
mock_file.id = "file-123"
mock.write_file = AsyncMock(return_value=mock_file)
return mock
class TestDecodeBase64:
"""Tests for _decode_base64 function."""
def test_raw_base64(self):
"""Decode raw base64 string."""
encoded = base64.b64encode(b"test content").decode()
result = _decode_base64(encoded)
assert result == b"test content"
def test_data_uri(self):
"""Decode base64 from data URI format."""
content = b"test content"
encoded = base64.b64encode(content).decode()
data_uri = f"data:image/png;base64,{encoded}"
result = _decode_base64(data_uri)
assert result == content
def test_invalid_base64(self):
"""Return None for invalid base64."""
result = _decode_base64("not valid base64!!!")
assert result is None
def test_missing_padding(self):
"""Handle base64 with missing padding."""
# base64.b64encode(b"test") = "dGVzdA=="
# Remove padding
result = _decode_base64("dGVzdA")
assert result == b"test"
def test_malformed_data_uri(self):
"""Return None for data URI without comma."""
result = _decode_base64("data:image/png;base64")
assert result is None
class TestProcessBinaryOutputs:
"""Tests for process_binary_outputs function."""
@pytest.mark.asyncio
async def test_saves_large_png(self, workspace_manager):
"""Large PNG content should be saved to workspace."""
# Create content larger than SIZE_THRESHOLD (1KB)
large_content = b"x" * 2000
encoded = base64.b64encode(large_content).decode()
outputs = {"result": [{"png": encoded}]}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
assert result["result"][0]["png"] == "workspace://file-123"
workspace_manager.write_file.assert_called_once()
call_kwargs = workspace_manager.write_file.call_args.kwargs
assert call_kwargs["content"] == large_content
assert "testblock_png_" in call_kwargs["filename"]
assert call_kwargs["filename"].endswith(".png")
@pytest.mark.asyncio
async def test_preserves_small_content(self, workspace_manager):
"""Small content should be preserved as-is."""
small_content = base64.b64encode(b"tiny").decode()
outputs = {"result": [{"png": small_content}]}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
assert result["result"][0]["png"] == small_content
workspace_manager.write_file.assert_not_called()
@pytest.mark.asyncio
async def test_deduplicates_identical_content(self, workspace_manager):
"""Identical content should only be saved once."""
large_content = b"x" * 2000
encoded = base64.b64encode(large_content).decode()
outputs = {
"main_result": [{"png": encoded}],
"results": [{"png": encoded}],
}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
# Both should have the same workspace reference
assert result["main_result"][0]["png"] == "workspace://file-123"
assert result["results"][0]["png"] == "workspace://file-123"
# But write_file should only be called once
assert workspace_manager.write_file.call_count == 1
@pytest.mark.asyncio
async def test_graceful_degradation_on_save_failure(self, workspace_manager):
"""Original content should be preserved if save fails."""
workspace_manager.write_file = AsyncMock(side_effect=Exception("Save failed"))
large_content = b"x" * 2000
encoded = base64.b64encode(large_content).decode()
outputs = {"result": [{"png": encoded}]}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
# Original content should be preserved
assert result["result"][0]["png"] == encoded
@pytest.mark.asyncio
async def test_handles_nested_structures(self, workspace_manager):
"""Should traverse nested dicts and lists."""
large_content = b"x" * 2000
encoded = base64.b64encode(large_content).decode()
outputs = {
"result": [
{
"nested": {
"deep": {
"png": encoded,
}
}
}
]
}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
assert result["result"][0]["nested"]["deep"]["png"] == "workspace://file-123"
@pytest.mark.asyncio
async def test_handles_svg_as_text(self, workspace_manager):
"""SVG should be saved as UTF-8 text, not base64 decoded."""
svg_content = "<svg>" + "x" * 2000 + "</svg>"
outputs = {"result": [{"svg": svg_content}]}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
assert result["result"][0]["svg"] == "workspace://file-123"
call_kwargs = workspace_manager.write_file.call_args.kwargs
# SVG should be UTF-8 encoded, not base64 decoded
assert call_kwargs["content"] == svg_content.encode("utf-8")
assert call_kwargs["filename"].endswith(".svg")
@pytest.mark.asyncio
async def test_ignores_unknown_fields(self, workspace_manager):
"""Fields not in SAVEABLE_FIELDS should be ignored."""
large_content = "x" * 2000 # Large text in an unknown field
outputs = {"result": [{"unknown_field": large_content}]}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
assert result["result"][0]["unknown_field"] == large_content
workspace_manager.write_file.assert_not_called()
@pytest.mark.asyncio
async def test_handles_jpeg_extension(self, workspace_manager):
"""JPEG files should use .jpg extension."""
large_content = b"x" * 2000
encoded = base64.b64encode(large_content).decode()
outputs = {"result": [{"jpeg": encoded}]}
await process_binary_outputs(outputs, workspace_manager, "TestBlock")
call_kwargs = workspace_manager.write_file.call_args.kwargs
assert call_kwargs["filename"].endswith(".jpg")
@pytest.mark.asyncio
async def test_handles_data_uri_in_binary_field(self, workspace_manager):
"""Data URI format in binary fields should be properly decoded."""
large_content = b"x" * 2000
encoded = base64.b64encode(large_content).decode()
data_uri = f"data:image/png;base64,{encoded}"
outputs = {"result": [{"png": data_uri}]}
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
assert result["result"][0]["png"] == "workspace://file-123"
call_kwargs = workspace_manager.write_file.call_args.kwargs
assert call_kwargs["content"] == large_content
@pytest.mark.asyncio
async def test_invalid_base64_preserves_original(self, workspace_manager):
"""Invalid base64 in a binary field should preserve the original value."""
invalid_content = "not valid base64!!!" + "x" * 2000
outputs = {"result": [{"png": invalid_content}]}
processed = await process_binary_outputs(
outputs, workspace_manager, "TestBlock"
)
assert processed["result"][0]["png"] == invalid_content
workspace_manager.write_file.assert_not_called()

View File

@@ -225,6 +225,95 @@ async def get_or_create_library_agent(
return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, and scopes."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if not _credential_has_required_scopes(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph(
user_id: str,
graph: GraphModel,
@@ -321,21 +410,11 @@ def _credential_has_required_scopes(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
"""Check if a credential has all the scopes required by the block."""
if credential.type != "oauth2":
return True
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)

View File

@@ -873,14 +873,13 @@ def is_block_auth_configured(
async def initialize_blocks() -> None:
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs
from backend.util.retry import func_retry
sync_all_provider_costs()
for cls in get_blocks().values():
block = cls()
@func_retry
async def sync_block_to_db(block: Block) -> None:
existing_block = await AgentBlock.prisma().find_first(
where={"OR": [{"id": block.id}, {"name": block.name}]}
)
@@ -893,7 +892,7 @@ async def initialize_blocks() -> None:
outputSchema=json.dumps(block.output_schema.jsonschema()),
)
)
continue
return
input_schema = json.dumps(block.input_schema.jsonschema())
output_schema = json.dumps(block.output_schema.jsonschema())
@@ -913,6 +912,25 @@ async def initialize_blocks() -> None:
},
)
failed_blocks: list[str] = []
for cls in get_blocks().values():
block = cls()
try:
await sync_block_to_db(block)
except Exception as e:
logger.warning(
f"Failed to sync block {block.name} to database: {e}. "
"Block is still available in memory.",
exc_info=True,
)
failed_blocks.append(block.name)
if failed_blocks:
logger.error(
f"Failed to sync {len(failed_blocks)} block(s) to database: "
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
)
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> AnyBlockSchema | None:

View File

@@ -0,0 +1,16 @@
"""Validation utilities."""
import re
_UUID_V4_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
def is_uuid_v4(text: str) -> bool:
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
def extract_uuids(text: str) -> list[str]:
return sorted({m.lower() for m in _UUID_V4_PATTERN.findall(text)})