mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
Compare commits
4 Commits
otto/copil
...
ntindle/go
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4f81bc4fc | ||
|
|
c5abc01f25 | ||
|
|
8b7053c1de | ||
|
|
e00c1202ad |
1320
autogpt_platform/autogpt_libs/poetry.lock
generated
1320
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,15 +11,15 @@ python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.128.0"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.27.2"
|
||||
uvicorn = "^0.40.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.404"
|
||||
|
||||
@@ -1,29 +1,26 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
SSE_RESPONSE_HEADERS = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,48 +36,6 @@ async def _validate_and_get_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _create_stream_generator(
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Create SSE event generator for chat streaming."""
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
context=context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -104,6 +59,15 @@ class CreateSessionResponse(BaseModel):
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
@@ -112,6 +76,7 @@ class SessionDetailResponse(BaseModel):
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -130,6 +95,14 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -215,13 +188,14 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -229,11 +203,28 @@ async def get_session(
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
@@ -241,6 +232,7 @@ async def get_session(
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -260,27 +252,122 @@ async def stream_chat_post(
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in background AI generation for session {session_id}: {e}"
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
subscriber_queue = None
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_create_stream_generator(
|
||||
session_id=session_id,
|
||||
message=request.message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
),
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=SSE_RESPONSE_HEADERS,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -312,16 +399,48 @@ async def stream_chat_get(
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
return StreamingResponse(
|
||||
_create_stream_generator(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
),
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=SSE_RESPONSE_HEADERS,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -351,6 +470,251 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@router.get("/config/ttl", status_code=200)
|
||||
async def get_ttl_config() -> dict:
|
||||
"""
|
||||
Get the stream TTL configuration.
|
||||
|
||||
Returns the Time-To-Live settings for chat streams, which determines
|
||||
how long clients can reconnect to an active stream.
|
||||
|
||||
Returns:
|
||||
dict: TTL configuration with seconds and milliseconds values.
|
||||
"""
|
||||
return {
|
||||
"stream_ttl_seconds": config.stream_ttl,
|
||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
||||
}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract input field info from JSON schema."""
|
||||
if not isinstance(input_schema, dict):
|
||||
return []
|
||||
|
||||
exclude = exclude_fields or set()
|
||||
properties = input_schema.get("properties", {})
|
||||
required = set(input_schema.get("required", []))
|
||||
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
"title": schema.get("title", name),
|
||||
"type": schema.get("type", "string"),
|
||||
"description": schema.get("description", ""),
|
||||
"required": name in required,
|
||||
"default": schema.get("default"),
|
||||
}
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"inputs": self._get_inputs_list(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract inputs list from schema."""
|
||||
inputs_list = []
|
||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||
for field_name, field_schema in input_schema["properties"].items():
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in input_schema.get("required", []),
|
||||
}
|
||||
)
|
||||
return inputs_list
|
||||
|
||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||
"""Get available execution modes for the graph."""
|
||||
trigger_info = graph.trigger_setup_info
|
||||
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
|
||||
suffix: str,
|
||||
) -> str:
|
||||
"""Build a message describing available inputs for an agent."""
|
||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||
|
||||
|
||||
@@ -10,13 +10,12 @@ from pydantic_core import PydanticUndefined
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
@@ -25,10 +24,7 @@ from .models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -77,22 +73,41 @@ class RunBlockTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
input_data = input_data or {}
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
# Get discriminator from input, falling back to schema default
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
@@ -111,34 +126,37 @@ class RunBlockTool(BaseTool):
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in effective_field_info.provider
|
||||
and cred.type in effective_field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return resolved
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(effective_field_info.provider), "unknown")
|
||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
@@ -329,6 +347,27 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
|
||||
@@ -117,7 +117,7 @@ def build_missing_credentials_from_graph(
|
||||
preserving all supported credential types for each field.
|
||||
"""
|
||||
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
||||
aggregated_fields = graph.aggregate_credentials_inputs()
|
||||
aggregated_fields = graph.regular_credentials_inputs
|
||||
|
||||
return {
|
||||
field_key: _serialize_missing_credential(field_key, field_info)
|
||||
@@ -223,103 +223,6 @@ async def get_or_create_library_agent(
|
||||
return library_agents[0]
|
||||
|
||||
|
||||
async def get_user_credentials(user_id: str) -> list:
|
||||
"""Get all available credentials for a user."""
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
return await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
|
||||
def find_matching_credential(
|
||||
available_creds: list,
|
||||
field_info: CredentialsFieldInfo,
|
||||
check_scopes: bool = True,
|
||||
):
|
||||
"""Find a credential that matches the required provider, type, and optionally scopes."""
|
||||
for cred in available_creds:
|
||||
if cred.provider not in field_info.provider:
|
||||
continue
|
||||
if cred.type not in field_info.supported_types:
|
||||
continue
|
||||
if check_scopes and not _credential_has_required_scopes(cred, field_info):
|
||||
continue
|
||||
return cred
|
||||
return None
|
||||
|
||||
|
||||
def create_credential_meta_from_match(matching_cred) -> CredentialsMetaInput:
|
||||
"""Create a CredentialsMetaInput from a matched credential."""
|
||||
return CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
|
||||
|
||||
async def match_credentials_to_requirements(
|
||||
user_id: str,
|
||||
requirements: dict[str, CredentialsFieldInfo],
|
||||
check_scopes: bool = True,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Match user's credentials against a dictionary of credential requirements.
|
||||
|
||||
This is the core matching logic shared by both graph and block credential matching.
|
||||
|
||||
Args:
|
||||
user_id: User ID to fetch credentials for
|
||||
requirements: Dict mapping field names to CredentialsFieldInfo
|
||||
check_scopes: Whether to verify OAuth2 scopes match requirements (default True).
|
||||
Set to False to preserve original run_block behavior which didn't check scopes.
|
||||
"""
|
||||
matched: dict[str, CredentialsMetaInput] = {}
|
||||
missing: list[CredentialsMetaInput] = []
|
||||
|
||||
if not requirements:
|
||||
return matched, missing
|
||||
|
||||
available_creds = await get_user_credentials(user_id)
|
||||
|
||||
for field_name, field_info in requirements.items():
|
||||
matching_cred = find_matching_credential(
|
||||
available_creds, field_info, check_scopes=check_scopes
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
try:
|
||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||
f"credential_id={matching_cred.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=f"{field_name} (validation failed: {e})",
|
||||
)
|
||||
)
|
||||
else:
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched, missing
|
||||
|
||||
|
||||
async def match_user_credentials_to_graph(
|
||||
user_id: str,
|
||||
graph: GraphModel,
|
||||
@@ -341,7 +244,7 @@ async def match_user_credentials_to_graph(
|
||||
missing_creds: list[str] = []
|
||||
|
||||
# Get aggregated credentials requirements from the graph
|
||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||
aggregated_creds = graph.regular_credentials_inputs
|
||||
logger.debug(
|
||||
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for chat tools utility functions."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
|
||||
def _make_regular_field() -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
|
||||
def test_build_missing_credentials_excludes_auto_creds():
|
||||
"""
|
||||
build_missing_credentials_from_graph() should use regular_credentials_inputs
|
||||
and thus exclude auto_credentials from the "missing" set.
|
||||
"""
|
||||
from backend.api.features.chat.tools.utils import (
|
||||
build_missing_credentials_from_graph,
|
||||
)
|
||||
|
||||
regular_field = _make_regular_field()
|
||||
|
||||
mock_graph = MagicMock()
|
||||
# regular_credentials_inputs should only return the non-auto field
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
|
||||
|
||||
# Should include the regular credential
|
||||
assert "github_api_key" in result
|
||||
# Should NOT include the auto_credential (not in regular_credentials_inputs)
|
||||
assert "google_oauth2" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_user_credentials_excludes_auto_creds():
|
||||
"""
|
||||
match_user_credentials_to_graph() should use regular_credentials_inputs
|
||||
and thus exclude auto_credentials from matching.
|
||||
"""
|
||||
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
|
||||
|
||||
regular_field = _make_regular_field()
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "test-graph"
|
||||
# regular_credentials_inputs returns only non-auto fields
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
# Mock the credentials manager to return no credentials
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
|
||||
) as MockCredsMgr:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.get_all_creds.return_value = []
|
||||
MockCredsMgr.return_value.store = mock_store
|
||||
|
||||
matched, missing = await match_user_credentials_to_graph(
|
||||
user_id="test-user", graph=mock_graph
|
||||
)
|
||||
|
||||
# No credentials available, so github should be missing
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
assert "github_api_key" in missing[0]
|
||||
@@ -1103,7 +1103,7 @@ async def create_preset_from_graph_execution(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||
)
|
||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||
elif len(graph.regular_credentials_inputs) > 0:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||
"because it was run before this feature existed "
|
||||
|
||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
webset = await aexa.websets.get(id=input_data.external_id)
|
||||
webset = aexa.websets.get(id=input_data.external_id)
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
yield "webset", webset_result
|
||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
count=input_data.search_count,
|
||||
)
|
||||
|
||||
webset = await aexa.websets.create(
|
||||
webset = aexa.websets.create(
|
||||
params=CreateWebsetParameters(
|
||||
search=search_params,
|
||||
external_id=input_data.external_id,
|
||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = await aexa.websets.list(
|
||||
response = aexa.websets.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
deleted_webset.status.value
|
||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
canceled_webset.status.value
|
||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
sdk_preview = await aexa.websets.preview(params=payload)
|
||||
sdk_preview = aexa.websets.preview(params=payload)
|
||||
|
||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||
|
||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Extract basic info
|
||||
webset_id = webset.id
|
||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
total_items = 0
|
||||
|
||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||
items_response = await aexa.websets.items.list(
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
sample_items_data = [
|
||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset details
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
|
||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||
sdk_enrichment = aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
items_enriched = 0
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_enrich = await aexa.websets.enrichments.get(
|
||||
current_enrich = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
current_status = (
|
||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||
sdk_enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = await aexa.websets.items.list(
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK import object
|
||||
mock_import = MagicMock()
|
||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_import = await aexa.websets.imports.create(
|
||||
sdk_import = aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
|
||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = await aexa.websets.imports.list(
|
||||
response = aexa.websets.imports.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_import = await aexa.websets.imports.delete(
|
||||
import_id=input_data.import_id
|
||||
)
|
||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
|
||||
}
|
||||
)
|
||||
|
||||
# Create async iterator for list_all
|
||||
async def async_item_iterator(*args, **kwargs):
|
||||
for item in [mock_item1, mock_item2]:
|
||||
yield item
|
||||
# Create mock iterator
|
||||
mock_items = [mock_item1, mock_item2]
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||
websets=MagicMock(
|
||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
async for sdk_item in item_iterator:
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_item = await aexa.websets.items.get(
|
||||
sdk_item = aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
||||
response = None
|
||||
|
||||
while time.time() - start_time < input_data.wait_timeout:
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
||||
interval = min(interval * 1.2, 10)
|
||||
|
||||
if not response:
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
else:
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_item = await aexa.websets.items.delete(
|
||||
deleted_item = aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
async for sdk_item in item_iterator:
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
# Get sample items if requested
|
||||
sample_items: List[WebsetItemModel] = []
|
||||
if input_data.sample_size > 0:
|
||||
items_response = await aexa.websets.items.list(
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
# Convert to our stable models
|
||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get items starting from cursor
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
|
||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK monitor object
|
||||
mock_monitor = MagicMock()
|
||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.update(
|
||||
sdk_monitor = aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
|
||||
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_monitor = await aexa.websets.monitors.delete(
|
||||
monitor_id=input_data.monitor_id
|
||||
)
|
||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = await aexa.websets.monitors.list(
|
||||
response = aexa.websets.monitors.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
|
||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
WebsetTargetStatus.IDLE,
|
||||
WebsetTargetStatus.ANY_COMPLETE,
|
||||
]:
|
||||
final_webset = await aexa.websets.wait_until_idle(
|
||||
final_webset = aexa.websets.wait_until_idle(
|
||||
id=input_data.webset_id,
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
interval = input_data.check_interval
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current search status using SDK
|
||||
search = await aexa.websets.searches.get(
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
search = await aexa.websets.searches.get(
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
final_status = (
|
||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current enrichment status using SDK
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
final_status = (
|
||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
poll_start = time.time()
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_search = await aexa.websets.searches.get(
|
||||
current_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
current_status = (
|
||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = await aexa.websets.searches.get(
|
||||
sdk_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_search = await aexa.websets.searches.cancel(
|
||||
canceled_search = aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
payload["entity"] = {"type": input_data.entity_type.value}
|
||||
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
|
||||
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
|
||||
def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
) -> bool | openai.Omit:
|
||||
):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.omit
|
||||
return openai.NOT_GIVEN
|
||||
return parallel_tool_calls
|
||||
|
||||
|
||||
|
||||
@@ -319,6 +319,8 @@ class BlockSchema(BaseModel):
|
||||
"credentials_provider": [config.get("provider", "google")],
|
||||
"credentials_types": [config.get("type", "oauth2")],
|
||||
"credentials_scopes": config.get("scopes"),
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": info["field_name"],
|
||||
}
|
||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||
auto_schema, by_alias=True
|
||||
|
||||
@@ -447,8 +447,7 @@ class GraphModel(Graph, GraphMeta):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
|
||||
graph_credentials_inputs = self.regular_credentials_inputs
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
@@ -604,6 +603,28 @@ class GraphModel(Graph, GraphMeta):
|
||||
for key, (field_info, node_field_pairs) in combined.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def regular_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||
"""Credentials that need explicit user mapping (CredentialsMetaInput fields)."""
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.aggregate_credentials_inputs().items()
|
||||
if not v[0].is_auto_credential
|
||||
}
|
||||
|
||||
@property
|
||||
def auto_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||
"""Credentials embedded in file fields (_credentials_id), resolved at execution time."""
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.aggregate_credentials_inputs().items()
|
||||
if v[0].is_auto_credential
|
||||
}
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
@@ -654,6 +675,16 @@ class GraphModel(Graph, GraphMeta):
|
||||
) and graph_id in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
# Clear auto-credentials references (e.g., _credentials_id in
|
||||
# GoogleDriveFile fields) so the new user must re-authenticate
|
||||
# with their own account
|
||||
for node in graph.nodes:
|
||||
if not node.input_default:
|
||||
continue
|
||||
for key, value in node.input_default.items():
|
||||
if isinstance(value, dict) and "_credentials_id" in value:
|
||||
del value["_credentials_id"]
|
||||
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
|
||||
@@ -463,3 +463,328 @@ def test_node_credentials_optional_with_other_metadata():
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for _reassign_ids credential clearing (Fix 3: SECRT-1772)
|
||||
def test_combine_preserves_is_auto_credential_flag():
|
||||
"""
|
||||
CredentialsFieldInfo.combine() must propagate is_auto_credential and
|
||||
input_field_name to the combined result. Regression test for reviewer
|
||||
finding that combine() dropped these fields.
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
auto_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google"],
|
||||
"credentials_types": ["oauth2"],
|
||||
"credentials_scopes": ["drive.readonly"],
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": "spreadsheet",
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
# combine() takes *args of (field_info, key) tuples
|
||||
combined = CredentialsFieldInfo.combine(
|
||||
(auto_field, ("node-1", "credentials")),
|
||||
(auto_field, ("node-2", "credentials")),
|
||||
)
|
||||
|
||||
assert len(combined) == 1
|
||||
group_key = next(iter(combined))
|
||||
combined_info, combined_keys = combined[group_key]
|
||||
|
||||
assert combined_info.is_auto_credential is True
|
||||
assert combined_info.input_field_name == "spreadsheet"
|
||||
assert combined_keys == {("node-1", "credentials"), ("node-2", "credentials")}
|
||||
|
||||
|
||||
def test_combine_preserves_regular_credential_defaults():
|
||||
"""Regular credentials should have is_auto_credential=False after combine()."""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
regular_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
combined = CredentialsFieldInfo.combine(
|
||||
(regular_field, ("node-1", "credentials")),
|
||||
)
|
||||
|
||||
group_key = next(iter(combined))
|
||||
combined_info, _ = combined[group_key]
|
||||
|
||||
assert combined_info.is_auto_credential is False
|
||||
assert combined_info.input_field_name is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_reassign_ids_clears_credentials_id():
|
||||
"""
|
||||
[SECRT-1772] _reassign_ids should clear _credentials_id from
|
||||
GoogleDriveFile-style input_default fields so forked agents
|
||||
don't retain the original creator's credential references.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "original-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
# _credentials_id key should be removed (not set to None) so that
|
||||
# _acquire_auto_credentials correctly errors instead of treating it as chained data
|
||||
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||
|
||||
|
||||
def test_reassign_ids_preserves_non_credential_fields():
|
||||
"""
|
||||
Regression guard: _reassign_ids should NOT modify non-credential fields
|
||||
like name, mimeType, id, url.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "cred-abc",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
field = graph.nodes[0].input_default["spreadsheet"]
|
||||
assert field["id"] == "file-123"
|
||||
assert field["name"] == "test.xlsx"
|
||||
assert field["mimeType"] == "application/vnd.google-apps.spreadsheet"
|
||||
assert field["url"] == "https://docs.google.com/spreadsheets/d/file-123"
|
||||
|
||||
|
||||
def test_reassign_ids_handles_no_credentials():
|
||||
"""
|
||||
Regression guard: _reassign_ids should not error when input_default
|
||||
has no dict fields with _credentials_id.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"input": "some value",
|
||||
"another_input": 42,
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
# Should not error, fields unchanged
|
||||
assert graph.nodes[0].input_default["input"] == "some value"
|
||||
assert graph.nodes[0].input_default["another_input"] == 42
|
||||
|
||||
|
||||
def test_reassign_ids_handles_multiple_credential_fields():
|
||||
"""
|
||||
[SECRT-1772] When a node has multiple dict fields with _credentials_id,
|
||||
ALL of them should be cleared.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "cred-1",
|
||||
"id": "file-1",
|
||||
"name": "file1.xlsx",
|
||||
},
|
||||
"doc_file": {
|
||||
"_credentials_id": "cred-2",
|
||||
"id": "file-2",
|
||||
"name": "file2.docx",
|
||||
},
|
||||
"plain_input": "not a dict",
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||
assert "_credentials_id" not in graph.nodes[0].input_default["doc_file"]
|
||||
assert graph.nodes[0].input_default["plain_input"] == "not a dict"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for discriminate() field propagation
|
||||
def test_discriminate_preserves_is_auto_credential_flag():
|
||||
"""
|
||||
CredentialsFieldInfo.discriminate() must propagate is_auto_credential and
|
||||
input_field_name to the discriminated result. Regression test for
|
||||
discriminate() dropping these fields (same class of bug as combine()).
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
auto_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google", "openai"],
|
||||
"credentials_types": ["oauth2"],
|
||||
"credentials_scopes": ["drive.readonly"],
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": "spreadsheet",
|
||||
"discriminator": "model",
|
||||
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
discriminated = auto_field.discriminate("gemini")
|
||||
|
||||
assert discriminated.is_auto_credential is True
|
||||
assert discriminated.input_field_name == "spreadsheet"
|
||||
assert discriminated.provider == frozenset(["google"])
|
||||
|
||||
|
||||
def test_discriminate_preserves_regular_credential_defaults():
|
||||
"""Regular credentials should have is_auto_credential=False after discriminate()."""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
regular_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google", "openai"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
"discriminator": "model",
|
||||
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
discriminated = regular_field.discriminate("gpt-4")
|
||||
|
||||
assert discriminated.is_auto_credential is False
|
||||
assert discriminated.input_field_name is None
|
||||
assert discriminated.provider == frozenset(["openai"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for credentials_input_schema excluding auto_credentials
|
||||
def test_credentials_input_schema_excludes_auto_creds():
|
||||
"""
|
||||
GraphModel.credentials_input_schema should exclude auto_credentials
|
||||
(is_auto_credential=True) from the schema. Auto_credentials are
|
||||
transparently resolved at execution time via file picker data.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
graph = GraphModel(
|
||||
id="test-graph",
|
||||
version=1,
|
||||
name="Test",
|
||||
description="Test",
|
||||
user_id="test-user",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
nodes=[
|
||||
NodeModel(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
graph_id="test-graph",
|
||||
graph_version=1,
|
||||
),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
# Mock regular_credentials_inputs to return only the non-auto field (3-tuple)
|
||||
regular_only = {
|
||||
"github_credentials": (
|
||||
regular_field_info,
|
||||
{("node-1", "credentials")},
|
||||
True,
|
||||
),
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(graph),
|
||||
"regular_credentials_inputs",
|
||||
new_callable=PropertyMock,
|
||||
return_value=regular_only,
|
||||
):
|
||||
schema = graph.credentials_input_schema
|
||||
field_names = set(schema.get("properties", {}).keys())
|
||||
# Should include regular credential but NOT auto_credential
|
||||
assert "github_credentials" in field_names
|
||||
assert "google_credentials" not in field_names
|
||||
|
||||
@@ -571,6 +571,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
discriminator_values: set[Any] = Field(default_factory=set)
|
||||
is_auto_credential: bool = False
|
||||
input_field_name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
@@ -651,6 +653,9 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
+ "_credentials"
|
||||
)
|
||||
|
||||
# Propagate is_auto_credential from the combined field.
|
||||
# All fields in a group should share the same is_auto_credential
|
||||
# value since auto and regular credentials serve different purposes.
|
||||
result[group_key] = (
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
@@ -659,6 +664,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
discriminator_values=set(all_discriminator_values),
|
||||
is_auto_credential=combined.is_auto_credential,
|
||||
input_field_name=combined.input_field_name,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
@@ -684,6 +691,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
is_auto_credential=self.is_auto_credential,
|
||||
input_field_name=self.input_field_name,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -172,6 +172,81 @@ def execute_graph(
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def _acquire_auto_credentials(
|
||||
input_model: type[BlockSchema],
|
||||
input_data: dict[str, Any],
|
||||
creds_manager: "IntegrationCredentialsManager",
|
||||
user_id: str,
|
||||
) -> tuple[dict[str, Any], list[AsyncRedisLock]]:
|
||||
"""
|
||||
Resolve auto_credentials from GoogleDriveFileField-style inputs.
|
||||
|
||||
Returns:
|
||||
(extra_exec_kwargs, locks): kwargs to inject into block execution, and
|
||||
credential locks to release after execution completes.
|
||||
"""
|
||||
extra_exec_kwargs: dict[str, Any] = {}
|
||||
locks: list[AsyncRedisLock] = []
|
||||
|
||||
# NOTE: If a block ever has multiple auto-credential fields, a ValueError
|
||||
# on a later field will strand locks acquired for earlier fields. They'll
|
||||
# auto-expire via Redis TTL, but add a try/except to release partial locks
|
||||
# if that becomes a real scenario.
|
||||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||||
field_name = info["field_name"]
|
||||
field_data = input_data.get(field_name)
|
||||
|
||||
if field_data and isinstance(field_data, dict):
|
||||
# Check if _credentials_id key exists in the field data
|
||||
if "_credentials_id" in field_data:
|
||||
cred_id = field_data["_credentials_id"]
|
||||
if cred_id:
|
||||
# Credential ID provided - acquire credentials
|
||||
provider = info.get("config", {}).get(
|
||||
"provider", "external service"
|
||||
)
|
||||
file_name = field_data.get("name", "selected file")
|
||||
try:
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, cred_id
|
||||
)
|
||||
locks.append(lock)
|
||||
extra_exec_kwargs[kwarg_name] = credentials
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"{provider.capitalize()} credentials for "
|
||||
f"'{file_name}' in field '{field_name}' are not "
|
||||
f"available in your account. "
|
||||
f"This can happen if the agent was created by another "
|
||||
f"user or the credentials were deleted. "
|
||||
f"Please open the agent in the builder and re-select "
|
||||
f"the file to authenticate with your own account."
|
||||
)
|
||||
# else: _credentials_id is explicitly None, skip (chained data)
|
||||
else:
|
||||
# _credentials_id key missing entirely - this is an error
|
||||
provider = info.get("config", {}).get("provider", "external service")
|
||||
file_name = field_data.get("name", "selected file")
|
||||
raise ValueError(
|
||||
f"Authentication missing for '{file_name}' in field "
|
||||
f"'{field_name}'. Please re-select the file to authenticate "
|
||||
f"with {provider.capitalize()}."
|
||||
)
|
||||
elif field_data is None and field_name not in input_data:
|
||||
# Field not in input_data at all = connected from upstream block, skip
|
||||
pass
|
||||
else:
|
||||
# field_data is None/empty but key IS in input_data = user didn't select
|
||||
provider = info.get("config", {}).get("provider", "external service")
|
||||
raise ValueError(
|
||||
f"No file selected for '{field_name}'. "
|
||||
f"Please select a file to provide "
|
||||
f"{provider.capitalize()} authentication."
|
||||
)
|
||||
|
||||
return extra_exec_kwargs, locks
|
||||
|
||||
|
||||
async def execute_node(
|
||||
node: Node,
|
||||
data: NodeExecutionEntry,
|
||||
@@ -271,41 +346,14 @@ async def execute_node(
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||||
field_name = info["field_name"]
|
||||
field_data = input_data.get(field_name)
|
||||
if field_data and isinstance(field_data, dict):
|
||||
# Check if _credentials_id key exists in the field data
|
||||
if "_credentials_id" in field_data:
|
||||
cred_id = field_data["_credentials_id"]
|
||||
if cred_id:
|
||||
# Credential ID provided - acquire credentials
|
||||
provider = info.get("config", {}).get(
|
||||
"provider", "external service"
|
||||
)
|
||||
file_name = field_data.get("name", "selected file")
|
||||
try:
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, cred_id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
extra_exec_kwargs[kwarg_name] = credentials
|
||||
except ValueError:
|
||||
# Credential was deleted or doesn't exist
|
||||
raise ValueError(
|
||||
f"Authentication expired for '{file_name}' in field '{field_name}'. "
|
||||
f"The saved {provider.capitalize()} credentials no longer exist. "
|
||||
f"Please re-select the file to re-authenticate."
|
||||
)
|
||||
# else: _credentials_id is explicitly None, skip credentials (for chained data)
|
||||
else:
|
||||
# _credentials_id key missing entirely - this is an error
|
||||
provider = info.get("config", {}).get("provider", "external service")
|
||||
file_name = field_data.get("name", "selected file")
|
||||
raise ValueError(
|
||||
f"Authentication missing for '{file_name}' in field '{field_name}'. "
|
||||
f"Please re-select the file to authenticate with {provider.capitalize()}."
|
||||
)
|
||||
auto_extra_kwargs, auto_locks = await _acquire_auto_credentials(
|
||||
input_model=input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=creds_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
extra_exec_kwargs.update(auto_extra_kwargs)
|
||||
creds_locks.extend(auto_locks)
|
||||
|
||||
output_size = 0
|
||||
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Tests for auto_credentials handling in execute_node().
|
||||
|
||||
These test the _acquire_auto_credentials() helper function extracted from
|
||||
execute_node() (manager.py lines 273-308).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_drive_file_data():
|
||||
return {
|
||||
"valid": {
|
||||
"_credentials_id": "cred-id-123",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
"chained": {
|
||||
"_credentials_id": None,
|
||||
"id": "file-456",
|
||||
"name": "chained.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
"missing_key": {
|
||||
"id": "file-789",
|
||||
"name": "bad.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_input_model(mocker: MockerFixture):
|
||||
"""Create a mock input model with get_auto_credentials_fields() returning one field."""
|
||||
input_model = mocker.MagicMock()
|
||||
input_model.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {
|
||||
"provider": "google",
|
||||
"type": "oauth2",
|
||||
"scopes": ["https://www.googleapis.com/auth/drive.readonly"],
|
||||
},
|
||||
}
|
||||
}
|
||||
return input_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_creds_manager(mocker: MockerFixture):
|
||||
manager = mocker.AsyncMock()
|
||||
mock_lock = mocker.AsyncMock()
|
||||
mock_creds = mocker.MagicMock()
|
||||
mock_creds.id = "cred-id-123"
|
||||
mock_creds.provider = "google"
|
||||
manager.acquire.return_value = (mock_creds, mock_lock)
|
||||
return manager, mock_creds, mock_lock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_happy_path(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""When field_data has a valid _credentials_id, credentials should be acquired."""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, mock_creds, mock_lock = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||
assert extra_kwargs["credentials"] == mock_creds
|
||||
assert mock_lock in locks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_field_none_static_raises(
|
||||
mocker: MockerFixture,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
[THE BUG FIX TEST — OPEN-2895]
|
||||
When field_data is None and the key IS in input_data (user didn't select a file),
|
||||
should raise ValueError instead of silently skipping.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
# Key is present but value is None = user didn't select a file
|
||||
input_data = {"spreadsheet": None}
|
||||
|
||||
with pytest.raises(ValueError, match="No file selected"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_field_absent_skips(
|
||||
mocker: MockerFixture,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
When the field key is NOT in input_data at all (upstream connection),
|
||||
should skip without error.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
# Key not present = connected from upstream block
|
||||
input_data = {}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
manager.acquire.assert_not_called()
|
||||
assert "credentials" not in extra_kwargs
|
||||
assert locks == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_chained_cred_id_none(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
When _credentials_id is explicitly None (chained data from upstream),
|
||||
should skip credential acquisition.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["chained"]}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
manager.acquire.assert_not_called()
|
||||
assert "credentials" not in extra_kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_missing_cred_id_key_raises(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
When _credentials_id key is missing entirely from field_data dict,
|
||||
should raise ValueError.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["missing_key"]}
|
||||
|
||||
with pytest.raises(ValueError, match="Authentication missing"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_ownership_mismatch_error(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When acquire() raises ValueError (credential belongs to another user),
|
||||
the error message should mention 'not available' (not 'expired').
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
manager.acquire.side_effect = ValueError(
|
||||
"Credentials #cred-id-123 for user #user-2 not found"
|
||||
)
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
with pytest.raises(ValueError, match="not available in your account"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_deleted_credential_error(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When acquire() raises ValueError (credential was deleted),
|
||||
the error message should mention 'not available' (not 'expired').
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
manager.acquire.side_effect = ValueError(
|
||||
"Credentials #cred-id-123 for user #user-1 not found"
|
||||
)
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
with pytest.raises(ValueError, match="not available in your account"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_lock_appended(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""Lock from acquire() should be included in returned locks list."""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, mock_lock = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(locks) == 1
|
||||
assert locks[0] is mock_lock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_multiple_fields(
|
||||
mocker: MockerFixture,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""When there are multiple auto_credentials fields, only valid ones should acquire."""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, mock_creds, mock_lock = mock_creds_manager
|
||||
|
||||
input_model = mocker.MagicMock()
|
||||
input_model.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
},
|
||||
"credentials2": {
|
||||
"field_name": "doc_file",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
},
|
||||
}
|
||||
|
||||
input_data = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "cred-id-123",
|
||||
"id": "file-1",
|
||||
"name": "file1.xlsx",
|
||||
},
|
||||
"doc_file": {
|
||||
"_credentials_id": None,
|
||||
"id": "file-2",
|
||||
"name": "chained.doc",
|
||||
},
|
||||
}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
# Only the first field should have acquired credentials
|
||||
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||
assert "credentials" in extra_kwargs
|
||||
assert "credentials2" not in extra_kwargs
|
||||
assert len(locks) == 1
|
||||
@@ -259,7 +259,8 @@ async def _validate_node_input_credentials(
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = block.input_schema.get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
auto_credentials_fields = block.input_schema.get_auto_credentials_fields()
|
||||
if not credentials_fields and not auto_credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
@@ -339,6 +340,47 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# Validate auto-credentials (GoogleDriveFileField-based)
|
||||
# These have _credentials_id embedded in the file field data
|
||||
if auto_credentials_fields:
|
||||
for _kwarg_name, info in auto_credentials_fields.items():
|
||||
field_name = info["field_name"]
|
||||
# Check input_default and nodes_input_masks for the field value
|
||||
field_value = node.input_default.get(field_name)
|
||||
if nodes_input_masks and node.id in nodes_input_masks:
|
||||
field_value = nodes_input_masks[node.id].get(
|
||||
field_name, field_value
|
||||
)
|
||||
|
||||
if field_value and isinstance(field_value, dict):
|
||||
if "_credentials_id" not in field_value:
|
||||
# Key removed (e.g., on fork) — needs re-auth
|
||||
has_missing_credentials = True
|
||||
credential_errors[node.id][field_name] = (
|
||||
"Authentication missing for the selected file. "
|
||||
"Please re-select the file to authenticate with "
|
||||
"your own account."
|
||||
)
|
||||
continue
|
||||
cred_id = field_value.get("_credentials_id")
|
||||
if cred_id and isinstance(cred_id, str):
|
||||
try:
|
||||
creds_store = get_integration_credentials_store()
|
||||
creds = await creds_store.get_creds_by_id(user_id, cred_id)
|
||||
except Exception as e:
|
||||
has_missing_credentials = True
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
continue
|
||||
if not creds:
|
||||
has_missing_credentials = True
|
||||
credential_errors[node.id][field_name] = (
|
||||
"The saved credentials are not available "
|
||||
"for your account. Please re-select the file to "
|
||||
"authenticate with your own account."
|
||||
)
|
||||
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
@@ -370,8 +412,9 @@ def make_node_credentials_input_map(
|
||||
"""
|
||||
result: dict[str, dict[str, JsonValue]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
# Only map regular credentials (not auto_credentials, which are resolved
|
||||
# at execution time from _credentials_id in file field data)
|
||||
graph_cred_inputs = graph.regular_credentials_inputs
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
|
||||
@@ -907,3 +907,335 @@ async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
||||
|
||||
# Verify both parent and child status updates
|
||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for auto_credentials validation in _validate_node_input_credentials
|
||||
# (Fix 3: SECRT-1772 + Fix 4: Path 4)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_auto_creds_valid(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When a node has auto_credentials with a valid _credentials_id
|
||||
that exists in the store, validation should pass without errors.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-auto-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "valid-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
}
|
||||
}
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
# No regular credentials fields
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
# Has auto_credentials fields
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Mock the credentials store to return valid credentials
|
||||
mock_store = mocker.MagicMock()
|
||||
mock_creds = mocker.MagicMock()
|
||||
mock_creds.id = "valid-cred-id"
|
||||
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=mock_creds)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_integration_credentials_store",
|
||||
return_value=mock_store,
|
||||
)
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
assert mock_node.id not in errors
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_auto_creds_missing(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When a node has auto_credentials with a _credentials_id
|
||||
that doesn't exist for the current user, validation should report an error.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-bad-auto-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "other-users-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
}
|
||||
}
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Mock the credentials store to return None (cred not found for this user)
|
||||
mock_store = mocker.MagicMock()
|
||||
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_integration_credentials_store",
|
||||
return_value=mock_store,
|
||||
)
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="different-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
assert mock_node.id in errors
|
||||
assert "spreadsheet" in errors[mock_node.id]
|
||||
assert "not available" in errors[mock_node.id]["spreadsheet"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_both_regular_and_auto(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] A node that has BOTH regular credentials AND auto_credentials
|
||||
should have both validated.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-both-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"credentials": {
|
||||
"id": "regular-cred-id",
|
||||
"provider": "github",
|
||||
"type": "api_key",
|
||||
},
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "auto-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
},
|
||||
}
|
||||
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_credentials_meta = mocker.MagicMock()
|
||||
mock_credentials_meta.id = "regular-cred-id"
|
||||
mock_credentials_meta.provider = "github"
|
||||
mock_credentials_meta.type = "api_key"
|
||||
mock_credentials_field_type.model_validate.return_value = mock_credentials_meta
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
# Regular credentials field
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type,
|
||||
}
|
||||
# Auto-credentials field
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"auto_credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Mock the credentials store to return valid credentials for both
|
||||
mock_store = mocker.MagicMock()
|
||||
mock_regular_creds = mocker.MagicMock()
|
||||
mock_regular_creds.id = "regular-cred-id"
|
||||
mock_regular_creds.provider = "github"
|
||||
mock_regular_creds.type = "api_key"
|
||||
|
||||
mock_auto_creds = mocker.MagicMock()
|
||||
mock_auto_creds.id = "auto-cred-id"
|
||||
|
||||
def get_creds_side_effect(user_id, cred_id):
|
||||
if cred_id == "regular-cred-id":
|
||||
return mock_regular_creds
|
||||
elif cred_id == "auto-cred-id":
|
||||
return mock_auto_creds
|
||||
return None
|
||||
|
||||
mock_store.get_creds_by_id = mocker.AsyncMock(side_effect=get_creds_side_effect)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_integration_credentials_store",
|
||||
return_value=mock_store,
|
||||
)
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Both should validate successfully - no errors
|
||||
assert mock_node.id not in errors
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_auto_creds_skipped_when_none(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
When a node has auto_credentials but the field value has _credentials_id=None
|
||||
(e.g., from upstream connection), validation should skip it without error.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-chained-auto-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": None,
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
}
|
||||
}
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# No error - chained data with None cred_id is valid
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for CredentialsFieldInfo auto_credential tag (Fix 4: Path 4)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_credentials_field_info_auto_credential_tag():
|
||||
"""
|
||||
[Path 4] CredentialsFieldInfo should support is_auto_credential and
|
||||
input_field_name fields for distinguishing auto from regular credentials.
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
# Regular credential should have is_auto_credential=False by default
|
||||
regular = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
assert regular.is_auto_credential is False
|
||||
assert regular.input_field_name is None
|
||||
|
||||
# Auto credential should have is_auto_credential=True
|
||||
auto = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google"],
|
||||
"credentials_types": ["oauth2"],
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": "spreadsheet",
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
assert auto.is_auto_credential is True
|
||||
assert auto.input_field_name == "spreadsheet"
|
||||
|
||||
|
||||
def test_make_node_credentials_input_map_excludes_auto_creds(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[Path 4] make_node_credentials_input_map should only include regular credentials,
|
||||
not auto_credentials (which are resolved at execution time).
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import make_node_credentials_input_map
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Create a mock graph with aggregate_credentials_inputs that returns
|
||||
# both regular and auto credentials
|
||||
mock_graph = mocker.MagicMock()
|
||||
|
||||
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
# Mock regular_credentials_inputs property (auto_credentials are excluded)
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_creds": (regular_field_info, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
graph_credentials_input = {
|
||||
"github_creds": CredentialsMetaInput(
|
||||
id="cred-123",
|
||||
provider=ProviderName("github"),
|
||||
type="api_key",
|
||||
),
|
||||
}
|
||||
|
||||
result = make_node_credentials_input_map(mock_graph, graph_credentials_input)
|
||||
|
||||
# Regular credentials should be mapped
|
||||
assert "node-1" in result
|
||||
assert "credentials" in result["node-1"]
|
||||
|
||||
# Auto credentials should NOT appear in the result
|
||||
# (they would have been mapped to the kwarg_name "credentials" not "spreadsheet")
|
||||
for node_id, fields in result.items():
|
||||
for field_name, value in fields.items():
|
||||
# Verify no auto-credential phantom entries
|
||||
if isinstance(value, dict):
|
||||
assert "_credentials_id" not in value
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Validation utilities."""
|
||||
|
||||
import re
|
||||
|
||||
_UUID_V4_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_uuid_v4(text: str) -> bool:
|
||||
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
|
||||
|
||||
|
||||
def extract_uuids(text: str) -> list[str]:
|
||||
return sorted({m.lower() for m in _UUID_V4_PATTERN.findall(text)})
|
||||
6751
autogpt_platform/backend/poetry.lock
generated
6751
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,7 @@ cryptography = "^45.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.128.0"
|
||||
fastapi = "^0.116.1"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.3.0"
|
||||
google-api-python-client = "^2.177.0"
|
||||
@@ -35,7 +35,7 @@ jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.25.0"
|
||||
langfuse = "^3.11.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
ollama = "^0.5.1"
|
||||
@@ -52,8 +52,8 @@ prometheus-client = "^0.22.1"
|
||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||
psutil = "^7.0.0"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
pydantic = { extras = ["email"], version = "^2.12.5" }
|
||||
pydantic-settings = "^2.12.0"
|
||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||
pydantic-settings = "^2.10.1"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
python-dotenv = "^1.1.1"
|
||||
@@ -65,11 +65,11 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
||||
sqlalchemy = "^2.0.40"
|
||||
strenum = "^0.4.9"
|
||||
stripe = "^11.5.0"
|
||||
supabase = "2.27.2"
|
||||
supabase = "2.17.0"
|
||||
tenacity = "^9.1.2"
|
||||
todoist-api-python = "^2.1.7"
|
||||
tweepy = "^4.16.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||
websockets = "^15.0"
|
||||
youtube-transcript-api = "^1.2.1"
|
||||
yt-dlp = "2025.12.08"
|
||||
|
||||
@@ -12307,9 +12307,7 @@
|
||||
"title": "Location"
|
||||
},
|
||||
"msg": { "type": "string", "title": "Message" },
|
||||
"type": { "type": "string", "title": "Error Type" },
|
||||
"input": { "title": "Input" },
|
||||
"ctx": { "type": "object", "title": "Context" }
|
||||
"type": { "type": "string", "title": "Error Type" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["loc", "msg", "type"],
|
||||
|
||||
@@ -4,7 +4,9 @@ import { loadScript } from "@/services/scripts/scripts";
|
||||
export async function loadGoogleAPIPicker(): Promise<void> {
|
||||
validateWindow();
|
||||
|
||||
await loadScript("https://apis.google.com/js/api.js");
|
||||
await loadScript("https://apis.google.com/js/api.js", {
|
||||
referrerPolicy: "no-referrer-when-downgrade",
|
||||
});
|
||||
|
||||
const googleAPI = window.gapi;
|
||||
if (!googleAPI) {
|
||||
@@ -27,7 +29,9 @@ export async function loadGoogleIdentityServices(): Promise<void> {
|
||||
throw new Error("Google Identity Services cannot load on server");
|
||||
}
|
||||
|
||||
await loadScript("https://accounts.google.com/gsi/client");
|
||||
await loadScript("https://accounts.google.com/gsi/client", {
|
||||
referrerPolicy: "no-referrer-when-downgrade",
|
||||
});
|
||||
|
||||
const google = window.google;
|
||||
if (!google?.accounts?.oauth2) {
|
||||
|
||||
Reference in New Issue
Block a user