diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index cab51543b1..b1b865d1dc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -17,6 +17,14 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_ config = ChatConfig() +# SSE response headers for streaming +SSE_RESPONSE_HEADERS = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # Disable nginx buffering + "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header +} + logger = logging.getLogger(__name__) @@ -32,6 +40,61 @@ async def _validate_and_get_session( return session +async def _create_stream_generator( + session_id: str, + message: str, + user_id: str | None, + session: ChatSession, + is_user_message: bool = True, + context: dict[str, str] | None = None, +) -> AsyncGenerator[str, None]: + """Create SSE event generator for chat streaming. + + Args: + session_id: Chat session ID + message: User message to process + user_id: Optional authenticated user ID + session: Pre-fetched chat session + is_user_message: Whether the message is from a user + context: Optional context dict with url and content + + Yields: + SSE-formatted chunks from the chat completion stream + """ + chunk_count = 0 + first_chunk_type: str | None = None + async for chunk in chat_service.stream_chat_completion( + session_id, + message, + is_user_message=is_user_message, + user_id=user_id, + session=session, + context=context, + ): + if chunk_count < 3: + logger.info( + "Chat stream chunk", + extra={ + "session_id": session_id, + "chunk_type": str(chunk.type), + }, + ) + if not first_chunk_type: + first_chunk_type = str(chunk.type) + chunk_count += 1 + yield chunk.to_sse() + logger.info( + "Chat stream completed", + extra={ + "session_id": session_id, + "chunk_count": chunk_count, + "first_chunk_type": first_chunk_type, + }, + ) + # AI SDK protocol termination + yield "data: [DONE]\n\n" + + router = APIRouter( tags=["chat"], ) @@ -221,49 +284,17 @@ async def stream_chat_post( """ session = await _validate_and_get_session(session_id, user_id) - async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" - return StreamingResponse( - event_generator(), + _create_stream_generator( + session_id=session_id, + message=request.message, + user_id=user_id, + session=session, + is_user_message=request.is_user_message, + context=request.context, + ), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Disable nginx buffering - "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header - }, + headers=SSE_RESPONSE_HEADERS, ) @@ -295,48 +326,16 @@ async def stream_chat_get( """ session = await _validate_and_get_session(session_id, user_id) - async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - message, - is_user_message=is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" - return StreamingResponse( - event_generator(), + _create_stream_generator( + session_id=session_id, + message=message, + user_id=user_id, + session=session, + is_user_message=is_user_message, + ), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Disable nginx buffering - "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header - }, + headers=SSE_RESPONSE_HEADERS, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py new file mode 100644 index 0000000000..c2d2d16769 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py @@ -0,0 +1,77 @@ +"""Shared helpers for chat tools.""" + +from typing import Any + +from .models import ErrorResponse + + +def error_response( + message: str, session_id: str | None, **kwargs: Any +) -> ErrorResponse: + """Create standardized error response. + + Args: + message: Error message to display + session_id: Current session ID + **kwargs: Additional fields to pass to ErrorResponse + + Returns: + ErrorResponse with the given message and session_id + """ + return ErrorResponse(message=message, session_id=session_id, **kwargs) + + +def get_inputs_from_schema( + input_schema: dict[str, Any], + exclude_fields: set[str] | None = None, +) -> list[dict[str, Any]]: + """Extract input field info from JSON schema. + + Args: + input_schema: JSON schema dict with 'properties' and 'required' + exclude_fields: Set of field names to exclude (e.g., credential fields) + + Returns: + List of dicts with field info (name, title, type, description, required, default) + """ + exclude = exclude_fields or set() + properties = input_schema.get("properties", {}) + required = set(input_schema.get("required", [])) + + return [ + { + "name": name, + "title": schema.get("title", name), + "type": schema.get("type", "string"), + "description": schema.get("description", ""), + "required": name in required, + "default": schema.get("default"), + } + for name, schema in properties.items() + if name not in exclude + ] + + +def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str: + """Format input fields as a readable markdown list. + + Args: + inputs: List of input dicts from get_inputs_from_schema + + Returns: + Markdown-formatted string listing the inputs + """ + if not inputs: + return "No inputs required." + + lines = [] + for inp in inputs: + required_marker = " (required)" if inp.get("required") else "" + default = inp.get("default") + default_info = f" [default: {default}]" if default is not None else "" + description = inp.get("description", "") + desc_info = f" - {description}" if description else "" + + lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}") + + return "\n".join(lines) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index a7fa65348a..c0c7dcd49b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -24,6 +24,7 @@ from backend.util.timezone_utils import ( ) from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( AgentDetails, AgentDetailsResponse, @@ -354,19 +355,7 @@ class RunAgentTool(BaseTool): def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]: """Extract inputs list from schema.""" - inputs_list = [] - if isinstance(input_schema, dict) and "properties" in input_schema: - for field_name, field_schema in input_schema["properties"].items(): - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in input_schema.get("required", []), - } - ) - return inputs_list + return get_inputs_from_schema(input_schema) def _get_execution_modes(self, graph: GraphModel) -> list[str]: """Get available execution modes for the graph.""" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index a59082b399..8b467104f9 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -8,12 +8,13 @@ from typing import Any from backend.api.features.chat.model import ChatSession from backend.data.block import get_block from backend.data.execution import ExecutionContext -from backend.data.model import CredentialsMetaInput +from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( BlockOutputResponse, ErrorResponse, @@ -22,7 +23,7 @@ from .models import ( ToolResponseBase, UserReadiness, ) -from .utils import build_missing_credentials_from_field_info +from .utils import build_missing_credentials_from_field_info, match_credentials_to_requirements logger = logging.getLogger(__name__) @@ -71,6 +72,22 @@ class RunBlockTool(BaseTool): def requires_auth(self) -> bool: return True + def _get_credentials_requirements( + self, + block: Any, + ) -> dict[str, CredentialsFieldInfo]: + """ + Get credential requirements from block's input schema. + + Args: + block: Block to get credentials for + + Returns: + Dict mapping field names to CredentialsFieldInfo + """ + credentials_fields_info = block.input_schema.get_credentials_fields_info() + return credentials_fields_info if credentials_fields_info else {} + async def _check_block_credentials( self, user_id: str, @@ -82,53 +99,14 @@ class RunBlockTool(BaseTool): Returns: tuple[matched_credentials, missing_credentials] """ - matched_credentials: dict[str, CredentialsMetaInput] = {} - missing_credentials: list[CredentialsMetaInput] = [] + # Get credential requirements from block + requirements = self._get_credentials_requirements(block) - # Get credential field info from block's input schema - credentials_fields_info = block.input_schema.get_credentials_fields_info() + if not requirements: + return {}, [] - if not credentials_fields_info: - return matched_credentials, missing_credentials - - # Get user's available credentials - creds_manager = IntegrationCredentialsManager() - available_creds = await creds_manager.store.get_all_creds(user_id) - - for field_name, field_info in credentials_fields_info.items(): - # field_info.provider is a frozenset of acceptable providers - # field_info.supported_types is a frozenset of acceptable types - matching_cred = next( - ( - cred - for cred in available_creds - if cred.provider in field_info.provider - and cred.type in field_info.supported_types - ), - None, - ) - - if matching_cred: - matched_credentials[field_name] = CredentialsMetaInput( - id=matching_cred.id, - provider=matching_cred.provider, # type: ignore - type=matching_cred.type, - title=matching_cred.title, - ) - else: - # Create a placeholder for the missing credential - provider = next(iter(field_info.provider), "unknown") - cred_type = next(iter(field_info.supported_types), "api_key") - missing_credentials.append( - CredentialsMetaInput( - id=field_name, - provider=provider, # type: ignore - type=cred_type, # type: ignore - title=field_name.replace("_", " ").title(), - ) - ) - - return matched_credentials, missing_credentials + # Use shared matching logic + return await match_credentials_to_requirements(user_id, requirements) async def _execute( self, @@ -320,27 +298,7 @@ class RunBlockTool(BaseTool): def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]: """Extract non-credential inputs from block schema.""" - inputs_list = [] schema = block.input_schema.jsonschema() - properties = schema.get("properties", {}) - required_fields = set(schema.get("required", [])) - # Get credential field names to exclude credentials_fields = set(block.input_schema.get_credentials_fields().keys()) - - for field_name, field_schema in properties.items(): - # Skip credential fields - if field_name in credentials_fields: - continue - - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in required_fields, - } - ) - - return inputs_list + return get_inputs_from_schema(schema, exclude_fields=credentials_fields) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index 0046d0b249..812c6ccaba 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -225,6 +225,135 @@ async def get_or_create_library_agent( return library_agents[0] +async def get_user_credentials(user_id: str) -> list: + """ + Get all available credentials for a user. + + Args: + user_id: The user's ID + + Returns: + List of user's credentials + """ + creds_manager = IntegrationCredentialsManager() + return await creds_manager.store.get_all_creds(user_id) + + +def find_matching_credential( + available_creds: list, + required_providers: frozenset[str] | set[str], + required_types: frozenset[str] | set[str], +): + """ + Find a credential that matches the required provider and type. + + Args: + available_creds: List of user's available credentials + required_providers: Set of acceptable provider names + required_types: Set of acceptable credential types + + Returns: + Matching credential or None + """ + return next( + ( + cred + for cred in available_creds + if cred.provider in required_providers + and cred.type in required_types + ), + None, + ) + + +def create_credential_meta_from_match( + matching_cred, +) -> CredentialsMetaInput: + """ + Create a CredentialsMetaInput from a matched credential. + + Args: + matching_cred: The matched credential object + + Returns: + CredentialsMetaInput instance + """ + return CredentialsMetaInput( + id=matching_cred.id, + provider=matching_cred.provider, # type: ignore + type=matching_cred.type, + title=matching_cred.title, + ) + + +async def match_credentials_to_requirements( + user_id: str, + requirements: dict[str, CredentialsFieldInfo], +) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Match user's credentials against a dictionary of credential requirements. + + This is the core matching logic shared by both graph and block credential matching. + + Args: + user_id: The user's ID + requirements: Dict mapping field names to CredentialsFieldInfo + + Returns: + tuple[matched_credentials dict, missing_credentials list] + """ + matched: dict[str, CredentialsMetaInput] = {} + missing: list[CredentialsMetaInput] = [] + + if not requirements: + return matched, missing + + available_creds = await get_user_credentials(user_id) + + for field_name, field_info in requirements.items(): + matching_cred = find_matching_credential( + available_creds, + field_info.provider, + field_info.supported_types, + ) + + if matching_cred: + try: + matched[field_name] = create_credential_meta_from_match(matching_cred) + except Exception as e: + logger.error( + f"Failed to create CredentialsMetaInput for field '{field_name}': " + f"provider={matching_cred.provider}, type={matching_cred.type}, " + f"credential_id={matching_cred.id}", + exc_info=True, + ) + # Add to missing with validation error + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=f"{field_name} (validation failed: {e})", + ) + ) + else: + # Create a placeholder for the missing credential + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=field_name.replace("_", " ").title(), + ) + ) + + return matched, missing + + async def match_user_credentials_to_graph( user_id: str, graph: GraphModel, @@ -242,9 +371,6 @@ async def match_user_credentials_to_graph( Returns: tuple[matched_credentials dict, missing_credential_descriptions list] """ - graph_credentials_inputs: dict[str, CredentialsMetaInput] = {} - missing_creds: list[str] = [] - # Get aggregated credentials requirements from the graph aggregated_creds = graph.aggregate_credentials_inputs() logger.debug( @@ -252,69 +378,28 @@ async def match_user_credentials_to_graph( ) if not aggregated_creds: - return graph_credentials_inputs, missing_creds + return {}, [] - # Get all available credentials for the user - creds_manager = IntegrationCredentialsManager() - available_creds = await creds_manager.store.get_all_creds(user_id) + # Convert aggregated format to simple requirements dict + requirements = { + field_name: field_info + for field_name, (field_info, _node_fields) in aggregated_creds.items() + } - # For each required credential field, find a matching user credential - # field_info.provider is a frozenset because aggregate_credentials_inputs() - # combines requirements from multiple nodes. A credential matches if its - # provider is in the set of acceptable providers. - for credential_field_name, ( - credential_requirements, - _node_fields, - ) in aggregated_creds.items(): - # Find first matching credential by provider, type, and scopes - matching_cred = next( - ( - cred - for cred in available_creds - if cred.provider in credential_requirements.provider - and cred.type in credential_requirements.supported_types - and _credential_has_required_scopes(cred, credential_requirements) - ), - None, - ) + # Use shared matching logic + matched, missing_list = await match_credentials_to_requirements(user_id, requirements) - if matching_cred: - try: - graph_credentials_inputs[credential_field_name] = CredentialsMetaInput( - id=matching_cred.id, - provider=matching_cred.provider, # type: ignore - type=matching_cred.type, - title=matching_cred.title, - ) - except Exception as e: - logger.error( - f"Failed to create CredentialsMetaInput for field '{credential_field_name}': " - f"provider={matching_cred.provider}, type={matching_cred.type}, " - f"credential_id={matching_cred.id}", - exc_info=True, - ) - missing_creds.append( - f"{credential_field_name} (validation failed: {e})" - ) - else: - # Build a helpful error message including scope requirements - error_parts = [ - f"provider in {list(credential_requirements.provider)}", - f"type in {list(credential_requirements.supported_types)}", - ] - if credential_requirements.required_scopes: - error_parts.append( - f"scopes including {list(credential_requirements.required_scopes)}" - ) - missing_creds.append( - f"{credential_field_name} (requires {', '.join(error_parts)})" - ) + # Convert missing list to string descriptions for backward compatibility + missing_descriptions = [ + f"{cred.id} (requires provider={cred.provider}, type={cred.type})" + for cred in missing_list + ] logger.info( - f"Credential matching complete: {len(graph_credentials_inputs)}/{len(aggregated_creds)} matched" + f"Credential matching complete: {len(matched)}/{len(aggregated_creds)} matched" ) - return graph_credentials_inputs, missing_creds + return matched, missing_descriptions def _credential_has_required_scopes( diff --git a/autogpt_platform/backend/backend/util/validation.py b/autogpt_platform/backend/backend/util/validation.py new file mode 100644 index 0000000000..14e6c15a49 --- /dev/null +++ b/autogpt_platform/backend/backend/util/validation.py @@ -0,0 +1,33 @@ +"""Validation utilities.""" + +import re + +# UUID v4 pattern - matches standard UUID v4 format +_UUID_V4_PATTERN = re.compile( + r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}", + re.IGNORECASE, +) + + +def is_uuid_v4(text: str) -> bool: + """Check if text is a valid UUID v4. + + Args: + text: String to validate + + Returns: + True if the text is a valid UUID v4, False otherwise + """ + return bool(_UUID_V4_PATTERN.fullmatch(text.strip())) + + +def extract_uuids(text: str) -> list[str]: + """Extract all UUID v4 strings from text. + + Args: + text: String to search for UUIDs + + Returns: + List of unique UUIDs found (lowercase) + """ + return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})