diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index de75419d28..8f37937df5 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -77,72 +77,6 @@ class RunBlockTool(BaseTool): def requires_auth(self) -> bool: return True - def _resolve_discriminated_credentials( - self, - block: AnyBlockSchema, - input_data: dict[str, Any], - ) -> dict[str, CredentialsFieldInfo]: - """Resolve credential requirements, applying discriminator logic where needed.""" - credentials_fields_info = block.input_schema.get_credentials_fields_info() - if not credentials_fields_info: - return {} - - resolved: dict[str, CredentialsFieldInfo] = {} - - for field_name, field_info in credentials_fields_info.items(): - effective_field_info = field_info - - if field_info.discriminator and field_info.discriminator_mapping: - discriminator_value = input_data.get(field_info.discriminator) - if discriminator_value is None: - field = block.input_schema.model_fields.get( - field_info.discriminator - ) - if field and field.default is not PydanticUndefined: - discriminator_value = field.default - - if ( - discriminator_value - and discriminator_value in field_info.discriminator_mapping - ): - effective_field_info = field_info.discriminate(discriminator_value) - # For host-scoped credentials, add the discriminator value - # (e.g., URL) so _credential_is_for_host can match it - effective_field_info.discriminator_values.add(discriminator_value) - logger.debug( - f"Discriminated provider for {field_name}: " - f"{discriminator_value} -> {effective_field_info.provider}" - ) - - resolved[field_name] = effective_field_info - - return resolved - - async def _check_block_credentials( - self, - user_id: str, - block: AnyBlockSchema, - input_data: dict[str, Any] | None = None, - ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: - """ - Check if user has required credentials for a block. - - Args: - user_id: User ID - block: Block to check credentials for - input_data: Input data for the block (used to determine provider via discriminator) - - Returns: - tuple[matched_credentials, missing_credentials] - """ - input_data = input_data or {} - requirements = self._resolve_discriminated_credentials(block, input_data) - - if not requirements: - return {}, [] - - return await match_credentials_to_requirements(user_id, requirements) - async def _execute( self, user_id: str | None, @@ -330,8 +264,74 @@ class RunBlockTool(BaseTool): session_id=session_id, ) + async def _check_block_credentials( + self, + user_id: str, + block: AnyBlockSchema, + input_data: dict[str, Any] | None = None, + ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Check if user has required credentials for a block. + + Args: + user_id: User ID + block: Block to check credentials for + input_data: Input data for the block (used to determine provider via discriminator) + + Returns: + tuple[matched_credentials, missing_credentials] + """ + input_data = input_data or {} + requirements = self._resolve_discriminated_credentials(block, input_data) + + if not requirements: + return {}, [] + + return await match_credentials_to_requirements(user_id, requirements) + def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]: """Extract non-credential inputs from block schema.""" schema = block.input_schema.jsonschema() credentials_fields = set(block.input_schema.get_credentials_fields().keys()) return get_inputs_from_schema(schema, exclude_fields=credentials_fields) + + def _resolve_discriminated_credentials( + self, + block: AnyBlockSchema, + input_data: dict[str, Any], + ) -> dict[str, CredentialsFieldInfo]: + """Resolve credential requirements, applying discriminator logic where needed.""" + credentials_fields_info = block.input_schema.get_credentials_fields_info() + if not credentials_fields_info: + return {} + + resolved: dict[str, CredentialsFieldInfo] = {} + + for field_name, field_info in credentials_fields_info.items(): + effective_field_info = field_info + + if field_info.discriminator and field_info.discriminator_mapping: + discriminator_value = input_data.get(field_info.discriminator) + if discriminator_value is None: + field = block.input_schema.model_fields.get( + field_info.discriminator + ) + if field and field.default is not PydanticUndefined: + discriminator_value = field.default + + if ( + discriminator_value + and discriminator_value in field_info.discriminator_mapping + ): + effective_field_info = field_info.discriminate(discriminator_value) + # For host-scoped credentials, add the discriminator value + # (e.g., URL) so _credential_is_for_host can match it + effective_field_info.discriminator_values.add(discriminator_value) + logger.debug( + f"Discriminated provider for {field_name}: " + f"{discriminator_value} -> {effective_field_info.provider}" + ) + + resolved[field_name] = effective_field_info + + return resolved