diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index cab51543b1..9847dc6090 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -17,6 +17,13 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_ config = ChatConfig() +SSE_RESPONSE_HEADERS = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", +} + logger = logging.getLogger(__name__) @@ -32,6 +39,48 @@ async def _validate_and_get_session( return session +async def _create_stream_generator( + session_id: str, + message: str, + user_id: str | None, + session: ChatSession, + is_user_message: bool = True, + context: dict[str, str] | None = None, +) -> AsyncGenerator[str, None]: + """Create SSE event generator for chat streaming.""" + chunk_count = 0 + first_chunk_type: str | None = None + async for chunk in chat_service.stream_chat_completion( + session_id, + message, + is_user_message=is_user_message, + user_id=user_id, + session=session, + context=context, + ): + if chunk_count < 3: + logger.info( + "Chat stream chunk", + extra={ + "session_id": session_id, + "chunk_type": str(chunk.type), + }, + ) + if not first_chunk_type: + first_chunk_type = str(chunk.type) + chunk_count += 1 + yield chunk.to_sse() + logger.info( + "Chat stream completed", + extra={ + "session_id": session_id, + "chunk_count": chunk_count, + "first_chunk_type": first_chunk_type, + }, + ) + yield "data: [DONE]\n\n" + + router = APIRouter( tags=["chat"], ) @@ -221,49 +270,17 @@ async def stream_chat_post( """ session = await _validate_and_get_session(session_id, user_id) - async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" - return StreamingResponse( - event_generator(), + _create_stream_generator( + session_id=session_id, + message=request.message, + user_id=user_id, + session=session, + is_user_message=request.is_user_message, + context=request.context, + ), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Disable nginx buffering - "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header - }, + headers=SSE_RESPONSE_HEADERS, ) @@ -295,48 +312,16 @@ async def stream_chat_get( """ session = await _validate_and_get_session(session_id, user_id) - async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - message, - is_user_message=is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" - return StreamingResponse( - event_generator(), + _create_stream_generator( + session_id=session_id, + message=message, + user_id=user_id, + session=session, + is_user_message=is_user_message, + ), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Disable nginx buffering - "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header - }, + headers=SSE_RESPONSE_HEADERS, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py new file mode 100644 index 0000000000..c2d2d16769 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py @@ -0,0 +1,77 @@ +"""Shared helpers for chat tools.""" + +from typing import Any + +from .models import ErrorResponse + + +def error_response( + message: str, session_id: str | None, **kwargs: Any +) -> ErrorResponse: + """Create standardized error response. + + Args: + message: Error message to display + session_id: Current session ID + **kwargs: Additional fields to pass to ErrorResponse + + Returns: + ErrorResponse with the given message and session_id + """ + return ErrorResponse(message=message, session_id=session_id, **kwargs) + + +def get_inputs_from_schema( + input_schema: dict[str, Any], + exclude_fields: set[str] | None = None, +) -> list[dict[str, Any]]: + """Extract input field info from JSON schema. + + Args: + input_schema: JSON schema dict with 'properties' and 'required' + exclude_fields: Set of field names to exclude (e.g., credential fields) + + Returns: + List of dicts with field info (name, title, type, description, required, default) + """ + exclude = exclude_fields or set() + properties = input_schema.get("properties", {}) + required = set(input_schema.get("required", [])) + + return [ + { + "name": name, + "title": schema.get("title", name), + "type": schema.get("type", "string"), + "description": schema.get("description", ""), + "required": name in required, + "default": schema.get("default"), + } + for name, schema in properties.items() + if name not in exclude + ] + + +def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str: + """Format input fields as a readable markdown list. + + Args: + inputs: List of input dicts from get_inputs_from_schema + + Returns: + Markdown-formatted string listing the inputs + """ + if not inputs: + return "No inputs required." + + lines = [] + for inp in inputs: + required_marker = " (required)" if inp.get("required") else "" + default = inp.get("default") + default_info = f" [default: {default}]" if default is not None else "" + description = inp.get("description", "") + desc_info = f" - {description}" if description else "" + + lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}") + + return "\n".join(lines) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index 73d4cf81f2..bc9aed4a35 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -24,6 +24,7 @@ from backend.util.timezone_utils import ( ) from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( AgentDetails, AgentDetailsResponse, @@ -371,19 +372,7 @@ class RunAgentTool(BaseTool): def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]: """Extract inputs list from schema.""" - inputs_list = [] - if isinstance(input_schema, dict) and "properties" in input_schema: - for field_name, field_schema in input_schema["properties"].items(): - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in input_schema.get("required", []), - } - ) - return inputs_list + return get_inputs_from_schema(input_schema) def _get_execution_modes(self, graph: GraphModel) -> list[str]: """Get available execution modes for the graph.""" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 51bb2c0575..c4e6e5ffc4 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -10,12 +10,13 @@ from pydantic_core import PydanticUndefined from backend.api.features.chat.model import ChatSession from backend.data.block import get_block from backend.data.execution import ExecutionContext -from backend.data.model import CredentialsMetaInput +from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( BlockOutputResponse, ErrorResponse, @@ -24,7 +25,10 @@ from .models import ( ToolResponseBase, UserReadiness, ) -from .utils import build_missing_credentials_from_field_info +from .utils import ( + build_missing_credentials_from_field_info, + match_credentials_to_requirements, +) logger = logging.getLogger(__name__) @@ -73,41 +77,22 @@ class RunBlockTool(BaseTool): def requires_auth(self) -> bool: return True - async def _check_block_credentials( + def _resolve_discriminated_credentials( self, - user_id: str, block: Any, - input_data: dict[str, Any] | None = None, - ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: - """ - Check if user has required credentials for a block. - - Args: - user_id: User ID - block: Block to check credentials for - input_data: Input data for the block (used to determine provider via discriminator) - - Returns: - tuple[matched_credentials, missing_credentials] - """ - matched_credentials: dict[str, CredentialsMetaInput] = {} - missing_credentials: list[CredentialsMetaInput] = [] - input_data = input_data or {} - - # Get credential field info from block's input schema + input_data: dict[str, Any], + ) -> dict[str, CredentialsFieldInfo]: + """Resolve credential requirements, applying discriminator logic where needed.""" credentials_fields_info = block.input_schema.get_credentials_fields_info() - if not credentials_fields_info: - return matched_credentials, missing_credentials + return {} - # Get user's available credentials - creds_manager = IntegrationCredentialsManager() - available_creds = await creds_manager.store.get_all_creds(user_id) + resolved: dict[str, CredentialsFieldInfo] = {} for field_name, field_info in credentials_fields_info.items(): effective_field_info = field_info + if field_info.discriminator and field_info.discriminator_mapping: - # Get discriminator from input, falling back to schema default discriminator_value = input_data.get(field_info.discriminator) if discriminator_value is None: field = block.input_schema.model_fields.get( @@ -126,37 +111,34 @@ class RunBlockTool(BaseTool): f"{discriminator_value} -> {effective_field_info.provider}" ) - matching_cred = next( - ( - cred - for cred in available_creds - if cred.provider in effective_field_info.provider - and cred.type in effective_field_info.supported_types - ), - None, - ) + resolved[field_name] = effective_field_info - if matching_cred: - matched_credentials[field_name] = CredentialsMetaInput( - id=matching_cred.id, - provider=matching_cred.provider, # type: ignore - type=matching_cred.type, - title=matching_cred.title, - ) - else: - # Create a placeholder for the missing credential - provider = next(iter(effective_field_info.provider), "unknown") - cred_type = next(iter(effective_field_info.supported_types), "api_key") - missing_credentials.append( - CredentialsMetaInput( - id=field_name, - provider=provider, # type: ignore - type=cred_type, # type: ignore - title=field_name.replace("_", " ").title(), - ) - ) + return resolved - return matched_credentials, missing_credentials + async def _check_block_credentials( + self, + user_id: str, + block: Any, + input_data: dict[str, Any] | None = None, + ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Check if user has required credentials for a block. + + Args: + user_id: User ID + block: Block to check credentials for + input_data: Input data for the block (used to determine provider via discriminator) + + Returns: + tuple[matched_credentials, missing_credentials] + """ + input_data = input_data or {} + requirements = self._resolve_discriminated_credentials(block, input_data) + + if not requirements: + return {}, [] + + return await match_credentials_to_requirements(user_id, requirements) async def _execute( self, @@ -347,27 +329,6 @@ class RunBlockTool(BaseTool): def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]: """Extract non-credential inputs from block schema.""" - inputs_list = [] schema = block.input_schema.jsonschema() - properties = schema.get("properties", {}) - required_fields = set(schema.get("required", [])) - - # Get credential field names to exclude credentials_fields = set(block.input_schema.get_credentials_fields().keys()) - - for field_name, field_schema in properties.items(): - # Skip credential fields - if field_name in credentials_fields: - continue - - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in required_fields, - } - ) - - return inputs_list + return get_inputs_from_schema(schema, exclude_fields=credentials_fields) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index 0046d0b249..463f7ad6fe 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -225,6 +225,93 @@ async def get_or_create_library_agent( return library_agents[0] +async def get_user_credentials(user_id: str) -> list: + """Get all available credentials for a user.""" + creds_manager = IntegrationCredentialsManager() + return await creds_manager.store.get_all_creds(user_id) + + +def find_matching_credential( + available_creds: list, + field_info: CredentialsFieldInfo, +): + """Find a credential that matches the required provider, type, and scopes.""" + for cred in available_creds: + if cred.provider not in field_info.provider: + continue + if cred.type not in field_info.supported_types: + continue + if not _credential_has_required_scopes(cred, field_info): + continue + return cred + return None + + +def create_credential_meta_from_match(matching_cred) -> CredentialsMetaInput: + """Create a CredentialsMetaInput from a matched credential.""" + return CredentialsMetaInput( + id=matching_cred.id, + provider=matching_cred.provider, # type: ignore + type=matching_cred.type, + title=matching_cred.title, + ) + + +async def match_credentials_to_requirements( + user_id: str, + requirements: dict[str, CredentialsFieldInfo], +) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Match user's credentials against a dictionary of credential requirements. + + This is the core matching logic shared by both graph and block credential matching. + """ + matched: dict[str, CredentialsMetaInput] = {} + missing: list[CredentialsMetaInput] = [] + + if not requirements: + return matched, missing + + available_creds = await get_user_credentials(user_id) + + for field_name, field_info in requirements.items(): + matching_cred = find_matching_credential(available_creds, field_info) + + if matching_cred: + try: + matched[field_name] = create_credential_meta_from_match(matching_cred) + except Exception as e: + logger.error( + f"Failed to create CredentialsMetaInput for field '{field_name}': " + f"provider={matching_cred.provider}, type={matching_cred.type}, " + f"credential_id={matching_cred.id}", + exc_info=True, + ) + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=f"{field_name} (validation failed: {e})", + ) + ) + else: + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=field_name.replace("_", " ").title(), + ) + ) + + return matched, missing + + async def match_user_credentials_to_graph( user_id: str, graph: GraphModel, diff --git a/autogpt_platform/backend/backend/util/validation.py b/autogpt_platform/backend/backend/util/validation.py new file mode 100644 index 0000000000..3c22bc3c4c --- /dev/null +++ b/autogpt_platform/backend/backend/util/validation.py @@ -0,0 +1,32 @@ +"""Validation utilities.""" + +import re + +_UUID_V4_PATTERN = re.compile( + r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}", + re.IGNORECASE, +) + + +def is_uuid_v4(text: str) -> bool: + """Check if text is a valid UUID v4. + + Args: + text: String to validate + + Returns: + True if the text is a valid UUID v4, False otherwise + """ + return bool(_UUID_V4_PATTERN.fullmatch(text.strip())) + + +def extract_uuids(text: str) -> list[str]: + """Extract all UUID v4 strings from text. + + Args: + text: String to search for UUIDs + + Returns: + List of unique UUIDs found (lowercase) + """ + return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})