mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 11:55:11 -05:00
Compare commits
3 Commits
feat/text-
...
otto/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e604528ea | ||
|
|
c3ec7c2880 | ||
|
|
7d9380a793 |
@@ -17,6 +17,14 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_
|
|||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# SSE response headers for streaming
|
||||||
|
SSE_RESPONSE_HEADERS = {
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -32,6 +40,60 @@ async def _validate_and_get_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
is_user_message: bool = True,
|
||||||
|
context: dict[str, str] | None = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create SSE event generator for chat streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Chat session ID
|
||||||
|
message: User message to process
|
||||||
|
user_id: Optional authenticated user ID
|
||||||
|
session: Pre-fetched chat session
|
||||||
|
is_user_message: Whether the message is from a user
|
||||||
|
context: Optional context dict with url and content
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SSE-formatted chunks from the chat completion stream
|
||||||
|
"""
|
||||||
|
chunk_count = 0
|
||||||
|
first_chunk_type: str | None = None
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
message,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
context=context,
|
||||||
|
):
|
||||||
|
if chunk_count < 3:
|
||||||
|
logger.info(
|
||||||
|
"Chat stream chunk",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_type": str(chunk.type),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if not first_chunk_type:
|
||||||
|
first_chunk_type = str(chunk.type)
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk.to_sse()
|
||||||
|
logger.info(
|
||||||
|
"Chat stream completed",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
"first_chunk_type": first_chunk_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["chat"],
|
tags=["chat"],
|
||||||
)
|
)
|
||||||
@@ -221,49 +283,17 @@ async def stream_chat_post(
|
|||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
context=request.context,
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=request.message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
context=request.context,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -295,48 +325,16 @@ async def stream_chat_get(
|
|||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
message,
|
|
||||||
is_user_message=is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .models import ErrorResponse
|
||||||
|
|
||||||
|
|
||||||
|
def error_response(
|
||||||
|
message: str, session_id: str | None, **kwargs: Any
|
||||||
|
) -> ErrorResponse:
|
||||||
|
"""Create standardized error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message to display
|
||||||
|
session_id: Current session ID
|
||||||
|
**kwargs: Additional fields to pass to ErrorResponse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ErrorResponse with the given message and session_id
|
||||||
|
"""
|
||||||
|
return ErrorResponse(message=message, session_id=session_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_schema: JSON schema dict with 'properties' and 'required'
|
||||||
|
exclude_fields: Set of field names to exclude (e.g., credential fields)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with field info (name, title, type, description, required, default)
|
||||||
|
"""
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str:
|
||||||
|
"""Format input fields as a readable markdown list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: List of input dicts from get_inputs_from_schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted string listing the inputs
|
||||||
|
"""
|
||||||
|
if not inputs:
|
||||||
|
return "No inputs required."
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for inp in inputs:
|
||||||
|
required_marker = " (required)" if inp.get("required") else ""
|
||||||
|
default = inp.get("default")
|
||||||
|
default_info = f" [default: {default}]" if default is not None else ""
|
||||||
|
description = inp.get("description", "")
|
||||||
|
desc_info = f" - {description}" if description else ""
|
||||||
|
|
||||||
|
lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -354,19 +355,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
"""Extract inputs list from schema."""
|
"""Extract inputs list from schema."""
|
||||||
inputs_list = []
|
return get_inputs_from_schema(input_schema)
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ from typing import Any
|
|||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -22,7 +23,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,6 +75,22 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_credentials_requirements(
|
||||||
|
self,
|
||||||
|
block: Any,
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""
|
||||||
|
Get credential requirements from block's input schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block: Block to get credentials for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping field names to CredentialsFieldInfo
|
||||||
|
"""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
return credentials_fields_info if credentials_fields_info else {}
|
||||||
|
|
||||||
async def _check_block_credentials(
|
async def _check_block_credentials(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -82,53 +102,12 @@ class RunBlockTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials, missing_credentials]
|
tuple[matched_credentials, missing_credentials]
|
||||||
"""
|
"""
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
requirements = self._get_credentials_requirements(block)
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
if not requirements:
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
return {}, []
|
||||||
|
|
||||||
if not credentials_fields_info:
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
# field_info.provider is a frozenset of acceptable providers
|
|
||||||
# field_info.supported_types is a frozenset of acceptable types
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in field_info.provider
|
|
||||||
and cred.type in field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
@@ -320,27 +299,7 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
# Get credential field names to exclude
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in required_fields,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_list
|
|
||||||
|
|||||||
@@ -225,6 +225,127 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list:
|
||||||
|
"""
|
||||||
|
Get all available credentials for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user's credentials
|
||||||
|
"""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list,
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Find a credential that matches the required provider, type, and scopes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_creds: List of user's available credentials
|
||||||
|
field_info: CredentialsFieldInfo with provider, type, and scope requirements
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matching credential or None
|
||||||
|
"""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if not _credential_has_required_scopes(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""
|
||||||
|
Create a CredentialsMetaInput from a matched credential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matching_cred: The matched credential object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CredentialsMetaInput instance
|
||||||
|
"""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
requirements: Dict mapping field names to CredentialsFieldInfo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials dict, missing_credentials list]
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -242,9 +363,6 @@ async def match_user_credentials_to_graph(
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials dict, missing_credential_descriptions list]
|
tuple[matched_credentials dict, missing_credential_descriptions list]
|
||||||
"""
|
"""
|
||||||
graph_credentials_inputs: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_creds: list[str] = []
|
|
||||||
|
|
||||||
# Get aggregated credentials requirements from the graph
|
# Get aggregated credentials requirements from the graph
|
||||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -252,69 +370,30 @@ async def match_user_credentials_to_graph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not aggregated_creds:
|
if not aggregated_creds:
|
||||||
return graph_credentials_inputs, missing_creds
|
return {}, []
|
||||||
|
|
||||||
# Get all available credentials for the user
|
# Convert aggregated format to simple requirements dict
|
||||||
creds_manager = IntegrationCredentialsManager()
|
requirements = {
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
field_name: field_info
|
||||||
|
for field_name, (field_info, _node_fields) in aggregated_creds.items()
|
||||||
|
}
|
||||||
|
|
||||||
# For each required credential field, find a matching user credential
|
# Use shared matching logic
|
||||||
# field_info.provider is a frozenset because aggregate_credentials_inputs()
|
matched, missing_list = await match_credentials_to_requirements(
|
||||||
# combines requirements from multiple nodes. A credential matches if its
|
user_id, requirements
|
||||||
# provider is in the set of acceptable providers.
|
|
||||||
for credential_field_name, (
|
|
||||||
credential_requirements,
|
|
||||||
_node_fields,
|
|
||||||
) in aggregated_creds.items():
|
|
||||||
# Find first matching credential by provider, type, and scopes
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in credential_requirements.provider
|
|
||||||
and cred.type in credential_requirements.supported_types
|
|
||||||
and _credential_has_required_scopes(cred, credential_requirements)
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if matching_cred:
|
# Convert missing list to string descriptions for backward compatibility
|
||||||
try:
|
missing_descriptions = [
|
||||||
graph_credentials_inputs[credential_field_name] = CredentialsMetaInput(
|
f"{cred.id} (requires provider={cred.provider}, type={cred.type})"
|
||||||
id=matching_cred.id,
|
for cred in missing_list
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{credential_field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} (validation failed: {e})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Build a helpful error message including scope requirements
|
|
||||||
error_parts = [
|
|
||||||
f"provider in {list(credential_requirements.provider)}",
|
|
||||||
f"type in {list(credential_requirements.supported_types)}",
|
|
||||||
]
|
]
|
||||||
if credential_requirements.required_scopes:
|
|
||||||
error_parts.append(
|
|
||||||
f"scopes including {list(credential_requirements.required_scopes)}"
|
|
||||||
)
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Credential matching complete: {len(graph_credentials_inputs)}/{len(aggregated_creds)} matched"
|
f"Credential matching complete: {len(matched)}/{len(aggregated_creds)} matched"
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_credentials_inputs, missing_creds
|
return matched, missing_descriptions
|
||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
|
|||||||
32
autogpt_platform/backend/backend/util/validation.py
Normal file
32
autogpt_platform/backend/backend/util/validation.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Validation utilities."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
_UUID_V4_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_uuid_v4(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the text is a valid UUID v4, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to search for UUIDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found (lowercase)
|
||||||
|
"""
|
||||||
|
return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})
|
||||||
Reference in New Issue
Block a user