mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 19:35:15 -05:00
refactor: address Pwuts review feedback
- Remove unused error_response() and format_inputs_as_markdown() from helpers.py - Remove _get_inputs_list() wrapper from run_agent.py, use get_inputs_from_schema directly - Fix type annotations: get_user_credentials, find_matching_credential, create_credential_meta_from_match - Remove check_scopes parameter - always check scopes (original missing check was broken behavior) - Reorder _credential_has_required_scopes to be defined before find_matching_credential
This commit is contained in:
@@ -2,39 +2,12 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .models import ErrorResponse
|
||||
|
||||
|
||||
def error_response(
|
||||
message: str, session_id: str | None, **kwargs: Any
|
||||
) -> ErrorResponse:
|
||||
"""Create standardized error response.
|
||||
|
||||
Args:
|
||||
message: Error message to display
|
||||
session_id: Current session ID
|
||||
**kwargs: Additional fields to pass to ErrorResponse
|
||||
|
||||
Returns:
|
||||
ErrorResponse with the given message and session_id
|
||||
"""
|
||||
return ErrorResponse(message=message, session_id=session_id, **kwargs)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract input field info from JSON schema.
|
||||
|
||||
Args:
|
||||
input_schema: JSON schema dict with 'properties' and 'required'
|
||||
exclude_fields: Set of field names to exclude (e.g., credential fields)
|
||||
|
||||
Returns:
|
||||
List of dicts with field info (name, title, type, description, required, default)
|
||||
"""
|
||||
# Safety check: original code returned [] if input_schema wasn't a dict
|
||||
"""Extract input field info from JSON schema."""
|
||||
if not isinstance(input_schema, dict):
|
||||
return []
|
||||
|
||||
@@ -54,28 +27,3 @@ def get_inputs_from_schema(
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
|
||||
|
||||
def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str:
|
||||
"""Format input fields as a readable markdown list.
|
||||
|
||||
Args:
|
||||
inputs: List of input dicts from get_inputs_from_schema
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string listing the inputs
|
||||
"""
|
||||
if not inputs:
|
||||
return "No inputs required."
|
||||
|
||||
lines = []
|
||||
for inp in inputs:
|
||||
required_marker = " (required)" if inp.get("required") else ""
|
||||
default = inp.get("default")
|
||||
default_info = f" [default: {default}]" if default is not None else ""
|
||||
description = inp.get("description", "")
|
||||
desc_info = f" - {description}" if description else ""
|
||||
|
||||
lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -262,7 +262,7 @@ class RunAgentTool(BaseTool):
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"inputs": self._get_inputs_list(graph.input_schema),
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
@@ -370,10 +370,6 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract inputs list from schema."""
|
||||
return get_inputs_from_schema(input_schema)
|
||||
|
||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||
"""Get available execution modes for the graph."""
|
||||
trigger_info = graph.trigger_setup_info
|
||||
@@ -387,7 +383,7 @@ class RunAgentTool(BaseTool):
|
||||
suffix: str,
|
||||
) -> str:
|
||||
"""Build a message describing available inputs for an agent."""
|
||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||
|
||||
|
||||
@@ -225,30 +225,43 @@ async def get_or_create_library_agent(
|
||||
return library_agents[0]
|
||||
|
||||
|
||||
async def get_user_credentials(user_id: str) -> list:
|
||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||
"""Get all available credentials for a user."""
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
return await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if a credential has all the scopes required by the block."""
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
def find_matching_credential(
|
||||
available_creds: list,
|
||||
available_creds: list[Credentials],
|
||||
field_info: CredentialsFieldInfo,
|
||||
check_scopes: bool = True,
|
||||
):
|
||||
"""Find a credential that matches the required provider, type, and optionally scopes."""
|
||||
) -> Credentials | None:
|
||||
"""Find a credential that matches the required provider, type, and scopes."""
|
||||
for cred in available_creds:
|
||||
if cred.provider not in field_info.provider:
|
||||
continue
|
||||
if cred.type not in field_info.supported_types:
|
||||
continue
|
||||
if check_scopes and not _credential_has_required_scopes(cred, field_info):
|
||||
if not _credential_has_required_scopes(cred, field_info):
|
||||
continue
|
||||
return cred
|
||||
return None
|
||||
|
||||
|
||||
def create_credential_meta_from_match(matching_cred) -> CredentialsMetaInput:
|
||||
def create_credential_meta_from_match(
|
||||
matching_cred: Credentials,
|
||||
) -> CredentialsMetaInput:
|
||||
"""Create a CredentialsMetaInput from a matched credential."""
|
||||
return CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
@@ -261,18 +274,11 @@ def create_credential_meta_from_match(matching_cred) -> CredentialsMetaInput:
|
||||
async def match_credentials_to_requirements(
|
||||
user_id: str,
|
||||
requirements: dict[str, CredentialsFieldInfo],
|
||||
check_scopes: bool = True,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Match user's credentials against a dictionary of credential requirements.
|
||||
|
||||
This is the core matching logic shared by both graph and block credential matching.
|
||||
|
||||
Args:
|
||||
user_id: User ID to fetch credentials for
|
||||
requirements: Dict mapping field names to CredentialsFieldInfo
|
||||
check_scopes: Whether to verify OAuth2 scopes match requirements (default True).
|
||||
Set to False to preserve original run_block behavior which didn't check scopes.
|
||||
"""
|
||||
matched: dict[str, CredentialsMetaInput] = {}
|
||||
missing: list[CredentialsMetaInput] = []
|
||||
@@ -283,9 +289,7 @@ async def match_credentials_to_requirements(
|
||||
available_creds = await get_user_credentials(user_id)
|
||||
|
||||
for field_name, field_info in requirements.items():
|
||||
matching_cred = find_matching_credential(
|
||||
available_creds, field_info, check_scopes=check_scopes
|
||||
)
|
||||
matching_cred = find_matching_credential(available_creds, field_info)
|
||||
|
||||
if matching_cred:
|
||||
try:
|
||||
@@ -414,28 +418,6 @@ async def match_user_credentials_to_graph(
|
||||
return graph_credentials_inputs, missing_creds
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
Reference in New Issue
Block a user