From 5a30d114164d3e680e56e2ec5d7c1f6d772b5586 Mon Sep 17 00:00:00 2001 From: Otto Date: Mon, 9 Feb 2026 13:43:55 +0000 Subject: [PATCH] refactor(copilot): Code cleanup and deduplication (#11950) ## Summary Code cleanup of the AI Copilot codebase - rebased onto latest dev. ## Changes ### New Files - `backend/util/validation.py` - UUID validation helpers - `backend/api/features/chat/tools/helpers.py` - Shared tool utilities ### Credential Matching Consolidation - Added shared utilities to `utils.py` - Refactored `run_block._check_block_credentials()` with discriminator support - Extracted `_resolve_discriminated_credentials()` for multi-provider handling ### Routes Cleanup - Extracted `_create_stream_generator()` and `SSE_RESPONSE_HEADERS` ### Tool Files Cleanup - Updated `run_agent.py` and `run_block.py` to use shared helpers **WIP** - This PR will be updated incrementally. --- .../api/features/chat/tools/helpers.py | 29 +++ .../api/features/chat/tools/run_agent.py | 21 +- .../api/features/chat/tools/run_block.py | 185 +++++++----------- .../backend/api/features/chat/tools/utils.py | 96 ++++++++- 4 files changed, 201 insertions(+), 130 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/helpers.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py new file mode 100644 index 0000000000..cf53605ac0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py @@ -0,0 +1,29 @@ +"""Shared helpers for chat tools.""" + +from typing import Any + + +def get_inputs_from_schema( + input_schema: dict[str, Any], + exclude_fields: set[str] | None = None, +) -> list[dict[str, Any]]: + """Extract input field info from JSON schema.""" + if not isinstance(input_schema, dict): + return [] + + exclude = exclude_fields or set() + properties = input_schema.get("properties", {}) + required = set(input_schema.get("required", [])) + + return [ + { + "name": name, + "title": schema.get("title", name), + "type": schema.get("type", "string"), + "description": schema.get("description", ""), + "required": name in required, + "default": schema.get("default"), + } + for name, schema in properties.items() + if name not in exclude + ] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index 73d4cf81f2..a9f19bcf62 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -24,6 +24,7 @@ from backend.util.timezone_utils import ( ) from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( AgentDetails, AgentDetailsResponse, @@ -261,7 +262,7 @@ class RunAgentTool(BaseTool): ), requirements={ "credentials": requirements_creds_list, - "inputs": self._get_inputs_list(graph.input_schema), + "inputs": get_inputs_from_schema(graph.input_schema), "execution_modes": self._get_execution_modes(graph), }, ), @@ -369,22 +370,6 @@ class RunAgentTool(BaseTool): session_id=session_id, ) - def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]: - """Extract inputs list from schema.""" - inputs_list = [] - if isinstance(input_schema, dict) and "properties" in input_schema: - for field_name, field_schema in input_schema["properties"].items(): - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in input_schema.get("required", []), - } - ) - return inputs_list - def _get_execution_modes(self, graph: GraphModel) -> list[str]: """Get available execution modes for the graph.""" trigger_info = graph.trigger_setup_info @@ -398,7 +383,7 @@ class RunAgentTool(BaseTool): suffix: str, ) -> str: """Build a message describing available inputs for an agent.""" - inputs_list = self._get_inputs_list(graph.input_schema) + inputs_list = get_inputs_from_schema(graph.input_schema) required_names = [i["name"] for i in inputs_list if i["required"]] optional_names = [i["name"] for i in inputs_list if not i["required"]] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 590f81ff23..fc4a470fdd 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -12,14 +12,15 @@ from backend.api.features.chat.tools.find_block import ( COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES, ) -from backend.data.block import get_block +from backend.data.block import AnyBlockSchema, get_block from backend.data.execution import ExecutionContext -from backend.data.model import CredentialsMetaInput +from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( BlockOutputResponse, ErrorResponse, @@ -28,7 +29,10 @@ from .models import ( ToolResponseBase, UserReadiness, ) -from .utils import build_missing_credentials_from_field_info +from .utils import ( + build_missing_credentials_from_field_info, + match_credentials_to_requirements, +) logger = logging.getLogger(__name__) @@ -77,91 +81,6 @@ class RunBlockTool(BaseTool): def requires_auth(self) -> bool: return True - async def _check_block_credentials( - self, - user_id: str, - block: Any, - input_data: dict[str, Any] | None = None, - ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: - """ - Check if user has required credentials for a block. - - Args: - user_id: User ID - block: Block to check credentials for - input_data: Input data for the block (used to determine provider via discriminator) - - Returns: - tuple[matched_credentials, missing_credentials] - """ - matched_credentials: dict[str, CredentialsMetaInput] = {} - missing_credentials: list[CredentialsMetaInput] = [] - input_data = input_data or {} - - # Get credential field info from block's input schema - credentials_fields_info = block.input_schema.get_credentials_fields_info() - - if not credentials_fields_info: - return matched_credentials, missing_credentials - - # Get user's available credentials - creds_manager = IntegrationCredentialsManager() - available_creds = await creds_manager.store.get_all_creds(user_id) - - for field_name, field_info in credentials_fields_info.items(): - effective_field_info = field_info - if field_info.discriminator and field_info.discriminator_mapping: - # Get discriminator from input, falling back to schema default - discriminator_value = input_data.get(field_info.discriminator) - if discriminator_value is None: - field = block.input_schema.model_fields.get( - field_info.discriminator - ) - if field and field.default is not PydanticUndefined: - discriminator_value = field.default - - if ( - discriminator_value - and discriminator_value in field_info.discriminator_mapping - ): - effective_field_info = field_info.discriminate(discriminator_value) - logger.debug( - f"Discriminated provider for {field_name}: " - f"{discriminator_value} -> {effective_field_info.provider}" - ) - - matching_cred = next( - ( - cred - for cred in available_creds - if cred.provider in effective_field_info.provider - and cred.type in effective_field_info.supported_types - ), - None, - ) - - if matching_cred: - matched_credentials[field_name] = CredentialsMetaInput( - id=matching_cred.id, - provider=matching_cred.provider, # type: ignore - type=matching_cred.type, - title=matching_cred.title, - ) - else: - # Create a placeholder for the missing credential - provider = next(iter(effective_field_info.provider), "unknown") - cred_type = next(iter(effective_field_info.supported_types), "api_key") - missing_credentials.append( - CredentialsMetaInput( - id=field_name, - provider=provider, # type: ignore - type=cred_type, # type: ignore - title=field_name.replace("_", " ").title(), - ) - ) - - return matched_credentials, missing_credentials - async def _execute( self, user_id: str | None, @@ -232,8 +151,8 @@ class RunBlockTool(BaseTool): logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") creds_manager = IntegrationCredentialsManager() - matched_credentials, missing_credentials = await self._check_block_credentials( - user_id, block, input_data + matched_credentials, missing_credentials = ( + await self._resolve_block_credentials(user_id, block, input_data) ) if missing_credentials: @@ -362,29 +281,75 @@ class RunBlockTool(BaseTool): session_id=session_id, ) - def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]: + async def _resolve_block_credentials( + self, + user_id: str, + block: AnyBlockSchema, + input_data: dict[str, Any] | None = None, + ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Resolve credentials for a block by matching user's available credentials. + + Args: + user_id: User ID + block: Block to resolve credentials for + input_data: Input data for the block (used to determine provider via discriminator) + + Returns: + tuple of (matched_credentials, missing_credentials) - matched credentials + are used for block execution, missing ones indicate setup requirements. + """ + input_data = input_data or {} + requirements = self._resolve_discriminated_credentials(block, input_data) + + if not requirements: + return {}, [] + + return await match_credentials_to_requirements(user_id, requirements) + + def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]: """Extract non-credential inputs from block schema.""" - inputs_list = [] schema = block.input_schema.jsonschema() - properties = schema.get("properties", {}) - required_fields = set(schema.get("required", [])) - - # Get credential field names to exclude credentials_fields = set(block.input_schema.get_credentials_fields().keys()) + return get_inputs_from_schema(schema, exclude_fields=credentials_fields) - for field_name, field_schema in properties.items(): - # Skip credential fields - if field_name in credentials_fields: - continue + def _resolve_discriminated_credentials( + self, + block: AnyBlockSchema, + input_data: dict[str, Any], + ) -> dict[str, CredentialsFieldInfo]: + """Resolve credential requirements, applying discriminator logic where needed.""" + credentials_fields_info = block.input_schema.get_credentials_fields_info() + if not credentials_fields_info: + return {} - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in required_fields, - } - ) + resolved: dict[str, CredentialsFieldInfo] = {} - return inputs_list + for field_name, field_info in credentials_fields_info.items(): + effective_field_info = field_info + + if field_info.discriminator and field_info.discriminator_mapping: + discriminator_value = input_data.get(field_info.discriminator) + if discriminator_value is None: + field = block.input_schema.model_fields.get( + field_info.discriminator + ) + if field and field.default is not PydanticUndefined: + discriminator_value = field.default + + if ( + discriminator_value + and discriminator_value in field_info.discriminator_mapping + ): + effective_field_info = field_info.discriminate(discriminator_value) + # For host-scoped credentials, add the discriminator value + # (e.g., URL) so _credential_is_for_host can match it + effective_field_info.discriminator_values.add(discriminator_value) + logger.debug( + f"Discriminated provider for {field_name}: " + f"{discriminator_value} -> {effective_field_info.provider}" + ) + + resolved[field_name] = effective_field_info + + return resolved diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index cda0914809..80a842bf36 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -8,6 +8,7 @@ from backend.api.features.library import model as library_model from backend.api.features.store import db as store_db from backend.data.graph import GraphModel from backend.data.model import ( + Credentials, CredentialsFieldInfo, CredentialsMetaInput, HostScopedCredentials, @@ -223,6 +224,99 @@ async def get_or_create_library_agent( return library_agents[0] +async def match_credentials_to_requirements( + user_id: str, + requirements: dict[str, CredentialsFieldInfo], +) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Match user's credentials against a dictionary of credential requirements. + + This is the core matching logic shared by both graph and block credential matching. + """ + matched: dict[str, CredentialsMetaInput] = {} + missing: list[CredentialsMetaInput] = [] + + if not requirements: + return matched, missing + + available_creds = await get_user_credentials(user_id) + + for field_name, field_info in requirements.items(): + matching_cred = find_matching_credential(available_creds, field_info) + + if matching_cred: + try: + matched[field_name] = create_credential_meta_from_match(matching_cred) + except Exception as e: + logger.error( + f"Failed to create CredentialsMetaInput for field '{field_name}': " + f"provider={matching_cred.provider}, type={matching_cred.type}, " + f"credential_id={matching_cred.id}", + exc_info=True, + ) + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=f"{field_name} (validation failed: {e})", + ) + ) + else: + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=field_name.replace("_", " ").title(), + ) + ) + + return matched, missing + + +async def get_user_credentials(user_id: str) -> list[Credentials]: + """Get all available credentials for a user.""" + creds_manager = IntegrationCredentialsManager() + return await creds_manager.store.get_all_creds(user_id) + + +def find_matching_credential( + available_creds: list[Credentials], + field_info: CredentialsFieldInfo, +) -> Credentials | None: + """Find a credential that matches the required provider, type, scopes, and host.""" + for cred in available_creds: + if cred.provider not in field_info.provider: + continue + if cred.type not in field_info.supported_types: + continue + if cred.type == "oauth2" and not _credential_has_required_scopes( + cred, field_info + ): + continue + if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info): + continue + return cred + return None + + +def create_credential_meta_from_match( + matching_cred: Credentials, +) -> CredentialsMetaInput: + """Create a CredentialsMetaInput from a matched credential.""" + return CredentialsMetaInput( + id=matching_cred.id, + provider=matching_cred.provider, # type: ignore + type=matching_cred.type, + title=matching_cred.title, + ) + + async def match_user_credentials_to_graph( user_id: str, graph: GraphModel, @@ -331,8 +425,6 @@ def _credential_has_required_scopes( # If no scopes are required, any credential matches if not requirements.required_scopes: return True - - # Check that credential scopes are a superset of required scopes return set(credential.scopes).issuperset(requirements.required_scopes)