diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 3e731d86ac..9847dc6090 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,26 +1,29 @@ """Chat API routes for chat session management and streaming via SSE.""" import logging -import uuid as uuid_module from collections.abc import AsyncGenerator from typing import Annotated from autogpt_libs import auth -from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security +from fastapi import APIRouter, Depends, Query, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError from . import service as chat_service -from . import stream_registry -from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions -from .response_model import StreamFinish, StreamHeartbeat, StreamStart config = ChatConfig() +SSE_RESPONSE_HEADERS = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", +} + logger = logging.getLogger(__name__) @@ -36,6 +39,48 @@ async def _validate_and_get_session( return session +async def _create_stream_generator( + session_id: str, + message: str, + user_id: str | None, + session: ChatSession, + is_user_message: bool = True, + context: dict[str, str] | None = None, +) -> AsyncGenerator[str, None]: + """Create SSE event generator for chat streaming.""" + chunk_count = 0 + first_chunk_type: str | None = None + async for chunk in chat_service.stream_chat_completion( + session_id, + message, + is_user_message=is_user_message, + user_id=user_id, + session=session, + context=context, + ): + if chunk_count < 3: + logger.info( + "Chat stream chunk", + extra={ + "session_id": session_id, + "chunk_type": str(chunk.type), + }, + ) + if not first_chunk_type: + first_chunk_type = str(chunk.type) + chunk_count += 1 + yield chunk.to_sse() + logger.info( + "Chat stream completed", + extra={ + "session_id": session_id, + "chunk_count": chunk_count, + "first_chunk_type": first_chunk_type, + }, + ) + yield "data: [DONE]\n\n" + + router = APIRouter( tags=["chat"], ) @@ -59,15 +104,6 @@ class CreateSessionResponse(BaseModel): user_id: str | None -class ActiveStreamInfo(BaseModel): - """Information about an active stream for reconnection.""" - - task_id: str - last_message_id: str # Redis Stream message ID for resumption - operation_id: str # Operation ID for completion tracking - tool_name: str # Name of the tool being executed - - class SessionDetailResponse(BaseModel): """Response model providing complete details for a chat session, including messages.""" @@ -76,7 +112,6 @@ class SessionDetailResponse(BaseModel): updated_at: str user_id: str | None messages: list[dict] - active_stream: ActiveStreamInfo | None = None # Present if stream is still active class SessionSummaryResponse(BaseModel): @@ -95,14 +130,6 @@ class ListSessionsResponse(BaseModel): total: int -class OperationCompleteRequest(BaseModel): - """Request model for external completion webhook.""" - - success: bool - result: dict | str | None = None - error: str | None = None - - # ========== Routes ========== @@ -188,14 +215,13 @@ async def get_session( Retrieve the details of a specific chat session. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. - If there's an active stream for this session, returns the task_id for reconnection. Args: session_id: The unique identifier for the desired chat session. user_id: The optional authenticated user ID, or None for anonymous access. Returns: - SessionDetailResponse: Details for the requested session, including active_stream info if applicable. + SessionDetailResponse: Details for the requested session, or None if not found. """ session = await get_chat_session(session_id, user_id) @@ -203,28 +229,11 @@ async def get_session( raise NotFoundError(f"Session {session_id} not found.") messages = [message.model_dump() for message in session.messages] - - # Check if there's an active stream for this session - active_stream_info = None - active_task, last_message_id = await stream_registry.get_active_task_for_session( - session_id, user_id + logger.info( + f"Returning session {session_id}: " + f"message_count={len(messages)}, " + f"roles={[m.get('role') for m in messages]}" ) - if active_task: - # Filter out the in-progress assistant message from the session response. - # The client will receive the complete assistant response through the SSE - # stream replay instead, preventing duplicate content. - if messages and messages[-1].get("role") == "assistant": - messages = messages[:-1] - - # Use "0-0" as last_message_id to replay the stream from the beginning. - # Since we filtered out the cached assistant message, the client needs - # the full stream to reconstruct the response. - active_stream_info = ActiveStreamInfo( - task_id=active_task.task_id, - last_message_id="0-0", - operation_id=active_task.operation_id, - tool_name=active_task.tool_name, - ) return SessionDetailResponse( id=session.session_id, @@ -232,7 +241,6 @@ async def get_session( updated_at=session.updated_at.isoformat(), user_id=session.user_id or None, messages=messages, - active_stream=active_stream_info, ) @@ -252,122 +260,27 @@ async def stream_chat_post( - Tool call UI elements (if invoked) - Tool execution results - The AI generation runs in a background task that continues even if the client disconnects. - All chunks are written to Redis for reconnection support. If the client disconnects, - they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off. - Args: session_id: The chat session identifier to associate with the streamed messages. request: Request body containing message, is_user_message, and optional context. user_id: Optional authenticated user ID. Returns: - StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event - containing the task_id for reconnection. + StreamingResponse: SSE-formatted response chunks. """ - import asyncio - session = await _validate_and_get_session(session_id, user_id) - # Create a task in the stream registry for reconnection support - task_id = str(uuid_module.uuid4()) - operation_id = str(uuid_module.uuid4()) - await stream_registry.create_task( - task_id=task_id, - session_id=session_id, - user_id=user_id, - tool_call_id="chat_stream", # Not a tool call, but needed for the model - tool_name="chat", - operation_id=operation_id, - ) - - # Background task that runs the AI generation independently of SSE connection - async def run_ai_generation(): - try: - # Emit a start event with task_id for reconnection - start_chunk = StreamStart(messageId=task_id, taskId=task_id) - await stream_registry.publish_chunk(task_id, start_chunk) - - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - # Write to Redis (subscribers will receive via XREAD) - await stream_registry.publish_chunk(task_id, chunk) - - # Mark task as completed - await stream_registry.mark_task_completed(task_id, "completed") - except Exception as e: - logger.error( - f"Error in background AI generation for session {session_id}: {e}" - ) - await stream_registry.mark_task_completed(task_id, "failed") - - # Start the AI generation in a background task - bg_task = asyncio.create_task(run_ai_generation()) - await stream_registry.set_task_asyncio_task(task_id, bg_task) - - # SSE endpoint that subscribes to the task's stream - async def event_generator() -> AsyncGenerator[str, None]: - subscriber_queue = None - try: - # Subscribe to the task stream (this replays existing messages + live updates) - subscriber_queue = await stream_registry.subscribe_to_task( - task_id=task_id, - user_id=user_id, - last_message_id="0-0", # Get all messages from the beginning - ) - - if subscriber_queue is None: - yield StreamFinish().to_sse() - yield "data: [DONE]\n\n" - return - - # Read from the subscriber queue and yield to SSE - while True: - try: - chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) - yield chunk.to_sse() - - # Check for finish signal - if isinstance(chunk, StreamFinish): - break - except asyncio.TimeoutError: - # Send heartbeat to keep connection alive - yield StreamHeartbeat().to_sse() - - except GeneratorExit: - pass # Client disconnected - background task continues - except Exception as e: - logger.error(f"Error in SSE stream for task {task_id}: {e}") - finally: - # Unsubscribe when client disconnects or stream ends to prevent resource leak - if subscriber_queue is not None: - try: - await stream_registry.unsubscribe_from_task( - task_id, subscriber_queue - ) - except Exception as unsub_err: - logger.error( - f"Error unsubscribing from task {task_id}: {unsub_err}", - exc_info=True, - ) - # AI SDK protocol termination - always yield even if unsubscribe fails - yield "data: [DONE]\n\n" - return StreamingResponse( - event_generator(), + _create_stream_generator( + session_id=session_id, + message=request.message, + user_id=user_id, + session=session, + is_user_message=request.is_user_message, + context=request.context, + ), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Disable nginx buffering - "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header - }, + headers=SSE_RESPONSE_HEADERS, ) @@ -399,48 +312,16 @@ async def stream_chat_get( """ session = await _validate_and_get_session(session_id, user_id) - async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - message, - is_user_message=is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" - return StreamingResponse( - event_generator(), + _create_stream_generator( + session_id=session_id, + message=message, + user_id=user_id, + session=session, + is_user_message=is_user_message, + ), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Disable nginx buffering - "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header - }, + headers=SSE_RESPONSE_HEADERS, ) @@ -470,251 +351,6 @@ async def session_assign_user( return {"status": "ok"} -# ========== Task Streaming (SSE Reconnection) ========== - - -@router.get( - "/tasks/{task_id}/stream", -) -async def stream_task( - task_id: str, - user_id: str | None = Depends(auth.get_user_id), - last_message_id: str = Query( - default="0-0", - description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", - ), -): - """ - Reconnect to a long-running task's SSE stream. - - When a long-running operation (like agent generation) starts, the client - receives a task_id. If the connection drops, the client can reconnect - using this endpoint to resume receiving updates. - - Args: - task_id: The task ID from the operation_started response. - user_id: Authenticated user ID for ownership validation. - last_message_id: Last Redis Stream message ID received ("0-0" for full replay). - - Returns: - StreamingResponse: SSE-formatted response chunks starting after last_message_id. - - Raises: - HTTPException: 404 if task not found, 410 if task expired, 403 if access denied. - """ - # Check task existence and expiry before subscribing - task, error_code = await stream_registry.get_task_with_expiry_info(task_id) - - if error_code == "TASK_EXPIRED": - raise HTTPException( - status_code=410, - detail={ - "code": "TASK_EXPIRED", - "message": "This operation has expired. Please try again.", - }, - ) - - if error_code == "TASK_NOT_FOUND": - raise HTTPException( - status_code=404, - detail={ - "code": "TASK_NOT_FOUND", - "message": f"Task {task_id} not found.", - }, - ) - - # Validate ownership if task has an owner - if task and task.user_id and user_id != task.user_id: - raise HTTPException( - status_code=403, - detail={ - "code": "ACCESS_DENIED", - "message": "You do not have access to this task.", - }, - ) - - # Get subscriber queue from stream registry - subscriber_queue = await stream_registry.subscribe_to_task( - task_id=task_id, - user_id=user_id, - last_message_id=last_message_id, - ) - - if subscriber_queue is None: - raise HTTPException( - status_code=404, - detail={ - "code": "TASK_NOT_FOUND", - "message": f"Task {task_id} not found or access denied.", - }, - ) - - async def event_generator() -> AsyncGenerator[str, None]: - import asyncio - - heartbeat_interval = 15.0 # Send heartbeat every 15 seconds - try: - while True: - try: - # Wait for next chunk with timeout for heartbeats - chunk = await asyncio.wait_for( - subscriber_queue.get(), timeout=heartbeat_interval - ) - yield chunk.to_sse() - - # Check for finish signal - if isinstance(chunk, StreamFinish): - break - except asyncio.TimeoutError: - # Send heartbeat to keep connection alive - yield StreamHeartbeat().to_sse() - except Exception as e: - logger.error(f"Error in task stream {task_id}: {e}", exc_info=True) - finally: - # Unsubscribe when client disconnects or stream ends - try: - await stream_registry.unsubscribe_from_task(task_id, subscriber_queue) - except Exception as unsub_err: - logger.error( - f"Error unsubscribing from task {task_id}: {unsub_err}", - exc_info=True, - ) - # AI SDK protocol termination - always yield even if unsubscribe fails - yield "data: [DONE]\n\n" - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - "x-vercel-ai-ui-message-stream": "v1", - }, - ) - - -@router.get( - "/tasks/{task_id}", -) -async def get_task_status( - task_id: str, - user_id: str | None = Depends(auth.get_user_id), -) -> dict: - """ - Get the status of a long-running task. - - Args: - task_id: The task ID to check. - user_id: Authenticated user ID for ownership validation. - - Returns: - dict: Task status including task_id, status, tool_name, and operation_id. - - Raises: - NotFoundError: If task_id is not found or user doesn't have access. - """ - task = await stream_registry.get_task(task_id) - - if task is None: - raise NotFoundError(f"Task {task_id} not found.") - - # Validate ownership - if task has an owner, requester must match - if task.user_id and user_id != task.user_id: - raise NotFoundError(f"Task {task_id} not found.") - - return { - "task_id": task.task_id, - "session_id": task.session_id, - "status": task.status, - "tool_name": task.tool_name, - "operation_id": task.operation_id, - "created_at": task.created_at.isoformat(), - } - - -# ========== External Completion Webhook ========== - - -@router.post( - "/operations/{operation_id}/complete", - status_code=200, -) -async def complete_operation( - operation_id: str, - request: OperationCompleteRequest, - x_api_key: str | None = Header(default=None), -) -> dict: - """ - External completion webhook for long-running operations. - - Called by Agent Generator (or other services) when an operation completes. - This triggers the stream registry to publish completion and continue LLM generation. - - Args: - operation_id: The operation ID to complete. - request: Completion payload with success status and result/error. - x_api_key: Internal API key for authentication. - - Returns: - dict: Status of the completion. - - Raises: - HTTPException: If API key is invalid or operation not found. - """ - # Validate internal API key - reject if not configured or invalid - if not config.internal_api_key: - logger.error( - "Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured" - ) - raise HTTPException( - status_code=503, - detail="Webhook not available: internal API key not configured", - ) - if x_api_key != config.internal_api_key: - raise HTTPException(status_code=401, detail="Invalid API key") - - # Find task by operation_id - task = await stream_registry.find_task_by_operation_id(operation_id) - if task is None: - raise HTTPException( - status_code=404, - detail=f"Operation {operation_id} not found", - ) - - logger.info( - f"Received completion webhook for operation {operation_id} " - f"(task_id={task.task_id}, success={request.success})" - ) - - if request.success: - await process_operation_success(task, request.result) - else: - await process_operation_failure(task, request.error) - - return {"status": "ok", "task_id": task.task_id} - - -# ========== Configuration ========== - - -@router.get("/config/ttl", status_code=200) -async def get_ttl_config() -> dict: - """ - Get the stream TTL configuration. - - Returns the Time-To-Live settings for chat streams, which determines - how long clients can reconnect to an active stream. - - Returns: - dict: TTL configuration with seconds and milliseconds values. - """ - return { - "stream_ttl_seconds": config.stream_ttl, - "stream_ttl_ms": config.stream_ttl * 1000, - } - - # ========== Health Check ========== diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py new file mode 100644 index 0000000000..c2d2d16769 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py @@ -0,0 +1,77 @@ +"""Shared helpers for chat tools.""" + +from typing import Any + +from .models import ErrorResponse + + +def error_response( + message: str, session_id: str | None, **kwargs: Any +) -> ErrorResponse: + """Create standardized error response. + + Args: + message: Error message to display + session_id: Current session ID + **kwargs: Additional fields to pass to ErrorResponse + + Returns: + ErrorResponse with the given message and session_id + """ + return ErrorResponse(message=message, session_id=session_id, **kwargs) + + +def get_inputs_from_schema( + input_schema: dict[str, Any], + exclude_fields: set[str] | None = None, +) -> list[dict[str, Any]]: + """Extract input field info from JSON schema. + + Args: + input_schema: JSON schema dict with 'properties' and 'required' + exclude_fields: Set of field names to exclude (e.g., credential fields) + + Returns: + List of dicts with field info (name, title, type, description, required, default) + """ + exclude = exclude_fields or set() + properties = input_schema.get("properties", {}) + required = set(input_schema.get("required", [])) + + return [ + { + "name": name, + "title": schema.get("title", name), + "type": schema.get("type", "string"), + "description": schema.get("description", ""), + "required": name in required, + "default": schema.get("default"), + } + for name, schema in properties.items() + if name not in exclude + ] + + +def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str: + """Format input fields as a readable markdown list. + + Args: + inputs: List of input dicts from get_inputs_from_schema + + Returns: + Markdown-formatted string listing the inputs + """ + if not inputs: + return "No inputs required." + + lines = [] + for inp in inputs: + required_marker = " (required)" if inp.get("required") else "" + default = inp.get("default") + default_info = f" [default: {default}]" if default is not None else "" + description = inp.get("description", "") + desc_info = f" - {description}" if description else "" + + lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}") + + return "\n".join(lines) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index 73d4cf81f2..bc9aed4a35 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -24,6 +24,7 @@ from backend.util.timezone_utils import ( ) from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( AgentDetails, AgentDetailsResponse, @@ -371,19 +372,7 @@ class RunAgentTool(BaseTool): def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]: """Extract inputs list from schema.""" - inputs_list = [] - if isinstance(input_schema, dict) and "properties" in input_schema: - for field_name, field_schema in input_schema["properties"].items(): - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in input_schema.get("required", []), - } - ) - return inputs_list + return get_inputs_from_schema(input_schema) def _get_execution_modes(self, graph: GraphModel) -> list[str]: """Get available execution modes for the graph.""" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 51bb2c0575..c4e6e5ffc4 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -10,12 +10,13 @@ from pydantic_core import PydanticUndefined from backend.api.features.chat.model import ChatSession from backend.data.block import get_block from backend.data.execution import ExecutionContext -from backend.data.model import CredentialsMetaInput +from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( BlockOutputResponse, ErrorResponse, @@ -24,7 +25,10 @@ from .models import ( ToolResponseBase, UserReadiness, ) -from .utils import build_missing_credentials_from_field_info +from .utils import ( + build_missing_credentials_from_field_info, + match_credentials_to_requirements, +) logger = logging.getLogger(__name__) @@ -73,41 +77,22 @@ class RunBlockTool(BaseTool): def requires_auth(self) -> bool: return True - async def _check_block_credentials( + def _resolve_discriminated_credentials( self, - user_id: str, block: Any, - input_data: dict[str, Any] | None = None, - ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: - """ - Check if user has required credentials for a block. - - Args: - user_id: User ID - block: Block to check credentials for - input_data: Input data for the block (used to determine provider via discriminator) - - Returns: - tuple[matched_credentials, missing_credentials] - """ - matched_credentials: dict[str, CredentialsMetaInput] = {} - missing_credentials: list[CredentialsMetaInput] = [] - input_data = input_data or {} - - # Get credential field info from block's input schema + input_data: dict[str, Any], + ) -> dict[str, CredentialsFieldInfo]: + """Resolve credential requirements, applying discriminator logic where needed.""" credentials_fields_info = block.input_schema.get_credentials_fields_info() - if not credentials_fields_info: - return matched_credentials, missing_credentials + return {} - # Get user's available credentials - creds_manager = IntegrationCredentialsManager() - available_creds = await creds_manager.store.get_all_creds(user_id) + resolved: dict[str, CredentialsFieldInfo] = {} for field_name, field_info in credentials_fields_info.items(): effective_field_info = field_info + if field_info.discriminator and field_info.discriminator_mapping: - # Get discriminator from input, falling back to schema default discriminator_value = input_data.get(field_info.discriminator) if discriminator_value is None: field = block.input_schema.model_fields.get( @@ -126,37 +111,34 @@ class RunBlockTool(BaseTool): f"{discriminator_value} -> {effective_field_info.provider}" ) - matching_cred = next( - ( - cred - for cred in available_creds - if cred.provider in effective_field_info.provider - and cred.type in effective_field_info.supported_types - ), - None, - ) + resolved[field_name] = effective_field_info - if matching_cred: - matched_credentials[field_name] = CredentialsMetaInput( - id=matching_cred.id, - provider=matching_cred.provider, # type: ignore - type=matching_cred.type, - title=matching_cred.title, - ) - else: - # Create a placeholder for the missing credential - provider = next(iter(effective_field_info.provider), "unknown") - cred_type = next(iter(effective_field_info.supported_types), "api_key") - missing_credentials.append( - CredentialsMetaInput( - id=field_name, - provider=provider, # type: ignore - type=cred_type, # type: ignore - title=field_name.replace("_", " ").title(), - ) - ) + return resolved - return matched_credentials, missing_credentials + async def _check_block_credentials( + self, + user_id: str, + block: Any, + input_data: dict[str, Any] | None = None, + ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Check if user has required credentials for a block. + + Args: + user_id: User ID + block: Block to check credentials for + input_data: Input data for the block (used to determine provider via discriminator) + + Returns: + tuple[matched_credentials, missing_credentials] + """ + input_data = input_data or {} + requirements = self._resolve_discriminated_credentials(block, input_data) + + if not requirements: + return {}, [] + + return await match_credentials_to_requirements(user_id, requirements) async def _execute( self, @@ -347,27 +329,6 @@ class RunBlockTool(BaseTool): def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]: """Extract non-credential inputs from block schema.""" - inputs_list = [] schema = block.input_schema.jsonschema() - properties = schema.get("properties", {}) - required_fields = set(schema.get("required", [])) - - # Get credential field names to exclude credentials_fields = set(block.input_schema.get_credentials_fields().keys()) - - for field_name, field_schema in properties.items(): - # Skip credential fields - if field_name in credentials_fields: - continue - - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in required_fields, - } - ) - - return inputs_list + return get_inputs_from_schema(schema, exclude_fields=credentials_fields) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index cda0914809..1f28f63913 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -223,6 +223,93 @@ async def get_or_create_library_agent( return library_agents[0] +async def get_user_credentials(user_id: str) -> list: + """Get all available credentials for a user.""" + creds_manager = IntegrationCredentialsManager() + return await creds_manager.store.get_all_creds(user_id) + + +def find_matching_credential( + available_creds: list, + field_info: CredentialsFieldInfo, +): + """Find a credential that matches the required provider, type, and scopes.""" + for cred in available_creds: + if cred.provider not in field_info.provider: + continue + if cred.type not in field_info.supported_types: + continue + if not _credential_has_required_scopes(cred, field_info): + continue + return cred + return None + + +def create_credential_meta_from_match(matching_cred) -> CredentialsMetaInput: + """Create a CredentialsMetaInput from a matched credential.""" + return CredentialsMetaInput( + id=matching_cred.id, + provider=matching_cred.provider, # type: ignore + type=matching_cred.type, + title=matching_cred.title, + ) + + +async def match_credentials_to_requirements( + user_id: str, + requirements: dict[str, CredentialsFieldInfo], +) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Match user's credentials against a dictionary of credential requirements. + + This is the core matching logic shared by both graph and block credential matching. + """ + matched: dict[str, CredentialsMetaInput] = {} + missing: list[CredentialsMetaInput] = [] + + if not requirements: + return matched, missing + + available_creds = await get_user_credentials(user_id) + + for field_name, field_info in requirements.items(): + matching_cred = find_matching_credential(available_creds, field_info) + + if matching_cred: + try: + matched[field_name] = create_credential_meta_from_match(matching_cred) + except Exception as e: + logger.error( + f"Failed to create CredentialsMetaInput for field '{field_name}': " + f"provider={matching_cred.provider}, type={matching_cred.type}, " + f"credential_id={matching_cred.id}", + exc_info=True, + ) + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=f"{field_name} (validation failed: {e})", + ) + ) + else: + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=field_name.replace("_", " ").title(), + ) + ) + + return matched, missing + + async def match_user_credentials_to_graph( user_id: str, graph: GraphModel, diff --git a/autogpt_platform/backend/backend/util/validation.py b/autogpt_platform/backend/backend/util/validation.py new file mode 100644 index 0000000000..3c22bc3c4c --- /dev/null +++ b/autogpt_platform/backend/backend/util/validation.py @@ -0,0 +1,32 @@ +"""Validation utilities.""" + +import re + +_UUID_V4_PATTERN = re.compile( + r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}", + re.IGNORECASE, +) + + +def is_uuid_v4(text: str) -> bool: + """Check if text is a valid UUID v4. + + Args: + text: String to validate + + Returns: + True if the text is a valid UUID v4, False otherwise + """ + return bool(_UUID_V4_PATTERN.fullmatch(text.strip())) + + +def extract_uuids(text: str) -> list[str]: + """Extract all UUID v4 strings from text. + + Args: + text: String to search for UUIDs + + Returns: + List of unique UUIDs found (lowercase) + """ + return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})