mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 20:05:11 -05:00
Compare commits
1 Commits
otto/copil
...
otto/secrt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6b76e672c |
@@ -17,13 +17,6 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
SSE_RESPONSE_HEADERS = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,48 +32,6 @@ async def _validate_and_get_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _create_stream_generator(
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Create SSE event generator for chat streaming."""
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
context=context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -270,17 +221,49 @@ async def stream_chat_post(
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
return StreamingResponse(
|
||||
_create_stream_generator(
|
||||
session_id=session_id,
|
||||
message=request.message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
),
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=SSE_RESPONSE_HEADERS,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -312,16 +295,48 @@ async def stream_chat_get(
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
return StreamingResponse(
|
||||
_create_stream_generator(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
),
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=SSE_RESPONSE_HEADERS,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Save binary block outputs to workspace, return references instead of base64.
|
||||
|
||||
This module post-processes block execution outputs to detect and save binary
|
||||
content (from code execution results) to the workspace, returning workspace://
|
||||
references instead of raw base64 data. This reduces LLM output token usage
|
||||
by ~97% for file generation tasks.
|
||||
|
||||
Detection is field-name based, targeting the standard e2b CodeExecutionResult
|
||||
fields: png, jpeg, pdf, svg. Other image-producing blocks already use
|
||||
store_media_file() and don't need this post-processing.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.util.file import sanitize_filename
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Field names that contain binary data (base64 encoded)
|
||||
BINARY_FIELDS = {"png", "jpeg", "pdf"}
|
||||
|
||||
# Field names that contain large text data (not base64, save as-is)
|
||||
TEXT_FIELDS = {"svg"}
|
||||
|
||||
# Combined set for quick lookup
|
||||
SAVEABLE_FIELDS = BINARY_FIELDS | TEXT_FIELDS
|
||||
|
||||
# Only process content larger than this (string length, not decoded size)
|
||||
SIZE_THRESHOLD = 1024 # 1KB
|
||||
|
||||
|
||||
async def process_binary_outputs(
|
||||
outputs: dict[str, list[Any]],
|
||||
workspace_manager: WorkspaceManager,
|
||||
block_name: str,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Replace binary data in block outputs with workspace:// references.
|
||||
|
||||
Scans outputs for known binary field names (png, jpeg, pdf, svg) and saves
|
||||
large content to the workspace. Returns processed outputs with base64 data
|
||||
replaced by workspace:// references.
|
||||
|
||||
Deduplicates identical content within a single call using content hashing.
|
||||
|
||||
Args:
|
||||
outputs: Block execution outputs (dict of output_name -> list of values)
|
||||
workspace_manager: WorkspaceManager instance with session scoping
|
||||
block_name: Name of the block (used in generated filenames)
|
||||
|
||||
Returns:
|
||||
Processed outputs with binary data replaced by workspace references
|
||||
"""
|
||||
cache: dict[str, str] = {} # content_hash -> workspace_ref
|
||||
|
||||
processed: dict[str, list[Any]] = {}
|
||||
for name, items in outputs.items():
|
||||
processed_items: list[Any] = []
|
||||
for item in items:
|
||||
processed_items.append(
|
||||
await _process_item(item, workspace_manager, block_name, cache)
|
||||
)
|
||||
processed[name] = processed_items
|
||||
return processed
|
||||
|
||||
|
||||
async def _process_item(
|
||||
item: Any,
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
cache: dict[str, str],
|
||||
) -> Any:
|
||||
"""Recursively process an item, handling dicts and lists."""
|
||||
if isinstance(item, dict):
|
||||
return await _process_dict(item, wm, block, cache)
|
||||
if isinstance(item, list):
|
||||
processed: list[Any] = []
|
||||
for i in item:
|
||||
processed.append(await _process_item(i, wm, block, cache))
|
||||
return processed
|
||||
return item
|
||||
|
||||
|
||||
async def _process_dict(
|
||||
data: dict[str, Any],
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
cache: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""Process a dict, saving binary fields and recursing into nested structures."""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if (
|
||||
key in SAVEABLE_FIELDS
|
||||
and isinstance(value, str)
|
||||
and len(value) > SIZE_THRESHOLD
|
||||
):
|
||||
# Determine content bytes based on field type
|
||||
if key in BINARY_FIELDS:
|
||||
content = _decode_base64(value)
|
||||
if content is None:
|
||||
# Decode failed, keep original value
|
||||
result[key] = value
|
||||
continue
|
||||
else:
|
||||
# TEXT_FIELDS: encode as UTF-8
|
||||
content = value.encode("utf-8")
|
||||
|
||||
# Hash decoded content for deduplication
|
||||
content_hash = hashlib.sha256(content).hexdigest()
|
||||
|
||||
if content_hash in cache:
|
||||
# Reuse existing workspace reference
|
||||
result[key] = cache[content_hash]
|
||||
elif ref := await _save_content(content, key, wm, block):
|
||||
# Save succeeded, cache and use reference
|
||||
cache[content_hash] = ref
|
||||
result[key] = ref
|
||||
else:
|
||||
# Save failed, keep original value
|
||||
result[key] = value
|
||||
|
||||
elif isinstance(value, dict):
|
||||
result[key] = await _process_dict(value, wm, block, cache)
|
||||
elif isinstance(value, list):
|
||||
processed: list[Any] = []
|
||||
for i in value:
|
||||
processed.append(await _process_item(i, wm, block, cache))
|
||||
result[key] = processed
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _save_content(
|
||||
content: bytes,
|
||||
field: str,
|
||||
wm: WorkspaceManager,
|
||||
block: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
Save content to workspace, return workspace:// reference.
|
||||
|
||||
Args:
|
||||
content: Decoded binary content to save
|
||||
field: Field name (used for extension)
|
||||
wm: WorkspaceManager instance
|
||||
block: Block name (used in filename)
|
||||
|
||||
Returns:
|
||||
workspace://file-id reference, or None if save failed
|
||||
"""
|
||||
try:
|
||||
# Map field name to file extension
|
||||
ext = {"jpeg": "jpg"}.get(field, field)
|
||||
|
||||
# Sanitize block name for safe filename
|
||||
safe_block = sanitize_filename(block.lower())[:20]
|
||||
filename = f"{safe_block}_{field}_{uuid.uuid4().hex[:12]}.{ext}"
|
||||
|
||||
file = await wm.write_file(content=content, filename=filename)
|
||||
return f"workspace://{file.id}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {field} to workspace for block '{block}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _decode_base64(value: str) -> bytes | None:
|
||||
"""
|
||||
Decode base64 string, handling both raw base64 and data URI formats.
|
||||
|
||||
Args:
|
||||
value: Base64 string or data URI (data:<mime>;base64,<payload>)
|
||||
|
||||
Returns:
|
||||
Decoded bytes, or None if decoding failed
|
||||
"""
|
||||
try:
|
||||
# Handle data URI format
|
||||
if value.startswith("data:"):
|
||||
if "," in value:
|
||||
value = value.split(",", 1)[1]
|
||||
else:
|
||||
# Malformed data URI, no comma separator
|
||||
return None
|
||||
|
||||
# Normalize padding (handle missing = chars)
|
||||
padded = value + "=" * (-len(value) % 4)
|
||||
|
||||
# Strict validation to prevent corrupted data
|
||||
return base64.b64decode(padded, validate=True)
|
||||
|
||||
except (binascii.Error, ValueError):
|
||||
return None
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract input field info from JSON schema."""
|
||||
if not isinstance(input_schema, dict):
|
||||
return []
|
||||
|
||||
exclude = exclude_fields or set()
|
||||
properties = input_schema.get("properties", {})
|
||||
required = set(input_schema.get("required", []))
|
||||
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
"title": schema.get("title", name),
|
||||
"type": schema.get("type", "string"),
|
||||
"description": schema.get("description", ""),
|
||||
"required": name in required,
|
||||
"default": schema.get("default"),
|
||||
}
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"inputs": self._get_inputs_list(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract inputs list from schema."""
|
||||
inputs_list = []
|
||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||
for field_name, field_schema in input_schema["properties"].items():
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in input_schema.get("required", []),
|
||||
}
|
||||
)
|
||||
return inputs_list
|
||||
|
||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||
"""Get available execution modes for the graph."""
|
||||
trigger_info = graph.trigger_setup_info
|
||||
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
|
||||
suffix: str,
|
||||
) -> str:
|
||||
"""Build a message describing available inputs for an agent."""
|
||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||
|
||||
|
||||
@@ -8,15 +8,18 @@ from typing import Any
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.binary_output_processor import (
|
||||
process_binary_outputs,
|
||||
)
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
@@ -25,10 +28,7 @@ from .models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -77,22 +77,41 @@ class RunBlockTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
input_data = input_data or {}
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
# Get discriminator from input, falling back to schema default
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
@@ -111,34 +130,37 @@ class RunBlockTool(BaseTool):
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in effective_field_info.provider
|
||||
and cred.type in effective_field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return resolved
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(effective_field_info.provider), "unknown")
|
||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
@@ -303,11 +325,20 @@ class RunBlockTool(BaseTool):
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Save binary outputs to workspace to prevent context bloat
|
||||
# (code execution results with png/jpeg/pdf/svg fields)
|
||||
workspace_manager = WorkspaceManager(
|
||||
user_id, workspace.id, session.session_id
|
||||
)
|
||||
processed_outputs = await process_binary_outputs(
|
||||
dict(outputs), workspace_manager, block.name
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
outputs=processed_outputs,
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -329,6 +360,27 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
"""Unit tests for binary_output_processor module."""
|
||||
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.binary_output_processor import (
|
||||
_decode_base64,
|
||||
process_binary_outputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_manager():
|
||||
"""Create a mock WorkspaceManager."""
|
||||
mock = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-123"
|
||||
mock.write_file = AsyncMock(return_value=mock_file)
|
||||
return mock
|
||||
|
||||
|
||||
class TestDecodeBase64:
|
||||
"""Tests for _decode_base64 function."""
|
||||
|
||||
def test_raw_base64(self):
|
||||
"""Decode raw base64 string."""
|
||||
encoded = base64.b64encode(b"test content").decode()
|
||||
result = _decode_base64(encoded)
|
||||
assert result == b"test content"
|
||||
|
||||
def test_data_uri(self):
|
||||
"""Decode base64 from data URI format."""
|
||||
content = b"test content"
|
||||
encoded = base64.b64encode(content).decode()
|
||||
data_uri = f"data:image/png;base64,{encoded}"
|
||||
result = _decode_base64(data_uri)
|
||||
assert result == content
|
||||
|
||||
def test_invalid_base64(self):
|
||||
"""Return None for invalid base64."""
|
||||
result = _decode_base64("not valid base64!!!")
|
||||
assert result is None
|
||||
|
||||
def test_missing_padding(self):
|
||||
"""Handle base64 with missing padding."""
|
||||
# base64.b64encode(b"test") = "dGVzdA=="
|
||||
# Remove padding
|
||||
result = _decode_base64("dGVzdA")
|
||||
assert result == b"test"
|
||||
|
||||
def test_malformed_data_uri(self):
|
||||
"""Return None for data URI without comma."""
|
||||
result = _decode_base64("data:image/png;base64")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestProcessBinaryOutputs:
|
||||
"""Tests for process_binary_outputs function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saves_large_png(self, workspace_manager):
|
||||
"""Large PNG content should be saved to workspace."""
|
||||
# Create content larger than SIZE_THRESHOLD (1KB)
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {"result": [{"png": encoded}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["png"] == "workspace://file-123"
|
||||
workspace_manager.write_file.assert_called_once()
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
assert call_kwargs["content"] == large_content
|
||||
assert "testblock_png_" in call_kwargs["filename"]
|
||||
assert call_kwargs["filename"].endswith(".png")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_small_content(self, workspace_manager):
|
||||
"""Small content should be preserved as-is."""
|
||||
small_content = base64.b64encode(b"tiny").decode()
|
||||
outputs = {"result": [{"png": small_content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["png"] == small_content
|
||||
workspace_manager.write_file.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_identical_content(self, workspace_manager):
|
||||
"""Identical content should only be saved once."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {
|
||||
"main_result": [{"png": encoded}],
|
||||
"results": [{"png": encoded}],
|
||||
}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
# Both should have the same workspace reference
|
||||
assert result["main_result"][0]["png"] == "workspace://file-123"
|
||||
assert result["results"][0]["png"] == "workspace://file-123"
|
||||
# But write_file should only be called once
|
||||
assert workspace_manager.write_file.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_on_save_failure(self, workspace_manager):
|
||||
"""Original content should be preserved if save fails."""
|
||||
workspace_manager.write_file = AsyncMock(side_effect=Exception("Save failed"))
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {"result": [{"png": encoded}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
# Original content should be preserved
|
||||
assert result["result"][0]["png"] == encoded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_nested_structures(self, workspace_manager):
|
||||
"""Should traverse nested dicts and lists."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {
|
||||
"result": [
|
||||
{
|
||||
"nested": {
|
||||
"deep": {
|
||||
"png": encoded,
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["nested"]["deep"]["png"] == "workspace://file-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_svg_as_text(self, workspace_manager):
|
||||
"""SVG should be saved as UTF-8 text, not base64 decoded."""
|
||||
svg_content = "<svg>" + "x" * 2000 + "</svg>"
|
||||
outputs = {"result": [{"svg": svg_content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["svg"] == "workspace://file-123"
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
# SVG should be UTF-8 encoded, not base64 decoded
|
||||
assert call_kwargs["content"] == svg_content.encode("utf-8")
|
||||
assert call_kwargs["filename"].endswith(".svg")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_unknown_fields(self, workspace_manager):
|
||||
"""Fields not in SAVEABLE_FIELDS should be ignored."""
|
||||
large_content = "x" * 2000 # Large text in an unknown field
|
||||
outputs = {"result": [{"unknown_field": large_content}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["unknown_field"] == large_content
|
||||
workspace_manager.write_file.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_jpeg_extension(self, workspace_manager):
|
||||
"""JPEG files should use .jpg extension."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
outputs = {"result": [{"jpeg": encoded}]}
|
||||
|
||||
await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
assert call_kwargs["filename"].endswith(".jpg")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_data_uri_in_binary_field(self, workspace_manager):
|
||||
"""Data URI format in binary fields should be properly decoded."""
|
||||
large_content = b"x" * 2000
|
||||
encoded = base64.b64encode(large_content).decode()
|
||||
data_uri = f"data:image/png;base64,{encoded}"
|
||||
outputs = {"result": [{"png": data_uri}]}
|
||||
|
||||
result = await process_binary_outputs(outputs, workspace_manager, "TestBlock")
|
||||
|
||||
assert result["result"][0]["png"] == "workspace://file-123"
|
||||
call_kwargs = workspace_manager.write_file.call_args.kwargs
|
||||
assert call_kwargs["content"] == large_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_base64_preserves_original(self, workspace_manager):
|
||||
"""Invalid base64 in a binary field should preserve the original value."""
|
||||
invalid_content = "not valid base64!!!" + "x" * 2000
|
||||
outputs = {"result": [{"png": invalid_content}]}
|
||||
|
||||
processed = await process_binary_outputs(
|
||||
outputs, workspace_manager, "TestBlock"
|
||||
)
|
||||
|
||||
assert processed["result"][0]["png"] == invalid_content
|
||||
workspace_manager.write_file.assert_not_called()
|
||||
@@ -225,95 +225,6 @@ async def get_or_create_library_agent(
|
||||
return library_agents[0]
|
||||
|
||||
|
||||
async def match_credentials_to_requirements(
|
||||
user_id: str,
|
||||
requirements: dict[str, CredentialsFieldInfo],
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Match user's credentials against a dictionary of credential requirements.
|
||||
|
||||
This is the core matching logic shared by both graph and block credential matching.
|
||||
"""
|
||||
matched: dict[str, CredentialsMetaInput] = {}
|
||||
missing: list[CredentialsMetaInput] = []
|
||||
|
||||
if not requirements:
|
||||
return matched, missing
|
||||
|
||||
available_creds = await get_user_credentials(user_id)
|
||||
|
||||
for field_name, field_info in requirements.items():
|
||||
matching_cred = find_matching_credential(available_creds, field_info)
|
||||
|
||||
if matching_cred:
|
||||
try:
|
||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||
f"credential_id={matching_cred.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=f"{field_name} (validation failed: {e})",
|
||||
)
|
||||
)
|
||||
else:
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched, missing
|
||||
|
||||
|
||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||
"""Get all available credentials for a user."""
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
return await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
|
||||
def find_matching_credential(
|
||||
available_creds: list[Credentials],
|
||||
field_info: CredentialsFieldInfo,
|
||||
) -> Credentials | None:
|
||||
"""Find a credential that matches the required provider, type, and scopes."""
|
||||
for cred in available_creds:
|
||||
if cred.provider not in field_info.provider:
|
||||
continue
|
||||
if cred.type not in field_info.supported_types:
|
||||
continue
|
||||
if not _credential_has_required_scopes(cred, field_info):
|
||||
continue
|
||||
return cred
|
||||
return None
|
||||
|
||||
|
||||
def create_credential_meta_from_match(
|
||||
matching_cred: Credentials,
|
||||
) -> CredentialsMetaInput:
|
||||
"""Create a CredentialsMetaInput from a matched credential."""
|
||||
return CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
|
||||
|
||||
async def match_user_credentials_to_graph(
|
||||
user_id: str,
|
||||
graph: GraphModel,
|
||||
@@ -410,11 +321,21 @@ def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if a credential has all the scopes required by the block."""
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
|
||||
@@ -873,13 +873,14 @@ def is_block_auth_configured(
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
sync_all_provider_costs()
|
||||
|
||||
@func_retry
|
||||
async def sync_block_to_db(block: Block) -> None:
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
existing_block = await AgentBlock.prisma().find_first(
|
||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||
)
|
||||
@@ -892,7 +893,7 @@ async def initialize_blocks() -> None:
|
||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||
)
|
||||
)
|
||||
return
|
||||
continue
|
||||
|
||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||
@@ -912,25 +913,6 @@ async def initialize_blocks() -> None:
|
||||
},
|
||||
)
|
||||
|
||||
failed_blocks: list[str] = []
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
try:
|
||||
await sync_block_to_db(block)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to sync block {block.name} to database: {e}. "
|
||||
"Block is still available in memory.",
|
||||
exc_info=True,
|
||||
)
|
||||
failed_blocks.append(block.name)
|
||||
|
||||
if failed_blocks:
|
||||
logger.error(
|
||||
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
||||
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
||||
)
|
||||
|
||||
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Validation utilities."""
|
||||
|
||||
import re
|
||||
|
||||
_UUID_V4_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_uuid_v4(text: str) -> bool:
|
||||
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
|
||||
|
||||
|
||||
def extract_uuids(text: str) -> list[str]:
|
||||
return sorted({m.lower() for m in _UUID_V4_PATTERN.findall(text)})
|
||||
Reference in New Issue
Block a user