mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
17 Commits
fix/copilo
...
feat/async
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec7c7ebea2 | ||
|
|
8ef8bec14f | ||
|
|
9b3e25d98e | ||
|
|
0bc098acb1 | ||
|
|
d78e0ee122 | ||
|
|
0e72e1f5e7 | ||
|
|
163b0b3c9d | ||
|
|
ef42b17e3b | ||
|
|
a18ffd0b21 | ||
|
|
e40c8c70ce | ||
|
|
9cdcd6793f | ||
|
|
fc64f83331 | ||
|
|
7718c49f05 | ||
|
|
0a1591fce2 | ||
|
|
681bb7c2b4 | ||
|
|
0818cd6683 | ||
|
|
7a39bdfaf8 |
@@ -190,5 +190,8 @@ ZEROBOUNCE_API_KEY=
|
||||
POSTHOG_API_KEY=
|
||||
POSTHOG_HOST=https://eu.i.posthog.com
|
||||
|
||||
# Tally Form Integration (pre-populate business understanding on signup)
|
||||
TALLY_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -2,23 +2,19 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -46,9 +42,6 @@ from backend.copilot.tools.models import (
|
||||
InputValidationErrorResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
@@ -99,10 +92,8 @@ class CreateSessionResponse(BaseModel):
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
turn_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
@@ -132,12 +123,11 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
class CancelSessionResponse(BaseModel):
|
||||
"""Response model for the cancel session endpoint."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
cancelled: bool
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
@@ -262,7 +252,7 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
@@ -280,28 +270,21 @@ async def get_session(
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
@@ -314,6 +297,51 @@ async def get_session(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/cancel",
|
||||
status_code=200,
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelSessionResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||
polls Redis until the task status flips from ``running`` or a timeout
|
||||
(5 s) is reached. Returns only after the cancellation is confirmed.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
if not active_session:
|
||||
return CancelSessionResponse(cancelled=True, reason="no_active_session")
|
||||
|
||||
await enqueue_cancel_task(session_id)
|
||||
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
|
||||
|
||||
# Poll until the executor confirms the task is no longer running.
|
||||
poll_interval = 0.5
|
||||
max_wait = 5.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
await asyncio.sleep(poll_interval)
|
||||
waited += poll_interval
|
||||
session_state = await stream_registry.get_session(session_id)
|
||||
if session_state is None or session_state.status != "running":
|
||||
logger.info(
|
||||
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
|
||||
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
|
||||
)
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
logger.warning(
|
||||
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
|
||||
)
|
||||
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
@@ -331,16 +359,15 @@ async def stream_chat_post(
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
@@ -387,35 +414,35 @@ async def stream_chat_post(
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
log_meta["task_id"] = task_id
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
@@ -432,7 +459,7 @@ async def stream_chat_post(
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -440,11 +467,12 @@ async def stream_chat_post(
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -527,19 +555,19 @@ async def stream_chat_post(
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -586,17 +614,22 @@ async def resume_session_stream(
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_task:
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=active_task.task_id,
|
||||
# Subscribe from the beginning ("0-0") to replay all chunks for this turn.
|
||||
# This is necessary because hydrated messages filter out incomplete tool calls
|
||||
# to avoid "No tool invocation found" errors. The resume stream delivers
|
||||
# those tool calls fresh with proper SDK state.
|
||||
# The AI SDK's deduplication will handle any duplicate chunks.
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||
last_message_id="0-0",
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -632,12 +665,12 @@ async def resume_session_stream(
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
active_task.task_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
@@ -688,229 +721,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@@ -991,9 +801,6 @@ ToolResponseUnion = (
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -126,6 +126,9 @@ v1_router = APIRouter()
|
||||
########################################################
|
||||
|
||||
|
||||
_tally_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
@@ -134,6 +137,24 @@ v1_router = APIRouter()
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
from backend.data.tally import populate_understanding_from_tally
|
||||
|
||||
task = asyncio.create_task(
|
||||
populate_understanding_from_tally(user.id, user.email)
|
||||
)
|
||||
_tally_background_tasks.add(task)
|
||||
task.add_done_callback(_tally_background_tasks.discard)
|
||||
except Exception:
|
||||
logger.debug("Failed to start Tally population task", exc_info=True)
|
||||
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -43,6 +43,7 @@ def test_get_or_create_user_route(
|
||||
) -> None:
|
||||
"""Test get or create user endpoint"""
|
||||
mock_user = Mock()
|
||||
mock_user.created_at = datetime.now(timezone.utc)
|
||||
mock_user.model_dump.return_value = {
|
||||
"id": test_user_id,
|
||||
"email": "test@example.com",
|
||||
|
||||
@@ -42,10 +42,6 @@ import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -123,21 +119,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for Redis Streams notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
|
||||
# Run the last process in the foreground.
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
finally:
|
||||
for process in processes:
|
||||
for process in reversed(processes):
|
||||
try:
|
||||
process.stop()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||
completion notifications (OperationCompleteMessage) from external services
|
||||
(like Agent Generator) and triggers the appropriate stream registry and
|
||||
chat service updates via process_operation_success/process_operation_failure.
|
||||
|
||||
Why Redis Streams instead of RabbitMQ?
|
||||
--------------------------------------
|
||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||
queue), Redis Streams was chosen for chat completion notifications because:
|
||||
|
||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||
Streams for completion notifications keeps all chat streaming infrastructure
|
||||
in one system, simplifying operations and reducing cross-system coordination.
|
||||
|
||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||
allowing consumers to replay missed messages after reconnection. This aligns
|
||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||
|
||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||
recovering from dead consumers - ideal for the completion callback pattern.
|
||||
|
||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||
|
||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||
transactional semantics without distributed coordination.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||
stale pending messages from dead consumers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
Database operations are handled through the chat_db() accessor, which
|
||||
routes through DatabaseManager RPC when Prisma is not directly connected.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# First, claim any stale pending messages from dead consumers
|
||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||
# claim them using XAUTOCLAIM
|
||||
try:
|
||||
claimed_result = await redis.xautoclaim(
|
||||
name=config.stream_completion_name,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
min_idle_time=config.stream_claim_min_idle_ms,
|
||||
start_id="0-0",
|
||||
count=10,
|
||||
)
|
||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||
if claimed_result and len(claimed_result) >= 2:
|
||||
claimed_entries = claimed_result[1]
|
||||
if claimed_entries:
|
||||
logger.info(
|
||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||
)
|
||||
for entry_id, data in claimed_entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _process_entry(
|
||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Process a single stream entry and acknowledge it on success.
|
||||
|
||||
Args:
|
||||
redis: Redis client connection
|
||||
entry_id: The stream entry ID
|
||||
data: The entry data dict
|
||||
"""
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message after successful processing
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message remains in pending state and will be claimed by
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
await process_operation_success(task, message.result)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
await process_operation_failure(task, message.error)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
@@ -1,329 +0,0 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
# Keys that should be stripped from agent_json when returning in error responses
|
||||
SENSITIVE_KEYS = frozenset(
|
||||
{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api_secret",
|
||||
"password",
|
||||
"secret",
|
||||
"credentials",
|
||||
"credential",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"private_key",
|
||||
"privatekey",
|
||||
"auth",
|
||||
"authorization",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_agent_json(obj: Any) -> Any:
|
||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||
|
||||
Args:
|
||||
obj: The object to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy with sensitive keys removed/redacted
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_sanitize_agent_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class ToolMessageUpdateError(Exception):
|
||||
"""Raised when updating a tool message in the database fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Update tool message in database using the chat_db accessor.
|
||||
|
||||
Routes through DatabaseManager RPC when Prisma is not directly
|
||||
connected (e.g. in the CoPilot Executor microservice).
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails.
|
||||
"""
|
||||
try:
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=content,
|
||||
)
|
||||
if not updated:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id="
|
||||
f"{tool_call_id} in session {session_id}"
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize. Can be a dict, list, string,
|
||||
number, boolean, or None.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||
only when result is explicitly None.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result is None:
|
||||
return '{"status": "completed"}'
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
# Sanitize agent_json to remove sensitive keys before returning
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": _sanitize_agent_json(agent_json),
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task
|
||||
will be marked as failed instead of completed.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output (only substitute default when result is exactly None)
|
||||
result_output = result if result is not None else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# If this fails, we must not continue to mark the task as completed
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||
"marking as failed instead of completed"
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText="Failed to save operation result to database"),
|
||||
)
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
raise
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database
|
||||
with the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
|
||||
# Update pending operation with error
|
||||
# If this fails, we still continue to mark the task as failed
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||
"continuing with cleanup"
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -27,7 +27,6 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
@@ -37,52 +36,29 @@ class ChatConfig(BaseSettings):
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_lock_ttl: int = Field(
|
||||
default=120,
|
||||
description="TTL in seconds for stream lock (2 minutes). Short timeout allows "
|
||||
"reconnection after refresh/crash without long waits.",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
stream_claim_min_idle_ms: int = Field(
|
||||
default=60000,
|
||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
session_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
description="Prefix for session metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
turn_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
description="Prefix for turn message stream keys",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
@@ -154,14 +130,6 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
@@ -92,10 +93,9 @@ async def add_chat_message(
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> ChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
data: ChatMessageCreateInput = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
@@ -123,7 +123,7 @@ async def add_chat_message(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
PrismaChatMessage.prisma().create(data=data),
|
||||
)
|
||||
return ChatMessage.from_db(message)
|
||||
|
||||
@@ -132,58 +132,93 @@ async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[ChatMessage]:
|
||||
) -> int:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
Uses collision detection with retry: tries to create messages starting
|
||||
at start_sequence. If a unique constraint violation occurs (e.g., the
|
||||
streaming loop and long-running callback race), queries the latest
|
||||
sequence and retries with the correct offset. This avoids unnecessary
|
||||
upserts and DB queries in the common case (no collision).
|
||||
|
||||
Returns:
|
||||
Next sequence number for the next message to be inserted. This equals
|
||||
start_sequence + len(messages) and allows callers to update their
|
||||
counters even when collision detection adjusts start_sequence.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
# No messages to add - return current count
|
||||
return start_sequence
|
||||
|
||||
created_messages = []
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Single timestamp for all messages and session update
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
async with db.transaction() as tx:
|
||||
# Build all message data
|
||||
messages_data = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
# Note: create_many doesn't support nested creates, use sessionId directly
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
messages_data.append(data)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
# Run create_many and session update in parallel within transaction
|
||||
# Both use the same timestamp for consistency
|
||||
await asyncio.gather(
|
||||
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
|
||||
PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": now},
|
||||
),
|
||||
)
|
||||
|
||||
return [ChatMessage.from_db(m) for m in created_messages]
|
||||
# Return next sequence number for counter sync
|
||||
return start_sequence + len(messages)
|
||||
|
||||
except UniqueViolationError:
|
||||
if attempt < max_retries - 1:
|
||||
# Collision detected - query MAX(sequence)+1 and retry with correct offset
|
||||
logger.info(
|
||||
f"Collision detected for session {session_id} at sequence "
|
||||
f"{start_sequence}, querying DB for latest sequence"
|
||||
)
|
||||
start_sequence = await get_next_sequence(session_id)
|
||||
logger.info(
|
||||
f"Retrying batch insert with start_sequence={start_sequence}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - propagate error
|
||||
raise
|
||||
|
||||
# Should never reach here due to raise in exception handler
|
||||
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
@@ -237,10 +272,20 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
async def get_next_sequence(session_id: str) -> int:
|
||||
"""Get the next sequence number for a new message in this session.
|
||||
|
||||
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
|
||||
More robust than COUNT(*) because it's immune to deleted messages.
|
||||
|
||||
Optimized to select only the sequence column using raw SQL.
|
||||
The unique index on (sessionId, sequence) makes this query fast.
|
||||
"""
|
||||
results = await db.query_raw_with_schema(
|
||||
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
|
||||
session_id,
|
||||
)
|
||||
return 0 if not results else results[0]["sequence"] + 1
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
|
||||
@@ -25,7 +25,7 @@ from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_task, init_worker
|
||||
from .processor import execute_copilot_turn, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
@@ -181,13 +181,13 @@ class CoPilotExecutor(AppProcess):
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
for session_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
@@ -267,20 +267,20 @@ class CoPilotExecutor(AppProcess):
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
session_id = request.session_id
|
||||
if not session_id:
|
||||
logger.warning("Cancel message missing 'session_id'")
|
||||
return
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
if session_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {session_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
_, cancel_event = self.active_tasks[session_id]
|
||||
logger.info(f"Received cancel for {session_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
logger.debug(f"Cancel already set for {session_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
@@ -352,12 +352,12 @@ class CoPilotExecutor(AppProcess):
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
task_id = entry.task_id
|
||||
session_id = entry.session_id
|
||||
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
# Check for local duplicate - session is already running on this executor
|
||||
if session_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
f"Session {session_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
@@ -365,53 +365,53 @@ class CoPilotExecutor(AppProcess):
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
key=f"copilot:session:{session_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
logger.warning(
|
||||
f"Session {session_id} already running on pod {current_owner}"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
f"Could not acquire lock for {session_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
self._task_locks[session_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
f"Acquired cluster lock for {session_id}, "
|
||||
f"executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
execute_copilot_turn, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
self.active_tasks[session_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
logger.warning(f"Failed to setup execution for {session_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
if session_id in self._task_locks:
|
||||
del self._task_locks[session_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
logger.info(f"Run completed for {session_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
logger.error(f"Execution for {session_id} failed: {exec_error}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
@@ -419,10 +419,10 @@ class CoPilotExecutor(AppProcess):
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
if session_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {session_id}")
|
||||
self._task_locks[session_id].release()
|
||||
del self._task_locks[session_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
@@ -433,11 +433,11 @@ class CoPilotExecutor(AppProcess):
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
with self._active_tasks_lock:
|
||||
for task_id, (future, _) in list(self.active_tasks.items()):
|
||||
for session_id, (future, _) in list(self.active_tasks.items()):
|
||||
if future.done():
|
||||
completed_tasks.append(task_id)
|
||||
self.active_tasks.pop(task_id, None)
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
completed_tasks.append(session_id)
|
||||
self.active_tasks.pop(session_id, None)
|
||||
logger.info(f"Cleaned up completed session {session_id}")
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot task execution
|
||||
This module contains the processor class that handles CoPilot session execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,7 @@ import time
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
@@ -32,17 +32,17 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]"
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_task(
|
||||
def execute_copilot_turn(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task using the thread-local processor.
|
||||
"""Execute a single CoPilot turn (user message → AI response).
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
entry: The turn payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
@@ -76,16 +76,16 @@ def cleanup_worker():
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot tasks.
|
||||
"""Per-worker execution logic for CoPilot sessions.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||
of CoPilot chat generation sessions. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. CoPilot task is picked from RabbitMQ queue
|
||||
2. Manager submits task to thread pool
|
||||
3. Processor executes the task in its event loop
|
||||
1. Session entry is picked from RabbitMQ queue
|
||||
2. Manager submits to thread pool
|
||||
3. Processor executes in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@@ -139,19 +139,17 @@ class CoPilotProcessor:
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task.
|
||||
"""Execute a CoPilot turn.
|
||||
|
||||
This is the main entry point for task execution. It runs the async
|
||||
execution logic in the worker's event loop and handles errors.
|
||||
Runs the async logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The task payload containing session and message info
|
||||
entry: The turn payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
@@ -185,11 +183,20 @@ class CoPilotProcessor:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||
# Note: _execute_async already marks the task as failed before re-raising,
|
||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||
# Safety net: if _execute_async's error handler failed to mark
|
||||
# the session (e.g. RuntimeError from SDK cleanup), do it here.
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message=str(e) or "Unknown error"
|
||||
),
|
||||
self.execution_loop,
|
||||
).result(timeout=5.0)
|
||||
except Exception as cleanup_err:
|
||||
log.error(f"Safety net mark_session_completed failed: {cleanup_err}")
|
||||
raise
|
||||
|
||||
async def _execute_async(
|
||||
@@ -199,16 +206,16 @@ class CoPilotProcessor:
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for CoPilot task.
|
||||
"""Async execution logic for a CoPilot turn.
|
||||
|
||||
This method calls the existing stream_chat_completion service function
|
||||
and publishes results to the stream registry.
|
||||
Calls the stream_chat_completion service function and publishes
|
||||
results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
entry: The turn payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger for this task
|
||||
log: Structured logger
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
@@ -228,7 +235,7 @@ class CoPilotProcessor:
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
@@ -236,52 +243,38 @@ class CoPilotProcessor:
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
):
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
)
|
||||
return
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
# Skip StreamFinish — mark_session_completed publishes it.
|
||||
if isinstance(chunk, StreamFinish):
|
||||
continue
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
log.info("Task completed successfully")
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error publishing chunk {type(chunk).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
error_message = "Operation cancelled" if cancel.is_set() else None
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message=error_message
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
|
||||
except BaseException as e:
|
||||
log.error(f"Turn failed: {e}")
|
||||
try:
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message=str(e) or "Unknown error"
|
||||
)
|
||||
except Exception as mark_err:
|
||||
log.error(f"mark_session_completed also failed: {mark_err}")
|
||||
raise
|
||||
|
||||
@@ -28,7 +28,7 @@ class CoPilotLogMetadata(TruncatedLogger):
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
@@ -135,18 +135,15 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
"""Chat session ID (also used for dedup/locking)"""
|
||||
|
||||
turn_id: str = ""
|
||||
"""Per-turn UUID for Redis stream isolation"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
@@ -160,40 +157,37 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
session_id: str
|
||||
"""Session ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
async def enqueue_copilot_turn(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
turn_id: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
session_id: Chat session ID (also used for dedup/locking)
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
turn_id: Per-turn UUID for Redis stream isolation
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
@@ -205,3 +199,20 @@ async def enqueue_copilot_task(
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
|
||||
async def enqueue_cancel_task(session_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot session.
|
||||
|
||||
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
|
||||
pods receive the cancellation signal.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
event = CancelCoPilotEvent(session_id=session_id)
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key="", # FANOUT ignores routing key
|
||||
message=event.model_dump_json(),
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
@@ -432,7 +432,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
@@ -449,16 +451,18 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
await _save_session_to_db(
|
||||
session,
|
||||
existing_message_count,
|
||||
skip_existence_check=existing_message_count > 0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
@@ -489,21 +493,31 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
session: ChatSession,
|
||||
existing_message_count: int,
|
||||
*,
|
||||
skip_existence_check: bool = False,
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
"""Save or update a chat session in the database.
|
||||
|
||||
Args:
|
||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
||||
and assume the session row already exists. Saves one DB round trip
|
||||
for incremental saves during streaming.
|
||||
"""
|
||||
db = chat_db()
|
||||
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
if not skip_existence_check:
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
@@ -562,9 +576,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session_id
|
||||
)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
|
||||
@@ -331,3 +331,96 @@ def test_to_openai_messages_merges_split_assistants():
|
||||
tc_list = merged.get("tool_calls")
|
||||
assert tc_list is not None and len(list(tc_list)) == 1
|
||||
assert list(tc_list)[0]["id"] == "tc1"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Concurrent save collision detection #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_saves_collision_detection(setup_test_user, test_user_id):
|
||||
"""Test that concurrent saves from streaming loop and callback handle collisions correctly.
|
||||
|
||||
Simulates the race condition where:
|
||||
1. Streaming loop starts with saved_msg_count=5
|
||||
2. Long-running callback appends message #5 and saves
|
||||
3. Streaming loop tries to save with stale count=5
|
||||
|
||||
The collision detection should handle this gracefully.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Create a session with initial messages
|
||||
session = ChatSession.new(user_id=test_user_id)
|
||||
for i in range(3):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if i % 2 == 0 else "assistant", content=f"Message {i}"
|
||||
)
|
||||
)
|
||||
|
||||
# Save initial messages
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Simulate streaming loop and callback saving concurrently
|
||||
async def streaming_loop_save():
|
||||
"""Simulates streaming loop saving messages."""
|
||||
# Add 2 messages
|
||||
session.messages.append(ChatMessage(role="user", content="Streaming message 1"))
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content="Streaming message 2")
|
||||
)
|
||||
|
||||
# Wait a bit to let callback potentially save first
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Save (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
async def callback_save():
|
||||
"""Simulates long-running callback saving a message."""
|
||||
# Add 1 message
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content="Callback result", tool_call_id="tc1")
|
||||
)
|
||||
|
||||
# Save immediately (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
# Run both saves concurrently - one will hit collision detection
|
||||
results = await asyncio.gather(streaming_loop_save(), callback_save())
|
||||
|
||||
# Both should succeed
|
||||
assert all(r is not None for r in results)
|
||||
|
||||
# Reload session from DB to verify
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key) # Clear cache to force DB load
|
||||
|
||||
loaded_session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert loaded_session is not None
|
||||
|
||||
# Should have all 6 messages (3 initial + 2 streaming + 1 callback)
|
||||
assert len(loaded_session.messages) == 6
|
||||
|
||||
# Verify no duplicate sequences
|
||||
sequences = []
|
||||
for i, msg in enumerate(loaded_session.messages):
|
||||
# Messages should have sequential sequence numbers starting from 0
|
||||
sequences.append(i)
|
||||
|
||||
# All sequences should be unique and sequential
|
||||
assert sequences == list(range(6))
|
||||
|
||||
# Verify message content is preserved
|
||||
contents = [m.content for m in loaded_session.messages]
|
||||
assert "Message 0" in contents
|
||||
assert "Message 1" in contents
|
||||
assert "Message 2" in contents
|
||||
assert "Streaming message 1" in contents
|
||||
assert "Streaming message 2" in contents
|
||||
assert "Callback result" in contents
|
||||
|
||||
@@ -14,7 +14,6 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
# Import here to allow module-level mocking if needed
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
@@ -32,7 +31,6 @@ async def test_parallel_tool_calls_run_concurrently():
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
# Minimal session mock
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
@@ -42,7 +40,7 @@ async def test_parallel_tool_calls_run_concurrently():
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
@@ -101,7 +99,7 @@ async def test_single_tool_call_works():
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
@@ -144,7 +142,7 @@ async def test_retryable_error_propagates():
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
@@ -175,8 +173,8 @@ async def test_retryable_error_propagates():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_lock_shared():
|
||||
"""All parallel tools should receive the same lock instance."""
|
||||
async def test_session_shared_across_parallel_tools():
|
||||
"""All parallel tools should receive the same session instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
@@ -199,10 +197,10 @@ async def test_session_lock_shared():
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_locks = []
|
||||
observed_sessions = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
observed_locks.append(lock)
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
observed_sessions.append(sess)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
@@ -222,9 +220,8 @@ async def test_session_lock_shared():
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_locks) == 3
|
||||
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
|
||||
assert isinstance(observed_locks[0], asyncio.Lock)
|
||||
assert len(observed_sessions) == 3
|
||||
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -251,7 +248,7 @@ async def test_cancellation_cleans_up():
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -12,6 +14,8 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
@@ -47,7 +51,8 @@ class StreamBaseResponse(BaseModel):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
json_str = self.model_dump_json(exclude_none=True)
|
||||
return f"data: {json_str}\n\n"
|
||||
|
||||
|
||||
# ========== Message Lifecycle ==========
|
||||
@@ -58,15 +63,13 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
taskId: str | None = Field(
|
||||
sessionId: str | None = Field(
|
||||
default=None,
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
description="Session ID for SSE reconnection.",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||
import json
|
||||
|
||||
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
|
||||
data: dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"messageId": self.messageId,
|
||||
@@ -163,8 +166,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
import json
|
||||
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"toolCallId": self.toolCallId,
|
||||
|
||||
57
autogpt_platform/backend/backend/copilot/sdk/dummy.py
Normal file
57
autogpt_platform/backend/backend/copilot/sdk/dummy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
Returns a simple streaming response with text deltas to test:
|
||||
- Streaming infrastructure works
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -0,0 +1,221 @@
|
||||
"""Tests for _format_conversation_context and _build_query_message."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import (
|
||||
_build_query_message,
|
||||
_format_conversation_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_conversation_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_format_empty_list():
|
||||
assert _format_conversation_context([]) is None
|
||||
|
||||
|
||||
def test_format_none_content_messages():
|
||||
msgs = [ChatMessage(role="user", content=None)]
|
||||
assert _format_conversation_context(msgs) is None
|
||||
|
||||
|
||||
def test_format_user_message():
|
||||
msgs = [ChatMessage(role="user", content="hello")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: hello" in result
|
||||
assert result.startswith("<conversation_history>")
|
||||
assert result.endswith("</conversation_history>")
|
||||
|
||||
|
||||
def test_format_assistant_text():
|
||||
msgs = [ChatMessage(role="assistant", content="hi there")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "You responded: hi there" in result
|
||||
|
||||
|
||||
def test_format_assistant_tool_calls():
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[{"function": {"name": "search", "arguments": '{"q": "test"}'}}],
|
||||
)
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'You called tool: search({"q": "test"})' in result
|
||||
|
||||
|
||||
def test_format_tool_result():
|
||||
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'Tool result: {"result": "ok"}' in result
|
||||
|
||||
|
||||
def test_format_tool_result_none_content():
|
||||
msgs = [ChatMessage(role="tool", content=None)]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "Tool result: " in result
|
||||
|
||||
|
||||
def test_format_full_conversation():
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="find agents"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="I'll search for agents.",
|
||||
tool_calls=[
|
||||
{"function": {"name": "find_agents", "arguments": '{"q": "test"}'}}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", content='[{"id": "1", "name": "Agent1"}]'),
|
||||
ChatMessage(role="assistant", content="Found Agent1."),
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: find agents" in result
|
||||
assert "You responded: I'll search for agents." in result
|
||||
assert "You called tool: find_agents" in result
|
||||
assert "Tool result:" in result
|
||||
assert "You responded: Found Agent1." in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_query_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(messages: list[ChatMessage]) -> ChatSession:
|
||||
"""Build a minimal ChatSession with the given messages."""
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=messages,
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_up_to_date():
|
||||
"""With --resume and transcript covers all messages, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="what's new?"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"what's new?",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
# transcript_msg_count == msg_count - 1, so no gap
|
||||
assert result == "what's new?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_stale_transcript():
|
||||
"""With --resume and stale transcript, gap context is prepended."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="turn 1"),
|
||||
ChatMessage(role="assistant", content="reply 1"),
|
||||
ChatMessage(role="user", content="turn 2"),
|
||||
ChatMessage(role="assistant", content="reply 2"),
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "turn 2" in result
|
||||
assert "reply 2" in result
|
||||
assert "Now, the user says:\nturn 3" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_zero_msg_count():
|
||||
"""With --resume but transcript_msg_count=0, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="new msg"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"new msg",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "new msg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_single_message():
|
||||
"""Without --resume and only 1 message, return raw message."""
|
||||
session = _make_session([ChatMessage(role="user", content="first")])
|
||||
result = await _build_query_message(
|
||||
"first",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "first"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
"""Without --resume and multiple messages, compress and prepend."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="older question"),
|
||||
ChatMessage(role="assistant", content="older answer"),
|
||||
ChatMessage(role="user", content="new question"),
|
||||
]
|
||||
)
|
||||
|
||||
# Mock _compress_conversation_history to return the messages as-is
|
||||
async def _mock_compress(sess):
|
||||
return sess.messages[:-1]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_conversation_history",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "older question" in result
|
||||
assert "older answer" in result
|
||||
assert "Now, the user says:\nnew question" in result
|
||||
@@ -47,19 +47,20 @@ class SDKResponseAdapter:
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
def __init__(self, message_id: str | None = None, session_id: str | None = None):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.session_id = session_id
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
@property
|
||||
def has_unresolved_tool_calls(self) -> bool:
|
||||
"""True when there are tool calls that haven't received output yet."""
|
||||
return bool(self.current_tool_calls.keys() - self.resolved_tool_calls)
|
||||
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
@@ -68,7 +69,7 @@ class SDKResponseAdapter:
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||
StreamStart(messageId=self.message_id, sessionId=self.session_id)
|
||||
)
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
@@ -77,7 +78,12 @@ class SDKResponseAdapter:
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
# BUT skip flush when this AssistantMessage is a parallel tool
|
||||
# continuation (contains only ToolUseBlocks) — the prior tools
|
||||
# are still executing concurrently and haven't finished yet.
|
||||
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
|
||||
if not is_tool_only:
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
@@ -118,8 +124,24 @@ class SDKResponseAdapter:
|
||||
blocks = content if isinstance(content, list) else []
|
||||
resolved_in_blocks: set[str] = set()
|
||||
|
||||
sid = (self.session_id or "?")[:12]
|
||||
parent_id_preview = getattr(sdk_message, "parent_tool_use_id", None)
|
||||
logger.info(
|
||||
"[SDK] [%s] UserMessage: %d blocks, content_type=%s, "
|
||||
"parent_tool_use_id=%s",
|
||||
sid,
|
||||
len(blocks),
|
||||
type(content).__name__,
|
||||
parent_id_preview[:12] if parent_id_preview else "None",
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
# Skip if already resolved (e.g. by flush) — the real
|
||||
# result supersedes the empty flush, but re-emitting
|
||||
# would confuse the frontend's state machine.
|
||||
if block.tool_use_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
@@ -144,7 +166,11 @@ class SDKResponseAdapter:
|
||||
# Handle SDK built-in tool results carried via parent_tool_use_id
|
||||
# instead of (or in addition to) ToolResultBlock content.
|
||||
parent_id = sdk_message.parent_tool_use_id
|
||||
if parent_id and parent_id not in resolved_in_blocks:
|
||||
if (
|
||||
parent_id
|
||||
and parent_id not in resolved_in_blocks
|
||||
and parent_id not in self.resolved_tool_calls
|
||||
):
|
||||
tool_info = self.current_tool_calls.get(parent_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
@@ -228,11 +254,28 @@ class SDKResponseAdapter:
|
||||
output, which we pop and emit here before the next ``AssistantMessage``
|
||||
starts.
|
||||
"""
|
||||
unresolved = [
|
||||
(tid, info.get("name", "unknown"))
|
||||
for tid, info in self.current_tool_calls.items()
|
||||
if tid not in self.resolved_tool_calls
|
||||
]
|
||||
sid = (self.session_id or "?")[:12]
|
||||
if not unresolved:
|
||||
logger.info(
|
||||
"[SDK] [%s] Flush called but all %d tool(s) already resolved",
|
||||
sid,
|
||||
len(self.current_tool_calls),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushing %d unresolved tool call(s): %s",
|
||||
sid,
|
||||
len(unresolved),
|
||||
", ".join(f"{name}({tid[:12]})" for tid, name in unresolved),
|
||||
)
|
||||
|
||||
flushed = False
|
||||
for tool_id, tool_info in self.current_tool_calls.items():
|
||||
if tool_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
for tool_id, tool_name in unresolved:
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if output is not None:
|
||||
responses.append(
|
||||
@@ -245,9 +288,12 @@ class SDKResponseAdapter:
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed pending output for built-in tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
len(output),
|
||||
)
|
||||
else:
|
||||
# No output available — emit an empty output so the frontend
|
||||
@@ -263,9 +309,14 @@ class SDKResponseAdapter:
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed empty output for unresolved tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
logger.warning(
|
||||
"[SDK] [%s] Flushed EMPTY output for unresolved tool %s "
|
||||
"(call %s) — stash was empty (likely SDK hook race "
|
||||
"condition: PostToolUse hook hadn't completed before "
|
||||
"flush was triggered)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
)
|
||||
|
||||
if flushed and self.step_open:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Unit tests for the SDK response adapter."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
@@ -27,12 +30,14 @@ from backend.copilot.response_model import (
|
||||
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
from .tool_adapter import _stash_event
|
||||
from .tool_adapter import stash_pending_tool_output as _stash
|
||||
from .tool_adapter import wait_for_stash
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
a = SDKResponseAdapter(message_id="msg-1")
|
||||
a.set_task_id("task-1")
|
||||
return a
|
||||
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
|
||||
|
||||
|
||||
# -- SystemMessage -----------------------------------------------------------
|
||||
@@ -44,7 +49,7 @@ def test_system_init_emits_start_and_step():
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamStart)
|
||||
assert results[0].messageId == "msg-1"
|
||||
assert results[0].taskId == "task-1"
|
||||
assert results[0].sessionId == "session-1"
|
||||
assert isinstance(results[1], StreamStartStep)
|
||||
|
||||
|
||||
@@ -364,3 +369,310 @@ def test_full_conversation_flow():
|
||||
"StreamFinishStep", # step 2 closed
|
||||
"StreamFinish",
|
||||
]
|
||||
|
||||
|
||||
# -- Flush unresolved tool calls --------------------------------------------
|
||||
|
||||
|
||||
def test_flush_unresolved_at_result_message():
|
||||
"""Built-in tools (WebSearch) without UserMessage results get flushed at ResultMessage."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in tool — no MCP prefix)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. No UserMessage for this tool — go straight to ResultMessage
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # flushed with empty output
|
||||
"StreamFinishStep", # step closed by flush
|
||||
"StreamFinish",
|
||||
]
|
||||
# The flushed output should be empty (no stash available)
|
||||
output_event = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
][0]
|
||||
assert output_event.toolCallId == "ws-1"
|
||||
assert output_event.toolName == "WebSearch"
|
||||
assert output_event.output == ""
|
||||
|
||||
|
||||
def test_flush_unresolved_at_next_assistant_message():
|
||||
"""Built-in tools get flushed when the next AssistantMessage arrives."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in — no UserMessage will come)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. Next AssistantMessage triggers flush before processing its blocks
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")], model="test"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
# Flush at next AssistantMessage:
|
||||
"StreamToolOutputAvailable",
|
||||
"StreamFinishStep", # step closed by flush
|
||||
# New step for continuation text:
|
||||
"StreamStartStep",
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta",
|
||||
]
|
||||
|
||||
|
||||
def test_flush_with_stashed_output():
|
||||
"""Stashed output from PostToolUse hook is used when flushing."""
|
||||
adapter = _adapter()
|
||||
|
||||
# Simulate PostToolUse hook stashing output
|
||||
_pto.set({})
|
||||
_stash("WebSearch", "Search result: 5 items found")
|
||||
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# ResultMessage triggers flush
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
output_events = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].output == "Search result: 5 items found"
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# -- wait_for_stash synchronisation tests --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_signaled():
|
||||
"""wait_for_stash returns True when stash_pending_tool_output signals."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Simulate a PostToolUse hook that stashes output after a short delay
|
||||
async def delayed_stash():
|
||||
await asyncio.sleep(0.01)
|
||||
_stash("WebSearch", "result data")
|
||||
|
||||
asyncio.create_task(delayed_stash())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
|
||||
assert result is True
|
||||
assert _pto.get({}).get("WebSearch") == ["result data"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_timeout():
|
||||
"""wait_for_stash returns False on timeout when no stash occurs."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_already_stashed():
|
||||
"""wait_for_stash picks up a stash that happened just before the wait."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Stash before waiting — simulates hook completing before message arrives
|
||||
_stash("Read", "file contents")
|
||||
# Event is now set; wait_for_stash detects the fast path and returns
|
||||
# immediately without timing out.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is True
|
||||
|
||||
# But the stash itself is populated
|
||||
assert _pto.get({}).get("Read") == ["file contents"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
# -- Parallel tool call tests --
|
||||
|
||||
|
||||
def test_parallel_tool_calls_not_flushed_prematurely():
|
||||
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
|
||||
only contains ToolUseBlocks (parallel continuation)."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# First AssistantMessage: tool call #1
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
r1 = adapter.convert_message(msg1)
|
||||
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Second AssistantMessage: tool call #2 (parallel continuation)
|
||||
msg2 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 0, (
|
||||
f"Tool-only AssistantMessage should not flush prior tools, "
|
||||
f"but got {len(output_events)} output events"
|
||||
)
|
||||
|
||||
# Both t1 and t2 should still be unresolved
|
||||
assert "t1" not in adapter.resolved_tool_calls
|
||||
assert "t2" not in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_text_assistant_message_flushes_prior_tools():
|
||||
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(msg1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Text AssistantMessage (new turn after tools completed)
|
||||
msg2 = AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# Flush SHOULD have happened — t1 gets empty output
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].toolCallId == "t1"
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_already_resolved_tool_skipped_in_user_message():
|
||||
"""A tool result in UserMessage should be skipped if already resolved by flush."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call + flush via text message
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Done")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
# Now UserMessage arrives with the real result — should be skipped
|
||||
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
|
||||
r = adapter.convert_message(user_msg)
|
||||
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
|
||||
194
autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py
Normal file
194
autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""SDK compatibility tests — verify the claude-agent-sdk public API surface we depend on.
|
||||
|
||||
Instead of pinning to a narrow version range, these tests verify that the
|
||||
installed SDK exposes every class, function, attribute, and method the copilot
|
||||
integration relies on. If an SDK upgrade removes or renames something these
|
||||
tests will catch it immediately.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types & factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sdk_exports_client_and_options():
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
assert inspect.isclass(ClaudeSDKClient)
|
||||
assert inspect.isclass(ClaudeAgentOptions)
|
||||
|
||||
|
||||
def test_sdk_exports_message_types():
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
for cls in (AssistantMessage, ResultMessage, SystemMessage, UserMessage):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
# Message is a Union type alias, just verify it's importable
|
||||
assert Message is not None
|
||||
|
||||
|
||||
def test_sdk_exports_content_block_types():
|
||||
from claude_agent_sdk import TextBlock, ToolResultBlock, ToolUseBlock
|
||||
|
||||
for cls in (TextBlock, ToolResultBlock, ToolUseBlock):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
|
||||
|
||||
def test_sdk_exports_mcp_helpers():
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
assert callable(create_sdk_mcp_server)
|
||||
assert callable(tool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeSDKClient interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_client_has_required_methods():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
required = ["connect", "disconnect", "query", "receive_messages"]
|
||||
for name in required:
|
||||
attr = getattr(ClaudeSDKClient, name, None)
|
||||
assert attr is not None, f"ClaudeSDKClient.{name} missing"
|
||||
assert callable(attr), f"ClaudeSDKClient.{name} is not callable"
|
||||
|
||||
|
||||
def test_client_supports_async_context_manager():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
assert hasattr(ClaudeSDKClient, "__aenter__")
|
||||
assert hasattr(ClaudeSDKClient, "__aexit__")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeAgentOptions fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_options_accepts_required_fields():
|
||||
"""Verify ClaudeAgentOptions accepts all kwargs our code passes."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
opts = ClaudeAgentOptions(
|
||||
system_prompt="test",
|
||||
cwd="/tmp",
|
||||
)
|
||||
assert opts.system_prompt == "test"
|
||||
assert opts.cwd == "/tmp"
|
||||
|
||||
|
||||
def test_agent_options_accepts_all_our_fields():
|
||||
"""Comprehensive check of every field we use in service.py."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
fields_we_use = [
|
||||
"system_prompt",
|
||||
"mcp_servers",
|
||||
"allowed_tools",
|
||||
"disallowed_tools",
|
||||
"hooks",
|
||||
"cwd",
|
||||
"model",
|
||||
"env",
|
||||
"resume",
|
||||
"max_buffer_size",
|
||||
]
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
for field in fields_we_use:
|
||||
assert field in sig.parameters, (
|
||||
f"ClaudeAgentOptions no longer accepts '{field}' — "
|
||||
f"available params: {list(sig.parameters.keys())}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_assistant_message_has_content_and_model():
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock
|
||||
|
||||
msg = AssistantMessage(content=[TextBlock(text="hi")], model="test")
|
||||
assert hasattr(msg, "content")
|
||||
assert hasattr(msg, "model")
|
||||
|
||||
|
||||
def test_result_message_has_required_attrs():
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
assert msg.subtype == "success"
|
||||
assert hasattr(msg, "result")
|
||||
|
||||
|
||||
def test_system_message_has_subtype_and_data():
|
||||
from claude_agent_sdk import SystemMessage
|
||||
|
||||
msg = SystemMessage(subtype="init", data={})
|
||||
assert msg.subtype == "init"
|
||||
assert msg.data == {}
|
||||
|
||||
|
||||
def test_user_message_has_parent_tool_use_id():
|
||||
from claude_agent_sdk import UserMessage
|
||||
|
||||
msg = UserMessage(content="test")
|
||||
assert hasattr(msg, "parent_tool_use_id")
|
||||
assert hasattr(msg, "tool_use_result")
|
||||
|
||||
|
||||
def test_tool_use_block_has_id_name_input():
|
||||
from claude_agent_sdk import ToolUseBlock
|
||||
|
||||
block = ToolUseBlock(id="t1", name="test", input={"key": "val"})
|
||||
assert block.id == "t1"
|
||||
assert block.name == "test"
|
||||
assert block.input == {"key": "val"}
|
||||
|
||||
|
||||
def test_tool_result_block_has_required_attrs():
|
||||
from claude_agent_sdk import ToolResultBlock
|
||||
|
||||
block = ToolResultBlock(tool_use_id="t1", content="result")
|
||||
assert block.tool_use_id == "t1"
|
||||
assert block.content == "result"
|
||||
assert hasattr(block, "is_error")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hook_event",
|
||||
["PreToolUse", "PostToolUse", "Stop"],
|
||||
)
|
||||
def test_sdk_exports_hook_event_type(hook_event: str):
|
||||
"""Verify HookEvent literal includes the events our security_hooks use."""
|
||||
from claude_agent_sdk.types import HookEvent
|
||||
|
||||
# HookEvent is a Literal type — check that our events are valid values.
|
||||
# We can't easily inspect Literal at runtime, so just verify the type exists.
|
||||
assert HookEvent is not None
|
||||
@@ -124,20 +124,20 @@ def _validate_user_isolation(
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
# The "path" param is a cloud storage key (e.g. "/ASEAN/report.md")
|
||||
# where a leading "/" is normal. Only check for ".." traversal.
|
||||
# Filesystem paths (source_path, save_to_path) are validated inside
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path:
|
||||
# Check for path traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
logger.warning(
|
||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
if path and ".." in path:
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
@@ -188,8 +188,19 @@ def create_security_hooks(
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
task_spawn_count += 1
|
||||
if task_spawn_count > max_subtasks:
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead "
|
||||
"(remove the run_in_background parameter)."
|
||||
),
|
||||
)
|
||||
if task_spawn_count >= max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
)
|
||||
@@ -200,6 +211,7 @@ def create_security_hooks(
|
||||
"Please continue in the main conversation."
|
||||
),
|
||||
)
|
||||
task_spawn_count += 1
|
||||
|
||||
# Strip MCP prefix for consistent validation
|
||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
@@ -234,15 +246,33 @@ def create_security_hooks(
|
||||
"""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
# emit StreamToolOutputAvailable even when the CLI doesn't surface
|
||||
# a separate UserMessage with ToolResultBlock content.
|
||||
if not tool_name.startswith(MCP_TOOL_PREFIX):
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
len(str(tool_response)),
|
||||
resp_preview,
|
||||
)
|
||||
stash_pending_tool_output(tool_name, tool_response)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] PostToolUse for builtin %s but tool_response is None",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
|
||||
@@ -7,11 +7,23 @@ tool access, and dangerous input patterns.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
from .service import _is_tool_error_or_denial
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
|
||||
def _sdk_available() -> bool:
|
||||
try:
|
||||
import claude_agent_sdk # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_denied(result: dict) -> bool:
|
||||
hook = result.get("hookSpecificOutput", {})
|
||||
return hook.get("permissionDecision") == "deny"
|
||||
@@ -153,11 +165,12 @@ def test_workspace_path_traversal_blocked():
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_absolute_path_blocked():
|
||||
def test_workspace_absolute_path_allowed():
|
||||
"""Workspace 'path' is a cloud storage key — leading '/' is normal."""
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
||||
"workspace_read", {"path": "/ASEAN/report.md"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_workspace_normal_path_allowed():
|
||||
@@ -188,3 +201,135 @@ def test_bash_builtin_blocked_message_clarity():
|
||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
|
||||
|
||||
# -- Task sub-agent hooks (require SDK) --------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _hooks():
|
||||
"""Create security hooks and return the PreToolUse handler."""
|
||||
from .security_hooks import create_security_hooks
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
pre = hooks["PreToolUse"][0].hooks[0]
|
||||
return pre
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_background_blocked(_hooks):
|
||||
"""Task with run_in_background=true must be denied."""
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "foreground" in _reason(result).lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_foreground_allowed(_hooks):
|
||||
"""Task without run_in_background should be allowed."""
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "do stuff"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_limit_enforced(_hooks):
|
||||
"""Task spawns beyond max_subtasks should be denied."""
|
||||
# First two should pass
|
||||
for _ in range(2):
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied (limit=2)
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "over limit"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
# -- _is_tool_error_or_denial ------------------------------------------------
|
||||
|
||||
|
||||
class TestIsToolErrorOrDenial:
|
||||
def test_none_content(self):
|
||||
assert _is_tool_error_or_denial(None) is False
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _is_tool_error_or_denial("") is False
|
||||
|
||||
def test_benign_output(self):
|
||||
assert _is_tool_error_or_denial("All good, no issues.") is False
|
||||
|
||||
def test_security_marker(self):
|
||||
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
|
||||
|
||||
def test_cannot_be_bypassed(self):
|
||||
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
|
||||
|
||||
def test_not_allowed(self):
|
||||
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
|
||||
|
||||
def test_background_task_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_subtask_limit_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Maximum 2 sub-tasks per session. Please continue in the main conversation."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_denied_marker(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
|
||||
)
|
||||
|
||||
def test_blocked_marker(self):
|
||||
assert _is_tool_error_or_denial("Request blocked by security policy") is True
|
||||
|
||||
def test_failed_marker(self):
|
||||
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
|
||||
|
||||
def test_mcp_iserror(self):
|
||||
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
|
||||
|
||||
def test_benign_error_in_value(self):
|
||||
"""Content like '0 errors found' should not trigger — 'error' was removed."""
|
||||
assert _is_tool_error_or_denial("0 errors found") is False
|
||||
|
||||
def test_benign_permission_field(self):
|
||||
"""Schema descriptions mentioning 'permission' should not trigger."""
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
'{"fields": [{"name": "permission_level", "type": "int"}]}'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_benign_not_found_in_listing(self):
|
||||
"""File listing containing 'not found' in filenames should not trigger."""
|
||||
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,19 +2,14 @@
|
||||
|
||||
This module provides the adapter layer that converts existing BaseTool implementations
|
||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||
|
||||
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
|
||||
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
|
||||
via a callback provided by the service layer. This avoids wasteful SDK polling
|
||||
and makes results survive page refreshes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
@@ -42,25 +37,22 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
"pending_tool_outputs",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
||||
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
||||
LongRunningCallback = Callable[
|
||||
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
|
||||
]
|
||||
|
||||
# ContextVar so the service layer can inject the callback per-request.
|
||||
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
|
||||
"long_running_callback", default=None
|
||||
# Event signaled whenever stash_pending_tool_output() adds a new entry.
|
||||
# Used by the streaming loop to wait for PostToolUse hooks to complete
|
||||
# instead of sleeping an arbitrary duration. The SDK fires hooks via
|
||||
# start_soon (fire-and-forget) so the next message can arrive before
|
||||
# the hook stashes its output — this event bridges that gap.
|
||||
_stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
long_running_callback: LongRunningCallback | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
@@ -70,13 +62,11 @@ def set_execution_context(
|
||||
Args:
|
||||
user_id: Current user's ID.
|
||||
session: Current chat session.
|
||||
long_running_callback: Optional callback to delegate long-running tools
|
||||
to the non-SDK background infrastructure (stream_registry + Redis).
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_pending_tool_outputs.set({})
|
||||
_long_running_callback.set(long_running_callback)
|
||||
_stash_event.set(asyncio.Event())
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
@@ -134,6 +124,43 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
except (TypeError, ValueError):
|
||||
text = str(output)
|
||||
pending.setdefault(tool_name, []).append(text)
|
||||
# Signal any waiters that new output is available.
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
the next message (AssistantMessage/ResultMessage) can arrive before the
|
||||
hook completes and stashes its output. This function bridges that gap
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
return False
|
||||
# Fast path: hook already completed before we got here.
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: wait for the hook to signal.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
event.clear()
|
||||
return True
|
||||
except TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
@@ -229,11 +256,6 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
|
||||
This wraps the existing BaseTool._execute method to be compatible
|
||||
with the Claude Agent SDK MCP tool format.
|
||||
|
||||
Long-running tools (``is_long_running=True``) are delegated to the
|
||||
non-SDK background infrastructure via a callback set in the execution
|
||||
context. The callback persists the operation in Redis (stream_registry)
|
||||
so results survive page refreshes and pod restarts.
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -243,25 +265,6 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
if session is None:
|
||||
return _mcp_error("No session context available")
|
||||
|
||||
# --- Long-running: delegate to non-SDK background infrastructure ---
|
||||
if base_tool.is_long_running:
|
||||
callback = _long_running_callback.get(None)
|
||||
if callback:
|
||||
try:
|
||||
return await callback(base_tool.name, args, session)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Long-running callback failed for {base_tool.name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
|
||||
# No callback — fall through to synchronous execution
|
||||
logger.warning(
|
||||
f"[SDK] No long-running callback for {base_tool.name}, "
|
||||
f"executing synchronously (may block)"
|
||||
)
|
||||
|
||||
# --- Normal (fast) tool: execute synchronously ---
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
|
||||
@@ -131,17 +131,20 @@ def read_transcript_file(transcript_path: str) -> str | None:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug("[Transcript] File is empty: %s", transcript_path)
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 3:
|
||||
# Raw files with ≤2 lines are metadata-only
|
||||
# (queue-operation + file-history-snapshot, no conversation).
|
||||
return None
|
||||
|
||||
# Quick structural validation — parse first and last lines.
|
||||
json.loads(lines[0])
|
||||
json.loads(lines[-1])
|
||||
# Validate that the transcript has real conversation content
|
||||
# (not just metadata like queue-operation entries).
|
||||
if not validate_transcript(content):
|
||||
logger.debug(
|
||||
"[Transcript] No conversation content (%d lines) in %s",
|
||||
len(lines),
|
||||
transcript_path,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,12 +6,7 @@ import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
|
||||
@@ -30,7 +25,6 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
@@ -40,10 +34,9 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
# StreamFinish is published by mark_session_completed (processor layer),
|
||||
# not by the service. The generator completing means the stream ended.
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
@@ -61,7 +54,6 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
@@ -71,13 +63,9 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
@@ -114,7 +102,6 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
)
|
||||
turn1_text = ""
|
||||
turn1_errors: list[str] = []
|
||||
turn1_ended = False
|
||||
|
||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
||||
session.session_id,
|
||||
@@ -125,24 +112,27 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
turn1_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn1_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
turn1_ended = True
|
||||
|
||||
assert turn1_ended, "Turn 1 did not finish"
|
||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||
assert turn1_text, "Turn 1 produced no text"
|
||||
|
||||
# Wait for background upload task to complete (retry up to 5s)
|
||||
# Wait for background upload task to complete (retry up to 5s).
|
||||
# The CLI may not produce a usable transcript for very short
|
||||
# conversations (only metadata entries) — this is environment-dependent
|
||||
# (CLI version, platform). When that happens, multi-turn still works
|
||||
# via conversation compression (non-resume path), but we can't test
|
||||
# the --resume round-trip.
|
||||
transcript = None
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.5)
|
||||
transcript = await download_transcript(test_user_id, session.session_id)
|
||||
if transcript:
|
||||
break
|
||||
assert transcript, (
|
||||
"Transcript was not uploaded to bucket after turn 1 — "
|
||||
"Stop hook may not have fired or transcript was too small"
|
||||
)
|
||||
if not transcript:
|
||||
return pytest.skip(
|
||||
"CLI did not produce a usable transcript — "
|
||||
"cannot test --resume round-trip in this environment"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
|
||||
# Reload session for turn 2
|
||||
@@ -153,7 +143,6 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
||||
turn2_text = ""
|
||||
turn2_errors: list[str] = []
|
||||
turn2_ended = False
|
||||
|
||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
||||
session.session_id,
|
||||
@@ -165,10 +154,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
turn2_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn2_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
turn2_ended = True
|
||||
|
||||
assert turn2_ended, "Turn 2 did not finish"
|
||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
||||
assert turn2_text, "Turn 2 produced no text"
|
||||
assert keyword in turn2_text, (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
401
autogpt_platform/backend/backend/copilot/test_copilot_e2e.py
Normal file
401
autogpt_platform/backend/backend/copilot/test_copilot_e2e.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""End-to-end tests for Copilot streaming with dummy implementations.
|
||||
|
||||
These tests verify the complete copilot flow using dummy implementations
|
||||
for agent generator and SDK service, allowing automated testing without
|
||||
external LLM calls.
|
||||
|
||||
Enable test mode with COPILOT_TEST_MODE=true environment variable.
|
||||
|
||||
Note: StreamFinish is NOT emitted by the dummy service — it is published
|
||||
by mark_session_completed in the processor layer. These tests only cover
|
||||
the service-level streaming output (StreamStart + StreamTextDelta).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
)
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_test_mode():
|
||||
"""Enable test mode for all tests in this module."""
|
||||
os.environ["COPILOT_TEST_MODE"] = "true"
|
||||
yield
|
||||
os.environ.pop("COPILOT_TEST_MODE", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_streaming_basic_flow():
|
||||
"""Test that dummy streaming produces correct event sequence."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-basic",
|
||||
message="Hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Verify we got events
|
||||
assert len(events) > 0, "Should receive events"
|
||||
|
||||
# Verify StreamStart
|
||||
start_events = [e for e in events if isinstance(e, StreamStart)]
|
||||
assert len(start_events) == 1
|
||||
assert start_events[0].messageId
|
||||
assert start_events[0].sessionId
|
||||
|
||||
# Verify StreamTextDelta events
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) > 0
|
||||
full_text = "".join(e.delta for e in text_events)
|
||||
assert len(full_text) > 0
|
||||
|
||||
# Verify order: start before text
|
||||
start_idx = events.index(start_events[0])
|
||||
first_text_idx = events.index(text_events[0]) if text_events else -1
|
||||
if first_text_idx >= 0:
|
||||
assert start_idx < first_text_idx
|
||||
|
||||
print(f"✅ Basic flow: {len(events)} events, {len(text_events)} text deltas")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_no_timeout():
|
||||
"""Test that streaming completes within reasonable time without timeout."""
|
||||
import time
|
||||
|
||||
start_time = time.monotonic()
|
||||
event_count = 0
|
||||
|
||||
async for _event in stream_chat_completion_dummy(
|
||||
session_id="test-session-timeout",
|
||||
message="count to 10",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
|
||||
# Should complete in < 5 seconds (dummy has 0.1s delays between words)
|
||||
assert elapsed < 5.0, f"Streaming took {elapsed:.1f}s, expected < 5s"
|
||||
assert event_count > 0, "Should receive events"
|
||||
|
||||
print(f"✅ No timeout: completed in {elapsed:.2f}s with {event_count} events")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_event_types():
|
||||
"""Test that all expected event types are present."""
|
||||
event_types = set()
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-types",
|
||||
message="test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
event_types.add(type(event).__name__)
|
||||
|
||||
# Required event types (StreamFinish is published by processor, not service)
|
||||
assert "StreamStart" in event_types, "Missing StreamStart"
|
||||
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
|
||||
|
||||
print(f"✅ Event types: {sorted(event_types)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_text_content():
|
||||
"""Test that streamed text is coherent and complete."""
|
||||
text_events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-content",
|
||||
message="count to 3",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
text_events.append(event)
|
||||
|
||||
# Verify text deltas
|
||||
assert len(text_events) > 0, "Should have text deltas"
|
||||
|
||||
# Reconstruct full text
|
||||
full_text = "".join(e.delta for e in text_events)
|
||||
assert len(full_text) > 0, "Text should not be empty"
|
||||
assert (
|
||||
"1" in full_text or "counted" in full_text.lower()
|
||||
), "Text should contain count"
|
||||
|
||||
# Verify all deltas have IDs
|
||||
for text_event in text_events:
|
||||
assert text_event.id, "Text delta must have ID"
|
||||
assert text_event.delta, "Text delta must have content"
|
||||
|
||||
print(f"✅ Text content: '{full_text}' ({len(text_events)} deltas)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_heartbeat_timing():
|
||||
"""Test that heartbeats are sent at correct interval during long operations."""
|
||||
# This test would need a dummy that takes longer
|
||||
# For now, just verify heartbeat structure if we receive one
|
||||
heartbeats = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-heartbeat",
|
||||
message="test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
if isinstance(event, StreamHeartbeat):
|
||||
heartbeats.append(event)
|
||||
|
||||
# Dummy is fast, so we might not get heartbeats
|
||||
# But if we do, verify they're valid
|
||||
if heartbeats:
|
||||
print(f"✅ Heartbeat structure verified ({len(heartbeats)} received)")
|
||||
else:
|
||||
print("✅ No heartbeats (dummy executes quickly)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling():
|
||||
"""Test that errors are properly formatted and sent."""
|
||||
# This would require a dummy that can trigger errors
|
||||
# For now, just verify error event structure
|
||||
|
||||
error = StreamError(errorText="Test error", code="test_error")
|
||||
assert error.errorText == "Test error"
|
||||
assert error.code == "test_error"
|
||||
assert str(error.type.value) in ["error", "error"]
|
||||
|
||||
print("✅ Error structure verified")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_sessions():
|
||||
"""Test that multiple sessions can stream concurrently."""
|
||||
|
||||
async def stream_session(session_id: str) -> int:
|
||||
count = 0
|
||||
async for _event in stream_chat_completion_dummy(
|
||||
session_id=session_id,
|
||||
message="test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
# Run 3 concurrent sessions
|
||||
results = await asyncio.gather(
|
||||
stream_session("session-1"),
|
||||
stream_session("session-2"),
|
||||
stream_session("session-3"),
|
||||
)
|
||||
|
||||
# All should complete successfully
|
||||
assert all(count > 0 for count in results), "All sessions should produce events"
|
||||
print(f"✅ Concurrent sessions: {results} events each")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.xfail(
|
||||
reason="Event loop isolation issue with DB operations in tests - needs fixture refactoring"
|
||||
)
|
||||
async def test_session_state_persistence():
|
||||
"""Test that session state is maintained across multiple messages."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
session_id = f"test-session-{uuid4()}"
|
||||
user_id = "test-user"
|
||||
|
||||
# Create session with first message
|
||||
session = ChatSession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
messages=[
|
||||
ChatMessage(role="user", content="Hello"),
|
||||
ChatMessage(role="assistant", content="Hi there!"),
|
||||
],
|
||||
usage=[],
|
||||
started_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# Stream second message
|
||||
events = []
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id=session_id,
|
||||
message="How are you?",
|
||||
is_user_message=True,
|
||||
user_id=user_id,
|
||||
session=session, # Pass existing session
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Verify events were produced
|
||||
assert len(events) > 0, "Should produce events for second message"
|
||||
|
||||
print(f"✅ Session persistence: {len(events)} events for second message")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_deduplication():
|
||||
"""Test that duplicate messages are filtered out."""
|
||||
|
||||
# Simulate receiving duplicate events (e.g., from reconnection)
|
||||
events = []
|
||||
|
||||
# First stream
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-dedup-1",
|
||||
message="Hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Count unique message IDs in StreamStart events
|
||||
start_events = [e for e in events if isinstance(e, StreamStart)]
|
||||
message_ids = [e.messageId for e in start_events]
|
||||
|
||||
# Verify all IDs are present
|
||||
assert len(message_ids) == len(set(message_ids)), "Message IDs should be unique"
|
||||
|
||||
print(f"✅ Deduplication: {len(events)} events, all unique")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_ordering():
|
||||
"""Test that events arrive in correct order."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-ordering",
|
||||
message="Test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Find event indices
|
||||
start_idx = next(
|
||||
(i for i, e in enumerate(events) if isinstance(e, StreamStart)), None
|
||||
)
|
||||
text_indices = [i for i, e in enumerate(events) if isinstance(e, StreamTextDelta)]
|
||||
|
||||
# Verify ordering
|
||||
assert start_idx is not None, "Should have StreamStart"
|
||||
assert start_idx == 0, "StreamStart should be first"
|
||||
|
||||
if text_indices:
|
||||
assert all(
|
||||
start_idx < i for i in text_indices
|
||||
), "Text deltas should be after start"
|
||||
|
||||
print(f"✅ Event ordering: start({start_idx}) < text deltas")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_completeness():
|
||||
"""Test that stream includes all required event types."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-completeness",
|
||||
message="Complete stream test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Check for required events (StreamFinish is published by processor)
|
||||
has_start = any(isinstance(e, StreamStart) for e in events)
|
||||
has_text = any(isinstance(e, StreamTextDelta) for e in events)
|
||||
|
||||
assert has_start, "Stream must include StreamStart"
|
||||
assert has_text, "Stream must include text deltas"
|
||||
|
||||
# Verify exactly one start
|
||||
start_count = sum(1 for e in events if isinstance(e, StreamStart))
|
||||
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
|
||||
|
||||
print(
|
||||
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_delta_consistency():
|
||||
"""Test that text deltas have consistent IDs and build coherent text."""
|
||||
text_events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-consistency",
|
||||
message="Test consistency",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
text_events.append(event)
|
||||
|
||||
# Verify all text deltas have IDs
|
||||
assert all(e.id for e in text_events), "All text deltas must have IDs"
|
||||
|
||||
# Verify all deltas have the same ID (same text block)
|
||||
if text_events:
|
||||
first_id = text_events[0].id
|
||||
assert all(
|
||||
e.id == first_id for e in text_events
|
||||
), "All text deltas should share the same block ID"
|
||||
|
||||
# Verify deltas build coherent text
|
||||
full_text = "".join(e.delta for e in text_events)
|
||||
assert len(full_text) > 0, "Deltas should build non-empty text"
|
||||
assert (
|
||||
full_text == full_text.strip()
|
||||
), "Text should not have leading/trailing whitespace artifacts"
|
||||
|
||||
print(
|
||||
f"✅ Consistency: {len(text_events)} deltas with ID '{text_events[0].id if text_events else 'N/A'}', text: '{full_text}'"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
|
||||
print("Running Copilot E2E tests with dummy implementations...")
|
||||
print("=" * 60)
|
||||
|
||||
asyncio.run(test_dummy_streaming_basic_flow())
|
||||
asyncio.run(test_streaming_no_timeout())
|
||||
asyncio.run(test_streaming_event_types())
|
||||
asyncio.run(test_streaming_text_content())
|
||||
asyncio.run(test_streaming_heartbeat_timing())
|
||||
asyncio.run(test_error_handling())
|
||||
asyncio.run(test_concurrent_sessions())
|
||||
asyncio.run(test_session_state_persistence())
|
||||
asyncio.run(test_message_deduplication())
|
||||
asyncio.run(test_event_ordering())
|
||||
asyncio.run(test_stream_completeness())
|
||||
asyncio.run(test_text_delta_consistency())
|
||||
|
||||
print("=" * 60)
|
||||
print("✅ All E2E tests passed!")
|
||||
@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .check_operation_status import CheckOperationStatusTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -47,7 +46,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"check_operation_status": CheckOperationStatusTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -31,14 +32,16 @@ def make_session(user_id: str):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_test_data():
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def setup_test_data(server):
|
||||
"""
|
||||
Set up test data for run_agent tests:
|
||||
1. Create a test user
|
||||
2. Create a test graph (agent input -> agent output)
|
||||
3. Create a store listing and store listing version
|
||||
4. Approve the store listing version
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
@@ -150,14 +153,16 @@ async def setup_test_data():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_llm_test_data():
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def setup_llm_test_data(server):
|
||||
"""
|
||||
Set up test data for LLM agent tests:
|
||||
1. Create a test user
|
||||
2. Create test OpenAI credentials for the user
|
||||
3. Create a test graph with input -> LLM block -> output
|
||||
4. Create and approve a store listing
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
key = getenv("OPENAI_API_KEY")
|
||||
if not key:
|
||||
@@ -315,13 +320,15 @@ async def setup_llm_test_data():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_firecrawl_test_data():
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def setup_firecrawl_test_data(server):
|
||||
"""
|
||||
Set up test data for Firecrawl agent tests (missing credentials scenario):
|
||||
1. Create a test user (WITHOUT Firecrawl credentials)
|
||||
2. Create a test graph with input -> Firecrawl block -> output
|
||||
3. Create and approve a store listing
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
|
||||
@@ -19,6 +19,7 @@ from .core import (
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
get_library_agent_by_id,
|
||||
get_library_agents_by_ids,
|
||||
get_library_agents_for_generation,
|
||||
graph_to_json,
|
||||
json_to_graph,
|
||||
@@ -49,6 +50,7 @@ __all__ = [
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
"get_library_agent_by_id",
|
||||
"get_library_agents_by_ids",
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"graph_to_json",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
@@ -78,7 +79,7 @@ AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
def _to_dict_list(
|
||||
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
@@ -190,6 +191,36 @@ async def get_library_agent_by_id(
|
||||
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||
|
||||
|
||||
async def get_library_agents_by_ids(
|
||||
user_id: str,
|
||||
agent_ids: list[str],
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Fetch multiple library agents by their IDs.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary for found agents (silently skips not found)
|
||||
"""
|
||||
agents: list[LibraryAgentSummary] = []
|
||||
for agent_id in agent_ids:
|
||||
try:
|
||||
agent = await get_library_agent_by_id(user_id, agent_id)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.debug(f"Fetched library agent by ID: {agent['name']}")
|
||||
else:
|
||||
logger.warning(f"Library agent not found for ID: {agent_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
|
||||
return agents
|
||||
|
||||
|
||||
async def get_library_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
@@ -214,10 +245,17 @@ async def get_library_agents_for_generation(
|
||||
Returns:
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
search_term = search_query.strip() if search_query else None
|
||||
if search_term and len(search_term) > 100:
|
||||
raise ValueError(
|
||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
||||
f"Please use a shorter, more specific search term."
|
||||
)
|
||||
|
||||
try:
|
||||
response = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
search_term=search_term,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
include_executions=True,
|
||||
@@ -271,9 +309,16 @@ async def search_marketplace_agents_for_generation(
|
||||
Returns:
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
search_term = search_query.strip()
|
||||
if len(search_term) > 100:
|
||||
raise ValueError(
|
||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
||||
f"Please use a shorter, more specific search term."
|
||||
)
|
||||
|
||||
try:
|
||||
response = await store_db().get_store_agents(
|
||||
search_query=search_query,
|
||||
search_query=search_term,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
)
|
||||
@@ -424,7 +469,7 @@ def extract_search_terms_from_steps(
|
||||
async def enrich_library_agents_from_steps(
|
||||
user_id: str,
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
|
||||
exclude_graph_id: str | None = None,
|
||||
include_marketplace: bool = True,
|
||||
max_additional_results: int = 10,
|
||||
@@ -448,7 +493,7 @@ async def enrich_library_agents_from_steps(
|
||||
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
if not search_terms:
|
||||
return existing_agents
|
||||
return list(existing_agents)
|
||||
|
||||
existing_ids: set[str] = set()
|
||||
existing_names: set[str] = set()
|
||||
@@ -511,7 +556,7 @@ async def enrich_library_agents_from_steps(
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
@@ -539,22 +584,16 @@ async def decompose_goal(
|
||||
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams
|
||||
completion notification)
|
||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||
and SSE delivery)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -562,13 +601,9 @@ async def generate_agent(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
if result:
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
@@ -758,9 +793,7 @@ async def get_agent_as_json(
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -773,12 +806,10 @@ async def generate_agent_patch(
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -789,8 +820,6 @@ async def generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
operation_id,
|
||||
task_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -102,10 +102,15 @@ async def generate_agent_dummy(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy agent JSON after a simulated delay."""
|
||||
logger.info("Using dummy agent generator for generate_agent (30s delay)")
|
||||
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
return _generate_dummy_agent_json()
|
||||
|
||||
@@ -115,10 +120,16 @@ async def generate_agent_patch_dummy(
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy patched agent (returns the current agent with updated description)."""
|
||||
logger.info("Using dummy agent generator for generate_agent_patch")
|
||||
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
patched = current_agent.copy()
|
||||
patched["description"] = (
|
||||
f"{current_agent.get('description', '')} (updated: {update_request})"
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""External Agent Generator service client.
|
||||
|
||||
This module provides a client for communicating with the external Agent Generator
|
||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||
microservice. All generation endpoints use async polling: submit a job (202),
|
||||
then poll GET /api/jobs/{job_id} every few seconds until the result is ready.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -25,22 +27,21 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_dummy_mode_warned = False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
POLL_INTERVAL_SECONDS = 10.0
|
||||
MAX_POLL_TIME_SECONDS = 1800.0 # 30 minutes
|
||||
MAX_CONSECUTIVE_POLL_ERRORS = 5
|
||||
|
||||
|
||||
def _create_error_response(
|
||||
error_message: str,
|
||||
error_type: str = "unknown",
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a standardized error response dict.
|
||||
|
||||
Args:
|
||||
error_message: Human-readable error message
|
||||
error_type: Machine-readable error type
|
||||
details: Optional additional error details
|
||||
|
||||
Returns:
|
||||
Error dict with type="error" and error details
|
||||
"""
|
||||
"""Create a standardized error response dict."""
|
||||
response: dict[str, Any] = {
|
||||
"type": "error",
|
||||
"error": error_message,
|
||||
@@ -52,14 +53,7 @@ def _create_error_response(
|
||||
|
||||
|
||||
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
"""Classify an HTTP error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The HTTP status error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
"""Classify an HTTP error into error_type and message."""
|
||||
status = e.response.status_code
|
||||
if status == 429:
|
||||
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||
@@ -72,14 +66,7 @@ def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
|
||||
|
||||
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
"""Classify a request error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The request error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
"""Classify a request error into error_type and message."""
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return "timeout", f"Agent Generator request timed out: {e}"
|
||||
@@ -89,6 +76,10 @@ def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client / settings singletons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
@@ -136,13 +127,149 @@ def _get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None:
|
||||
settings = _get_settings()
|
||||
timeout = httpx.Timeout(float(settings.config.agentgenerator_timeout))
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=_get_base_url(),
|
||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||
timeout=timeout,
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core polling helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _submit_and_poll(
|
||||
endpoint: str,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Submit a job to the agent-generator and poll until the result is ready.
|
||||
|
||||
The endpoint is expected to return 202 with ``{"job_id": "..."}`` on success.
|
||||
We then poll ``GET /api/jobs/{job_id}`` every ``POLL_INTERVAL_SECONDS``
|
||||
until the job completes or fails.
|
||||
|
||||
Returns:
|
||||
The *result* dict from a completed job, or an error dict.
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# 1. Submit ----------------------------------------------------------------
|
||||
try:
|
||||
response = await client.post(endpoint, json=payload)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
data = response.json()
|
||||
job_id = data.get("job_id")
|
||||
if not job_id:
|
||||
return _create_error_response(
|
||||
"Agent Generator did not return a job_id", "invalid_response"
|
||||
)
|
||||
|
||||
logger.info(f"Agent Generator job submitted: {job_id} via {endpoint}")
|
||||
|
||||
# 2. Poll ------------------------------------------------------------------
|
||||
start = time.monotonic()
|
||||
consecutive_errors = 0
|
||||
while (time.monotonic() - start) < MAX_POLL_TIME_SECONDS:
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
try:
|
||||
poll_resp = await client.get(f"/api/jobs/{job_id}")
|
||||
poll_resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
return _create_error_response(
|
||||
"Agent Generator job not found or expired", "job_not_found"
|
||||
)
|
||||
status_code = e.response.status_code
|
||||
if status_code in {429, 503, 504, 408}:
|
||||
consecutive_errors += 1
|
||||
logger.warning(
|
||||
f"Transient HTTP {status_code} polling job {job_id} "
|
||||
f"({consecutive_errors}/{MAX_CONSECUTIVE_POLL_ERRORS}): {e}"
|
||||
)
|
||||
if consecutive_errors >= MAX_CONSECUTIVE_POLL_ERRORS:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(
|
||||
f"Giving up on job {job_id} after "
|
||||
f"{MAX_CONSECUTIVE_POLL_ERRORS} consecutive poll errors: {error_msg}"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
continue
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(f"Poll error for job {job_id}: {error_msg}")
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
consecutive_errors += 1
|
||||
logger.warning(
|
||||
f"Transient poll error for job {job_id} "
|
||||
f"({consecutive_errors}/{MAX_CONSECUTIVE_POLL_ERRORS}): {e}"
|
||||
)
|
||||
if consecutive_errors >= MAX_CONSECUTIVE_POLL_ERRORS:
|
||||
error_msg = (
|
||||
f"Giving up on job {job_id} after "
|
||||
f"{MAX_CONSECUTIVE_POLL_ERRORS} consecutive poll errors: {e}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "poll_error")
|
||||
continue
|
||||
|
||||
consecutive_errors = 0
|
||||
poll_data = poll_resp.json()
|
||||
status = poll_data.get("status")
|
||||
|
||||
if status == "completed":
|
||||
logger.info(f"Agent Generator job {job_id} completed")
|
||||
result = poll_data.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
return _create_error_response(
|
||||
"Agent Generator returned invalid result payload",
|
||||
"invalid_response",
|
||||
)
|
||||
return result
|
||||
elif status == "failed":
|
||||
error_msg = poll_data.get("error", "Job failed")
|
||||
logger.error(f"Agent Generator job {job_id} failed: {error_msg}")
|
||||
return _create_error_response(error_msg, "job_failed")
|
||||
elif status in {"running", "pending", "queued"}:
|
||||
continue
|
||||
else:
|
||||
return _create_error_response(
|
||||
f"Agent Generator returned unexpected job status: {status}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
return _create_error_response("Agent generation timed out after polling", "timeout")
|
||||
|
||||
|
||||
def _extract_agent_json(result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract and validate agent_json from a job result.
|
||||
|
||||
Returns the agent_json dict, or an error response if missing/invalid.
|
||||
"""
|
||||
agent_json = result.get("agent_json")
|
||||
if not isinstance(agent_json, dict):
|
||||
return _create_error_response(
|
||||
"Agent Generator returned no agent_json in result", "invalid_response"
|
||||
)
|
||||
return agent_json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public functions — same signatures as before, now using polling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str,
|
||||
context: str = "",
|
||||
@@ -150,25 +277,17 @@ async def decompose_goal_external(
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
Returns one of the following dicts (keyed by ``"type"``):
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||
Or None on unexpected error
|
||||
* ``{"type": "instructions", "steps": [...]}``
|
||||
* ``{"type": "clarifying_questions", "questions": [...]}``
|
||||
* ``{"type": "unachievable_goal", "reason": ..., "suggested_goal": ...}``
|
||||
* ``{"type": "vague_goal", "suggested_goal": ...}``
|
||||
* ``{"type": "error", "error": ..., "error_type": ...}``
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await decompose_goal_dummy(description, context, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
@@ -177,236 +296,113 @@ async def decompose_goal_external(
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator decomposition failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": data.get("reason"),
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "error":
|
||||
# Pass through error from the service
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return _create_error_response(
|
||||
f"Unknown response type from Agent Generator: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/decompose-description", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
# The result dict from the job is already in the expected format
|
||||
# (type, steps, questions, etc.) — just return it as-is.
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
response_type = result.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": result.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": result.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": result.get("reason"),
|
||||
"suggested_goal": result.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": result.get("suggested_goal"),
|
||||
}
|
||||
else:
|
||||
logger.error(f"Unknown response type from Agent Generator job: {response_type}")
|
||||
return _create_error_response(
|
||||
f"Unknown response type: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict or error dict {"type": "error", ...} on error.
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_dummy(
|
||||
instructions, library_agents, operation_id, task_id
|
||||
)
|
||||
return await generate_agent_dummy(instructions, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/generate-agent", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
return _extract_agent_json(result)
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
Updated agent JSON, clarifying questions dict, or error dict.
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_patch_dummy(
|
||||
update_request, current_agent, library_agents, operation_id, task_id
|
||||
update_request, current_agent, library_agents
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator patch generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/update-agent", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
if result.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": result.get("questions", []),
|
||||
}
|
||||
|
||||
return _extract_agent_json(result)
|
||||
|
||||
|
||||
async def customize_template_external(
|
||||
template_agent: dict[str, Any],
|
||||
@@ -415,81 +411,51 @@ async def customize_template_external(
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to customize a template/marketplace agent.
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
Customized agent JSON, clarifying questions dict, or error dict.
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await customize_template_dummy(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
request_text = modification_request
|
||||
if context:
|
||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
request_text = (
|
||||
f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"template_agent_json": template_agent,
|
||||
"modification_request": request,
|
||||
"modification_request": request_text,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post("/api/template-modification", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator template customization failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the customized agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/template-modification", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
if result.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": result.get("questions", []),
|
||||
}
|
||||
|
||||
return _extract_agent_json(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-generation endpoints (still synchronous — quick responses)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
"""Get available blocks from the external service."""
|
||||
if _is_dummy_mode():
|
||||
return await get_blocks_dummy()
|
||||
|
||||
@@ -518,11 +484,7 @@ async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
|
||||
|
||||
async def health_check() -> bool:
|
||||
"""Check if the external service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
"""Check if the external service is healthy."""
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
|
||||
@@ -36,16 +36,6 @@ class BaseTool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
"""Whether this tool is long-running and should execute in background.
|
||||
|
||||
Long-running tools (like agent generation) are executed via background
|
||||
tasks to survive SSE disconnections. The result is persisted to chat
|
||||
history and visible when the user refreshes.
|
||||
"""
|
||||
return False
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""CheckOperationStatusTool — query the status of a long-running operation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperationStatusResponse(ToolResponseBase):
|
||||
"""Response for check_operation_status tool."""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STATUS
|
||||
task_id: str
|
||||
operation_id: str
|
||||
status: str # "running", "completed", "failed"
|
||||
tool_name: str | None = None
|
||||
message: str = ""
|
||||
|
||||
|
||||
class CheckOperationStatusTool(BaseTool):
|
||||
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
|
||||
|
||||
The CoPilot uses this tool to report back to the user whether an
|
||||
operation that was started earlier has completed, failed, or is still
|
||||
running.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "check_operation_status"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Check the current status of a long-running operation such as "
|
||||
"create_agent or edit_agent. Accepts either an operation_id or "
|
||||
"task_id from a previous operation_started response. "
|
||||
"Returns the current status: running, completed, or failed."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The operation_id from an operation_started response."
|
||||
),
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The task_id from an operation_started response. "
|
||||
"Used as fallback if operation_id is not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
from backend.copilot import stream_registry
|
||||
|
||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
||||
task_id = (kwargs.get("task_id") or "").strip()
|
||||
|
||||
if not operation_id and not task_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an operation_id or task_id.",
|
||||
error="missing_parameter",
|
||||
)
|
||||
|
||||
task = None
|
||||
if operation_id:
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None and task_id:
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
# Task not in Redis — it may have already expired (TTL).
|
||||
# Check conversation history for the result instead.
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Operation not found — it may have already completed and "
|
||||
"expired from the status tracker. Check the conversation "
|
||||
"history for the result."
|
||||
),
|
||||
error="not_found",
|
||||
)
|
||||
|
||||
status_messages = {
|
||||
"running": (
|
||||
f"The {task.tool_name or 'operation'} is still running. "
|
||||
"Please wait for it to complete."
|
||||
),
|
||||
"completed": (
|
||||
f"The {task.tool_name or 'operation'} has completed successfully."
|
||||
),
|
||||
"failed": f"The {task.tool_name or 'operation'} has failed.",
|
||||
}
|
||||
|
||||
return OperationStatusResponse(
|
||||
task_id=task.task_id,
|
||||
operation_id=task.operation_id,
|
||||
status=task.status,
|
||||
tool_name=task.tool_name,
|
||||
message=status_messages.get(task.status, f"Status: {task.status}"),
|
||||
)
|
||||
@@ -10,7 +10,6 @@ from .agent_generator import (
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -18,7 +17,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -40,17 +38,16 @@ class CreateAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true."
|
||||
"First generates a preview, then saves to library if save=true. "
|
||||
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
@@ -70,6 +67,15 @@ class CreateAgentTool(BaseTool):
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks. "
|
||||
"Search for relevant agents using find_library_agent first, "
|
||||
"then pass their IDs here so they can be composed into the new agent."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
@@ -97,12 +103,14 @@ class CreateAgentTool(BaseTool):
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
|
||||
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
|
||||
)
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
@@ -111,25 +119,34 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id:
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
search_query=description,
|
||||
include_marketplace=True,
|
||||
agent_ids=library_agent_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
@@ -230,10 +247,17 @@ class CreateAgentTool(BaseTool):
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] GENERATE - "
|
||||
f"success={agent_json is not None}, "
|
||||
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
@@ -276,25 +300,20 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
|
||||
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
|
||||
)
|
||||
|
||||
if not save:
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
@@ -320,6 +339,13 @@ class CreateAgentTool(BaseTool):
|
||||
agent_json, user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
|
||||
f"library_agent_id={library_agent.id}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
@@ -330,6 +356,12 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
|
||||
@@ -43,11 +43,6 @@ async def test_vague_goal_returns_suggested_goal_response(tool, session):
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
@@ -78,11 +73,6 @@ async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
@@ -120,11 +110,6 @@ async def test_clarifying_questions_returns_clarification_needed_response(
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
|
||||
@@ -46,10 +46,6 @@ class CustomizeAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -9,7 +9,6 @@ from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -17,7 +16,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -38,17 +36,16 @@ class EditAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates updates to the agent while preserving unchanged parts."
|
||||
"Generates updates to the agent while preserving unchanged parts. "
|
||||
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"functionality, search for relevant existing agents using find_library_agent "
|
||||
"that could be used as building blocks. Pass their IDs in library_agent_ids."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
@@ -74,6 +71,15 @@ class EditAgentTool(BaseTool):
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks for the changes. "
|
||||
"If adding new functionality, search for relevant agents using "
|
||||
"find_library_agent first, then pass their IDs here."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
@@ -102,13 +108,10 @@ class EditAgentTool(BaseTool):
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -132,21 +135,25 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id:
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
# Filter out the current agent being edited
|
||||
filtered_ids = [id for id in library_agent_ids if id != graph_id]
|
||||
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
search_query=changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
agent_ids=filtered_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
update_request = changes
|
||||
if context:
|
||||
@@ -157,8 +164,6 @@ class EditAgentTool(BaseTool):
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -178,19 +183,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
|
||||
@@ -366,12 +366,15 @@ class TestFindBlockFiltering:
|
||||
return_value=(search_results, len(search_results))
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
), patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=lambda bid: mock_blocks.get(bid),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=lambda bid: mock_blocks.get(bid),
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
|
||||
@@ -36,8 +36,6 @@ class ResponseType(str, Enum):
|
||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||
# Long-running operation types
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
@@ -45,8 +43,6 @@ class ResponseType(str, Enum):
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Operation status check
|
||||
OPERATION_STATUS = "operation_status"
|
||||
# Feature request types
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
@@ -420,34 +416,6 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
|
||||
|
||||
# Long-running operation models
|
||||
class OperationStartedResponse(ToolResponseBase):
|
||||
"""Response when a long-running operation has been started in the background.
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
"""Response stored in chat history while a long-running operation is executing.
|
||||
|
||||
This is persisted to the database so users see a pending state when they
|
||||
refresh before the operation completes.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class OperationInProgressResponse(ToolResponseBase):
|
||||
"""Response when an operation is already in progress.
|
||||
|
||||
@@ -459,23 +427,6 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||
consumer will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
|
||||
class WebFetchResponse(ToolResponseBase):
|
||||
"""Response for web_fetch tool."""
|
||||
|
||||
|
||||
@@ -160,9 +160,10 @@ class RunBlockTool(BaseTool):
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = (
|
||||
await self._resolve_block_credentials(user_id, block, input_data)
|
||||
)
|
||||
(
|
||||
matched_credentials,
|
||||
missing_credentials,
|
||||
) = await self._resolve_block_credentials(user_id, block, input_data)
|
||||
|
||||
# Get block schemas for details/validation
|
||||
try:
|
||||
|
||||
@@ -13,6 +13,7 @@ import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -245,6 +246,7 @@ async def run_sandboxed(
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=safe_env,
|
||||
start_new_session=True, # Own process group for clean kill
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -255,7 +257,18 @@ async def run_sandboxed(
|
||||
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
||||
return stdout, stderr, proc.returncode or 0, False
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
# Kill entire process group (bwrap + all children).
|
||||
# proc.kill() alone only kills the bwrap parent, leaving
|
||||
# children running until they finish naturally.
|
||||
try:
|
||||
os.killpg(proc.pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # Already exited
|
||||
except OSError as kill_err:
|
||||
logger.warning(
|
||||
"Failed to kill process group %d: %s", proc.pid, kill_err
|
||||
)
|
||||
# Always reap the subprocess regardless of killpg outcome.
|
||||
await proc.communicate()
|
||||
return "", f"Execution timed out after {timeout}s", -1, True
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
@@ -18,6 +20,151 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
content_b64: str | None,
|
||||
source_path: str | None,
|
||||
session_id: str,
|
||||
) -> bytes | ErrorResponse:
|
||||
"""Resolve file content from exactly one of three input sources.
|
||||
|
||||
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
|
||||
failure (wrong number of sources, invalid path, file not found, etc.).
|
||||
"""
|
||||
# Normalise empty strings to None so counting and dispatch stay in sync.
|
||||
if content_text is not None and content_text == "":
|
||||
content_text = None
|
||||
if content_b64 is not None and content_b64 == "":
|
||||
content_b64 = None
|
||||
if source_path is not None and source_path == "":
|
||||
source_path = None
|
||||
|
||||
sources_provided = sum(
|
||||
x is not None for x in [content_text, content_b64, source_path]
|
||||
)
|
||||
if sources_provided == 0:
|
||||
return ErrorResponse(
|
||||
message="Please provide one of: content, content_base64, or source_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
if sources_provided > 1:
|
||||
return ErrorResponse(
|
||||
message="Provide only one of: content, content_base64, or source_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if source_path is not None:
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if content_b64 is not None:
|
||||
try:
|
||||
return base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Invalid base64 encoding in content_base64. "
|
||||
"Please encode the file content with standard base64, "
|
||||
"or use the 'content' parameter for plain text, "
|
||||
"or 'source_path' to copy from the working directory."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
assert content_text is not None
|
||||
return content_text.encode("utf-8")
|
||||
|
||||
|
||||
def _validate_ephemeral_path(
|
||||
path: str, *, param_name: str, session_id: str
|
||||
) -> ErrorResponse | str:
|
||||
"""Validate that *path* is inside the session's ephemeral directory.
|
||||
|
||||
Uses the session-specific directory (``make_session_path(session_id)``)
|
||||
rather than the bare prefix, so ``/tmp/copilot-evil/...`` is rejected.
|
||||
|
||||
Returns the resolved real path on success, or an ``ErrorResponse`` when the
|
||||
path escapes the session directory.
|
||||
"""
|
||||
session_dir = os.path.realpath(make_session_path(session_id)) + os.sep
|
||||
real = os.path.realpath(path)
|
||||
if not real.startswith(session_dir):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"{param_name} must be within the ephemeral working "
|
||||
f"directory ({make_session_path(session_id)})"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
return real
|
||||
|
||||
|
||||
_TEXT_MIME_PREFIXES = (
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
)
|
||||
|
||||
_IMAGE_MIME_TYPES = {"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
|
||||
|
||||
def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
path: str | None,
|
||||
session_id: str,
|
||||
) -> tuple[str, Any] | ErrorResponse:
|
||||
"""Resolve a file by file_id or path.
|
||||
|
||||
Returns ``(target_file_id, file_info)`` on success, or an
|
||||
``ErrorResponse`` if the file was not found.
|
||||
"""
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}", session_id=session_id
|
||||
)
|
||||
return file_id, file_info
|
||||
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}", session_id=session_id
|
||||
)
|
||||
return file_info.id, file_info
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
@@ -68,6 +215,8 @@ class WorkspaceWriteResponse(ToolResponseBase):
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
source: str | None = None # "content", "base64", or "copied from <path>"
|
||||
content_preview: str | None = None # First 200 chars for text files
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
@@ -136,11 +285,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
@@ -148,20 +295,13 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
path=path_prefix, include_all_sessions=include_all_sessions
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
@@ -172,19 +312,27 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
scope = "all sessions" if include_all_sessions else "current session"
|
||||
total_size = sum(f.size_bytes for f in file_infos)
|
||||
|
||||
# Build a human-readable summary so the agent can relay details.
|
||||
lines = [f"Found {len(files)} file(s) in workspace ({scope}):"]
|
||||
for f in file_infos:
|
||||
lines.append(f" - {f.path} ({f.size_bytes:,} bytes, {f.mime_type})")
|
||||
if total > len(files):
|
||||
lines.append(f" ... and {total - len(files)} more")
|
||||
lines.append(f"Total size: {total_size:,} bytes")
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||
message="\n".join(lines),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {str(e)}",
|
||||
message=f"Failed to list workspace files: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -193,10 +341,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
# Size threshold for returning full content vs metadata+URL
|
||||
# Files larger than this return metadata with download URL to prevent context bloat
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
# Preview size for text files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
@@ -212,6 +357,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Optionally use 'save_to_path' to copy the file to the ephemeral "
|
||||
"working directory for processing with bash_exec or SDK tools. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
@@ -232,6 +379,15 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the ephemeral "
|
||||
"working directory (e.g., '/tmp/copilot-.../data.csv') "
|
||||
"so it can be processed with bash_exec or SDK tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
@@ -247,18 +403,6 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""Check if the MIME type is a text-based type."""
|
||||
text_types = [
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
]
|
||||
return any(mime_type.startswith(t) for t in text_types)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
@@ -266,117 +410,112 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
save_to_path: Optional[str] = kwargs.get("save_to_path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
# Validate and resolve save_to_path (use sanitized real path).
|
||||
if save_to_path:
|
||||
validated_save = _validate_ephemeral_path(
|
||||
save_to_path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated_save, ErrorResponse):
|
||||
return validated_save
|
||||
save_to_path = validated_save
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
# If save_to_path, read + save; cache bytes for possible inline reuse.
|
||||
cached_content: bytes | None = None
|
||||
if save_to_path:
|
||||
cached_content = await manager.read_file_by_id(target_file_id)
|
||||
dir_path = os.path.dirname(save_to_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(save_to_path, "wb") as f:
|
||||
f.write(cached_content)
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
||||
|
||||
# Return inline content for small text/image files (unless force_download_url)
|
||||
is_image_file = file_info.mime_type in {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
}
|
||||
if (
|
||||
is_small_file
|
||||
and (is_text_file or is_image_file)
|
||||
and not force_download_url
|
||||
):
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text = _is_text_mime(file_info.mime_type)
|
||||
is_image = file_info.mime_type in _IMAGE_MIME_TYPES
|
||||
|
||||
# Inline content for small text/image files
|
||||
if is_small and (is_text or is_image) and not force_download_url:
|
||||
content = cached_content or await manager.read_file_by_id(
|
||||
target_file_id
|
||||
)
|
||||
msg = (
|
||||
f"Read {file_info.name} from workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes, {file_info.mime_type})"
|
||||
)
|
||||
if save_to_path:
|
||||
msg += f" — also saved to {save_to_path}"
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mime_type,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
content_base64=base64.b64encode(content).decode("utf-8"),
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Return metadata + workspace:// reference for large or binary files
|
||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||
download_url = f"workspace://{target_file_id}"
|
||||
|
||||
# Generate preview for text files
|
||||
# Metadata + download URL for large/binary files
|
||||
preview: str | None = None
|
||||
if is_text_file:
|
||||
if is_text:
|
||||
try:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||
"utf-8", errors="replace"
|
||||
raw = cached_content or await manager.read_file_by_id(
|
||||
target_file_id
|
||||
)
|
||||
if len(content) > self.PREVIEW_SIZE:
|
||||
preview_text += "..."
|
||||
preview = preview_text
|
||||
preview = raw[: self.PREVIEW_SIZE].decode("utf-8", errors="replace")
|
||||
if len(raw) > self.PREVIEW_SIZE:
|
||||
preview += "..."
|
||||
except Exception:
|
||||
pass # Preview is optional
|
||||
pass
|
||||
|
||||
msg = (
|
||||
f"File: {file_info.name} at workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes, {file_info.mime_type})"
|
||||
)
|
||||
if save_to_path:
|
||||
msg += f" — saved to {save_to_path}"
|
||||
else:
|
||||
msg += (
|
||||
" — use read_workspace_file with this file_id to retrieve content"
|
||||
)
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mime_type,
|
||||
size_bytes=file_info.size_bytes,
|
||||
download_url=download_url,
|
||||
download_url=f"workspace://{target_file_id}",
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.",
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {str(e)}",
|
||||
message=f"Failed to read workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -395,7 +534,9 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
"Provide content as plain text via 'content', OR base64-encoded via "
|
||||
"'content_base64', OR copy a file from the ephemeral working directory "
|
||||
"via 'source_path'. Exactly one of these three is required. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
@@ -410,9 +551,30 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Plain text content to write. Use this for text files "
|
||||
"(code, configs, documents, etc.). "
|
||||
"Mutually exclusive with content_base64 and source_path."
|
||||
),
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded file content",
|
||||
"description": (
|
||||
"Base64-encoded file content. Use this for binary files "
|
||||
"(images, PDFs, etc.). "
|
||||
"Mutually exclusive with content and source_path."
|
||||
),
|
||||
},
|
||||
"source_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path to a file in the ephemeral working directory to "
|
||||
"copy to workspace (e.g., '/tmp/copilot-.../output.csv'). "
|
||||
"Use this to persist files created by bash_exec or SDK Write. "
|
||||
"Mutually exclusive with content and content_base64."
|
||||
),
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
@@ -434,7 +596,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content_base64"],
|
||||
"required": ["filename"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -448,82 +610,92 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
message="Please provide a filename", session_id=session_id
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
)
|
||||
source_path_arg: str | None = kwargs.get("source_path")
|
||||
content_text: str | None = kwargs.get("content")
|
||||
content_b64: str | None = kwargs.get("content_base64")
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
session_id=session_id,
|
||||
)
|
||||
resolved = _resolve_write_content(
|
||||
content_text,
|
||||
content_b64,
|
||||
source_path_arg,
|
||||
session_id,
|
||||
)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
content: bytes = resolved
|
||||
|
||||
# Check size
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
max_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
file_record = await manager.write_file(
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
path=kwargs.get("path"),
|
||||
mime_type=kwargs.get("mime_type"),
|
||||
overwrite=kwargs.get("overwrite", False),
|
||||
)
|
||||
|
||||
# Build informative source label and message.
|
||||
if source_path_arg:
|
||||
source = f"copied from {source_path_arg}"
|
||||
msg = (
|
||||
f"Copied {source_path_arg} → workspace:{rec.path} "
|
||||
f"({rec.size_bytes:,} bytes)"
|
||||
)
|
||||
elif content_b64:
|
||||
source = "base64"
|
||||
msg = (
|
||||
f"Wrote {rec.name} to workspace ({rec.size_bytes:,} bytes, "
|
||||
f"decoded from base64)"
|
||||
)
|
||||
else:
|
||||
source = "content"
|
||||
msg = f"Wrote {rec.name} to workspace ({rec.size_bytes:,} bytes)"
|
||||
|
||||
# Include a short preview for text content.
|
||||
preview: str | None = None
|
||||
if _is_text_mime(rec.mime_type):
|
||||
try:
|
||||
preview = content[:200].decode("utf-8", errors="replace")
|
||||
if len(content) > 200:
|
||||
preview += "..."
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.size_bytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
file_id=rec.id,
|
||||
name=rec.name,
|
||||
path=rec.path,
|
||||
size_bytes=rec.size_bytes,
|
||||
source=source,
|
||||
content_preview=preview,
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {str(e)}",
|
||||
message=f"Failed to write workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -576,61 +748,42 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
success = await manager.delete_file(target_file_id)
|
||||
|
||||
if not success:
|
||||
if not await manager.delete_file(target_file_id):
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}",
|
||||
session_id=session_id,
|
||||
message=f"File not found: {target_file_id}", session_id=session_id
|
||||
)
|
||||
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message="File deleted successfully",
|
||||
message=(
|
||||
f"Deleted {file_info.name} from workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes)"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {str(e)}",
|
||||
message=f"Failed to delete workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
"""Tests for workspace file tool helpers and path validation."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools._test_data import make_session, setup_test_data
|
||||
from backend.copilot.tools.workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WorkspaceDeleteResponse,
|
||||
WorkspaceFileListResponse,
|
||||
WorkspaceWriteResponse,
|
||||
WriteWorkspaceFileTool,
|
||||
_resolve_write_content,
|
||||
_validate_ephemeral_path,
|
||||
)
|
||||
|
||||
# Re-export so pytest discovers the session-scoped fixture
|
||||
setup_test_data = setup_test_data
|
||||
|
||||
# We need to mock make_session_path to return a known temp dir for tests.
|
||||
# The real one uses WORKSPACE_PREFIX = "/tmp/copilot-"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ephemeral_dir(tmp_path, monkeypatch):
|
||||
"""Create a temp dir that acts as the ephemeral session directory."""
|
||||
session_dir = tmp_path / "copilot-test-session"
|
||||
session_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.workspace_files.make_session_path",
|
||||
lambda session_id: str(session_dir),
|
||||
)
|
||||
return session_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_ephemeral_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateEphemeralPath:
|
||||
def test_valid_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "file.txt"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert result == os.path.realpath(str(target))
|
||||
|
||||
def test_path_traversal_rejected(self, ephemeral_dir):
|
||||
evil_path = str(ephemeral_dir / ".." / "etc" / "passwd")
|
||||
result = _validate_ephemeral_path(evil_path, param_name="test", session_id="s1")
|
||||
# Should return ErrorResponse
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_different_session_rejected(self, ephemeral_dir, tmp_path):
|
||||
other_dir = tmp_path / "copilot-evil-session"
|
||||
other_dir.mkdir()
|
||||
target = other_dir / "steal.txt"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_symlink_escape_rejected(self, ephemeral_dir, tmp_path):
|
||||
"""Symlink inside session dir pointing outside should be rejected."""
|
||||
outside_file = tmp_path / "secret.txt"
|
||||
outside_file.write_text("secret")
|
||||
symlink = ephemeral_dir / "link.txt"
|
||||
symlink.symlink_to(outside_file)
|
||||
result = _validate_ephemeral_path(
|
||||
str(symlink), param_name="test", session_id="s1"
|
||||
)
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_nested_path_valid(self, ephemeral_dir):
|
||||
nested = ephemeral_dir / "subdir" / "deep"
|
||||
nested.mkdir(parents=True)
|
||||
target = nested / "data.csv"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_write_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveWriteContent:
|
||||
def test_no_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, None, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_multiple_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content("text", "b64data", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_plain_text_content(self):
|
||||
result = _resolve_write_content("hello world", None, None, "s1")
|
||||
assert result == b"hello world"
|
||||
|
||||
def test_base64_content(self):
|
||||
raw = b"binary data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
result = _resolve_write_content(None, b64, None, "s1")
|
||||
assert result == raw
|
||||
|
||||
def test_invalid_base64_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "base64" in result.message.lower()
|
||||
|
||||
def test_source_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "input.txt"
|
||||
target.write_bytes(b"file content")
|
||||
result = _resolve_write_content(None, None, str(target), "s1")
|
||||
assert result == b"file content"
|
||||
|
||||
def test_source_path_not_found(self, ephemeral_dir):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
missing = str(ephemeral_dir / "nope.txt")
|
||||
result = _resolve_write_content(None, None, missing, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("nope")
|
||||
result = _resolve_write_content(None, None, str(outside), "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_sources_treated_as_none(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
# All empty strings → same as no sources
|
||||
result = _resolve_write_content("", "", "", "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_source_path_with_text(self):
|
||||
# source_path="" should be normalised to None, so only content counts
|
||||
result = _resolve_write_content("hello", "", "", "s1")
|
||||
assert result == b"hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E: workspace file tool round-trip (write → list → read → delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_workspace_file_round_trip(setup_test_data):
|
||||
"""E2E: write a file, list it, read it back (with save_to_path), then delete it."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
session_id = session.session_id
|
||||
|
||||
# ---- Write ----
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="test_round_trip.txt",
|
||||
content="Hello from e2e test!",
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
file_id = write_resp.file_id
|
||||
|
||||
# ---- List ----
|
||||
list_tool = ListWorkspaceFilesTool()
|
||||
list_resp = await list_tool._execute(user_id=user.id, session=session)
|
||||
assert isinstance(list_resp, WorkspaceFileListResponse), list_resp.message
|
||||
assert any(f.file_id == file_id for f in list_resp.files)
|
||||
|
||||
# ---- Read (inline) ----
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
read_resp = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id
|
||||
)
|
||||
from backend.copilot.tools.workspace_files import WorkspaceFileContentResponse
|
||||
|
||||
assert isinstance(read_resp, WorkspaceFileContentResponse), read_resp.message
|
||||
decoded = base64.b64decode(read_resp.content_base64).decode()
|
||||
assert decoded == "Hello from e2e test!"
|
||||
|
||||
# ---- Read with save_to_path ----
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
|
||||
ephemeral_dir = make_session_path(session_id)
|
||||
os.makedirs(ephemeral_dir, exist_ok=True)
|
||||
save_path = os.path.join(ephemeral_dir, "saved_copy.txt")
|
||||
|
||||
read_resp2 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, save_to_path=save_path
|
||||
)
|
||||
assert not isinstance(read_resp2, type(None))
|
||||
assert os.path.exists(save_path)
|
||||
with open(save_path) as f:
|
||||
assert f.read() == "Hello from e2e test!"
|
||||
|
||||
# ---- Delete ----
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
del_resp = await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id
|
||||
)
|
||||
assert isinstance(del_resp, WorkspaceDeleteResponse), del_resp.message
|
||||
assert del_resp.success is True
|
||||
|
||||
# Verify file is gone
|
||||
list_resp2 = await list_tool._execute(user_id=user.id, session=session)
|
||||
assert isinstance(list_resp2, WorkspaceFileListResponse)
|
||||
assert not any(f.file_id == file_id for f in list_resp2.files)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_write_workspace_file_source_path(setup_test_data):
|
||||
"""E2E: write a file from ephemeral source_path to workspace."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
session_id = session.session_id
|
||||
|
||||
# Create a file in the ephemeral dir
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
|
||||
ephemeral_dir = make_session_path(session_id)
|
||||
os.makedirs(ephemeral_dir, exist_ok=True)
|
||||
source = os.path.join(ephemeral_dir, "generated_output.csv")
|
||||
with open(source, "w") as f:
|
||||
f.write("col1,col2\n1,2\n")
|
||||
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="output.csv",
|
||||
source_path=source,
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
|
||||
# Clean up
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=write_resp.file_id
|
||||
)
|
||||
@@ -303,7 +303,7 @@ class DatabaseManager(AppService):
|
||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||
get_user_session_count = _(chat_db.get_user_session_count)
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
|
||||
|
||||
@@ -473,5 +473,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_chat_sessions = d.get_user_chat_sessions
|
||||
get_user_session_count = d.get_user_session_count
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_chat_session_message_count = d.get_chat_session_message_count
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
|
||||
426
autogpt_platform/backend/backend/data/tally.py
Normal file
426
autogpt_platform/backend/backend/data/tally.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Tally form integration: cache submissions, match by email, extract business understanding."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TALLY_API_BASE = "https://api.tally.so"
|
||||
_settings = Settings()
|
||||
TALLY_FORM_ID = _settings.secrets.tally_form_id
|
||||
|
||||
# Redis key templates
|
||||
_EMAIL_INDEX_KEY = "tally:form:{form_id}:email_index"
|
||||
_QUESTIONS_KEY = "tally:form:{form_id}:questions"
|
||||
_LAST_FETCH_KEY = "tally:form:{form_id}:last_fetch"
|
||||
|
||||
# TTLs — keep aligned so last_fetch never outlives the index
|
||||
_INDEX_TTL = 3600 # 1 hour
|
||||
_LAST_FETCH_TTL = 3600 # 1 hour (same as index)
|
||||
|
||||
# Pagination
|
||||
_PAGE_LIMIT = 500
|
||||
_MAX_PAGES = 100
|
||||
|
||||
# LLM extraction timeout (seconds)
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
try:
|
||||
local, domain = email.rsplit("@", 1)
|
||||
if len(local) <= 2:
|
||||
masked_local = local[0] + "***"
|
||||
else:
|
||||
masked_local = local[0] + "***" + local[-1]
|
||||
return f"{masked_local}@{domain}"
|
||||
except (ValueError, IndexError):
|
||||
return "***"
|
||||
|
||||
|
||||
async def _fetch_tally_page(
|
||||
client: Requests,
|
||||
form_id: str,
|
||||
page: int,
|
||||
limit: int = _PAGE_LIMIT,
|
||||
start_date: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Fetch a single page of submissions from the Tally API."""
|
||||
url = f"{TALLY_API_BASE}/forms/{form_id}/submissions?page={page}&limit={limit}"
|
||||
if start_date:
|
||||
url += f"&startDate={start_date}"
|
||||
|
||||
response = await client.get(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
def _make_tally_client(api_key: str) -> Requests:
|
||||
"""Create a Requests client configured for the Tally API."""
|
||||
return Requests(
|
||||
trusted_origins=[TALLY_API_BASE],
|
||||
raise_for_status=True,
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_all_submissions(
|
||||
client: Requests,
|
||||
form_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
max_pages: int = _MAX_PAGES,
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Paginate through all Tally submissions. Returns (questions, submissions)."""
|
||||
|
||||
questions: list[dict] = []
|
||||
all_submissions: list[dict] = []
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
data = await _fetch_tally_page(client, form_id, page, start_date=start_date)
|
||||
|
||||
if page == 1:
|
||||
questions = data.get("questions", [])
|
||||
|
||||
submissions = data.get("submissions", [])
|
||||
all_submissions.extend(submissions)
|
||||
|
||||
# Tally API uses `hasMore` for pagination
|
||||
has_more = data.get("hasMore", False)
|
||||
if not has_more:
|
||||
break
|
||||
if page >= max_pages:
|
||||
total = data.get("totalNumberOfSubmissionsPerFilter", {}).get("all", "?")
|
||||
logger.warning(
|
||||
f"Tally: hit max page cap ({max_pages}) for form {form_id}, "
|
||||
f"fetched {len(all_submissions)} of {total} total submissions"
|
||||
)
|
||||
break
|
||||
page += 1
|
||||
|
||||
return questions, all_submissions
|
||||
|
||||
|
||||
def _build_email_index(
|
||||
submissions: list[dict], questions: list[dict]
|
||||
) -> dict[str, dict]:
|
||||
"""Build an {email -> submission_data} index from submissions.
|
||||
|
||||
Scans question titles for email/contact fields to find the email answer.
|
||||
"""
|
||||
# Find question IDs that are likely email fields
|
||||
email_question_ids: list[str] = []
|
||||
for q in questions:
|
||||
label = (q.get("label") or q.get("title") or q.get("name") or "").lower()
|
||||
q_type = (q.get("type") or "").lower()
|
||||
if q_type in ("input_email", "email"):
|
||||
email_question_ids.append(q["id"])
|
||||
elif any(kw in label for kw in ("email", "e-mail", "contact")):
|
||||
email_question_ids.append(q["id"])
|
||||
|
||||
index: dict[str, dict] = {}
|
||||
for sub in submissions:
|
||||
email = _extract_email_from_submission(sub, email_question_ids)
|
||||
if email:
|
||||
index[email.lower()] = {
|
||||
"responses": sub.get("responses", sub.get("fields", [])),
|
||||
"submitted_at": sub.get("submittedAt", sub.get("createdAt", "")),
|
||||
"questions": sub.get("questions", []),
|
||||
}
|
||||
return index
|
||||
|
||||
|
||||
def _extract_email_from_submission(
|
||||
submission: dict, email_question_ids: list[str]
|
||||
) -> Optional[str]:
|
||||
"""Extract email address from a submission by checking respondentEmail, then field responses."""
|
||||
# Try respondent email first (Tally often includes this)
|
||||
respondent_email = submission.get("respondentEmail")
|
||||
if respondent_email:
|
||||
return respondent_email
|
||||
|
||||
# Search through responses/fields for matching question IDs
|
||||
responses = submission.get("responses", submission.get("fields", []))
|
||||
if isinstance(responses, list):
|
||||
for resp in responses:
|
||||
q_id = resp.get("questionId") or resp.get("key") or resp.get("id")
|
||||
if q_id in email_question_ids:
|
||||
value = resp.get("value") or resp.get("answer")
|
||||
if isinstance(value, str) and "@" in value:
|
||||
return value
|
||||
elif isinstance(responses, dict):
|
||||
for q_id in email_question_ids:
|
||||
value = responses.get(q_id)
|
||||
if isinstance(value, str) and "@" in value:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _get_cached_index(
|
||||
form_id: str,
|
||||
) -> tuple[Optional[dict], Optional[list]]:
|
||||
"""Return (email_index, questions) from Redis, or (None, None) on cache miss."""
|
||||
redis = await get_redis_async()
|
||||
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
|
||||
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
|
||||
|
||||
raw_index = await redis.get(index_key)
|
||||
raw_questions = await redis.get(questions_key)
|
||||
|
||||
if raw_index and raw_questions:
|
||||
return json.loads(raw_index), json.loads(raw_questions)
|
||||
return None, None
|
||||
|
||||
|
||||
async def _refresh_cache(form_id: str) -> tuple[dict, list]:
|
||||
"""Refresh the Tally submission cache. Uses incremental fetch when possible.
|
||||
|
||||
Returns (email_index, questions).
|
||||
"""
|
||||
settings = Settings()
|
||||
client = _make_tally_client(settings.secrets.tally_api_key)
|
||||
|
||||
redis = await get_redis_async()
|
||||
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
|
||||
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
|
||||
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
|
||||
|
||||
last_fetch = await redis.get(last_fetch_key)
|
||||
|
||||
if last_fetch:
|
||||
# Try to load existing index for incremental merge
|
||||
raw_existing = await redis.get(index_key)
|
||||
|
||||
if raw_existing is None:
|
||||
# Index expired but last_fetch still present — fall back to full fetch
|
||||
logger.info("Tally: last_fetch present but index missing, doing full fetch")
|
||||
questions, submissions = await _fetch_all_submissions(client, form_id)
|
||||
email_index = _build_email_index(submissions, questions)
|
||||
else:
|
||||
# Incremental fetch: only get new submissions since last fetch
|
||||
logger.info(f"Tally incremental fetch since {last_fetch}")
|
||||
questions, new_submissions = await _fetch_all_submissions(
|
||||
client, form_id, start_date=last_fetch
|
||||
)
|
||||
|
||||
existing_index: dict[str, dict] = json.loads(raw_existing)
|
||||
|
||||
if not questions:
|
||||
raw_q = await redis.get(questions_key)
|
||||
if raw_q:
|
||||
questions = json.loads(raw_q)
|
||||
|
||||
new_index = _build_email_index(new_submissions, questions)
|
||||
existing_index.update(new_index)
|
||||
email_index = existing_index
|
||||
else:
|
||||
# Full initial fetch
|
||||
logger.info("Tally full initial fetch")
|
||||
questions, submissions = await _fetch_all_submissions(client, form_id)
|
||||
email_index = _build_email_index(submissions, questions)
|
||||
|
||||
# Store in Redis
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await redis.setex(index_key, _INDEX_TTL, json.dumps(email_index))
|
||||
await redis.setex(questions_key, _INDEX_TTL, json.dumps(questions))
|
||||
await redis.setex(last_fetch_key, _LAST_FETCH_TTL, now)
|
||||
|
||||
logger.info(f"Tally cache refreshed: {len(email_index)} emails indexed")
|
||||
return email_index, questions
|
||||
|
||||
|
||||
async def find_submission_by_email(
|
||||
form_id: str, email: str
|
||||
) -> Optional[tuple[dict, list]]:
|
||||
"""Look up a Tally submission by email. Uses cache when available.
|
||||
|
||||
Returns (submission_data, questions) or None.
|
||||
"""
|
||||
email_lower = email.lower()
|
||||
|
||||
# Try cache first
|
||||
email_index, questions = await _get_cached_index(form_id)
|
||||
if email_index is not None and questions is not None:
|
||||
sub = email_index.get(email_lower)
|
||||
if sub is not None:
|
||||
return sub, questions
|
||||
return None
|
||||
|
||||
# Cache miss - refresh
|
||||
email_index, questions = await _refresh_cache(form_id)
|
||||
sub = email_index.get(email_lower)
|
||||
if sub is not None:
|
||||
return sub, questions
|
||||
return None
|
||||
|
||||
|
||||
def format_submission_for_llm(submission: dict, questions: list[dict]) -> str:
|
||||
"""Format a submission as readable Q&A text for LLM consumption."""
|
||||
# Build question ID -> title lookup
|
||||
q_titles: dict[str, str] = {}
|
||||
for q in questions:
|
||||
q_id = q.get("id", "")
|
||||
title = q.get("label") or q.get("title") or q.get("name") or f"Question {q_id}"
|
||||
q_titles[q_id] = title
|
||||
|
||||
lines: list[str] = []
|
||||
responses = submission.get("responses", [])
|
||||
|
||||
if isinstance(responses, list):
|
||||
for resp in responses:
|
||||
q_id = resp.get("questionId") or resp.get("key") or resp.get("id") or ""
|
||||
title = q_titles.get(q_id, f"Question {q_id}")
|
||||
value = resp.get("value") or resp.get("answer") or ""
|
||||
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
|
||||
elif isinstance(responses, dict):
|
||||
for q_id, value in responses.items():
|
||||
title = q_titles.get(q_id, f"Question {q_id}")
|
||||
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def _format_answer(value: object) -> str:
|
||||
"""Convert an answer value (str, list, dict, None) to a human-readable string."""
|
||||
if value is None:
|
||||
return "(no answer)"
|
||||
if isinstance(value, list):
|
||||
return ", ".join(str(v) for v in value)
|
||||
if isinstance(value, dict):
|
||||
parts = [f"{k}: {v}" for k, v in value.items() if v]
|
||||
return "; ".join(parts) if parts else "(no answer)"
|
||||
return str(value)
|
||||
|
||||
|
||||
_EXTRACTION_PROMPT = """\
|
||||
You are a business analyst. Given the following form submission data, extract structured business understanding information.
|
||||
|
||||
Return a JSON object with ONLY the fields that can be confidently extracted. Use null for fields that cannot be determined.
|
||||
|
||||
Fields:
|
||||
- user_name (string): the person's name
|
||||
- job_title (string): their job title
|
||||
- business_name (string): company/business name
|
||||
- industry (string): industry or sector
|
||||
- business_size (string): company size e.g. "1-10", "11-50", "51-200"
|
||||
- user_role (string): their role context e.g. "decision maker", "implementer"
|
||||
- key_workflows (list of strings): key business workflows
|
||||
- daily_activities (list of strings): daily activities performed
|
||||
- pain_points (list of strings): current pain points
|
||||
- bottlenecks (list of strings): process bottlenecks
|
||||
- manual_tasks (list of strings): manual/repetitive tasks
|
||||
- automation_goals (list of strings): desired automation goals
|
||||
- current_software (list of strings): software/tools currently used
|
||||
- existing_automation (list of strings): existing automations
|
||||
- additional_notes (string): any additional context
|
||||
|
||||
Form data:
|
||||
"""
|
||||
|
||||
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
|
||||
|
||||
|
||||
async def extract_business_understanding(
|
||||
formatted_text: str,
|
||||
) -> BusinessUnderstandingInput:
|
||||
"""Use an LLM to extract structured business understanding from form text.
|
||||
|
||||
Raises on timeout or unparseable response so the caller can handle it.
|
||||
"""
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.open_router_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{_EXTRACTION_PROMPT}{formatted_text}{_EXTRACTION_SUFFIX}",
|
||||
}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.0,
|
||||
),
|
||||
timeout=_LLM_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Tally: LLM extraction timed out")
|
||||
raise
|
||||
|
||||
raw = response.choices[0].message.content or "{}"
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Tally: LLM returned invalid JSON, skipping extraction")
|
||||
raise
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
"""Main orchestrator: check Tally for a matching submission and populate understanding.
|
||||
|
||||
Fire-and-forget safe — all exceptions are caught and logged.
|
||||
"""
|
||||
try:
|
||||
# Check if understanding already exists (idempotency)
|
||||
existing = await get_business_understanding(user_id)
|
||||
if existing is not None:
|
||||
logger.debug(
|
||||
f"Tally: user {user_id} already has business understanding, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Check API key is configured
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key:
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Tally: error populating understanding for user {user_id}")
|
||||
589
autogpt_platform/backend/backend/data/tally_test.py
Normal file
589
autogpt_platform/backend/backend/data/tally_test.py
Normal file
@@ -0,0 +1,589 @@
|
||||
"""Tests for backend.data.tally module."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.tally import (
|
||||
_EXTRACTION_PROMPT,
|
||||
_EXTRACTION_SUFFIX,
|
||||
_build_email_index,
|
||||
_format_answer,
|
||||
_make_tally_client,
|
||||
_mask_email,
|
||||
_refresh_cache,
|
||||
extract_business_understanding,
|
||||
find_submission_by_email,
|
||||
format_submission_for_llm,
|
||||
populate_understanding_from_tally,
|
||||
)
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
SAMPLE_QUESTIONS = [
|
||||
{"id": "q1", "label": "What is your name?", "type": "INPUT_TEXT"},
|
||||
{"id": "q2", "label": "Email address", "type": "INPUT_EMAIL"},
|
||||
{"id": "q3", "label": "Company name", "type": "INPUT_TEXT"},
|
||||
{"id": "q4", "label": "Industry", "type": "INPUT_TEXT"},
|
||||
]
|
||||
|
||||
SAMPLE_SUBMISSIONS = [
|
||||
{
|
||||
"respondentEmail": None,
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Alice Smith"},
|
||||
{"questionId": "q2", "value": "alice@example.com"},
|
||||
{"questionId": "q3", "value": "Acme Corp"},
|
||||
{"questionId": "q4", "value": "Technology"},
|
||||
],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
},
|
||||
{
|
||||
"respondentEmail": "bob@example.com",
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Bob Jones"},
|
||||
{"questionId": "q2", "value": "bob@example.com"},
|
||||
{"questionId": "q3", "value": "Bob's Burgers"},
|
||||
{"questionId": "q4", "value": "Food"},
|
||||
],
|
||||
"submittedAt": "2025-01-16T10:00:00Z",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── _build_email_index ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_build_email_index():
|
||||
index = _build_email_index(SAMPLE_SUBMISSIONS, SAMPLE_QUESTIONS)
|
||||
assert "alice@example.com" in index
|
||||
assert "bob@example.com" in index
|
||||
assert len(index) == 2
|
||||
|
||||
|
||||
def test_build_email_index_case_insensitive():
|
||||
submissions = [
|
||||
{
|
||||
"respondentEmail": None,
|
||||
"responses": [
|
||||
{"questionId": "q2", "value": "Alice@Example.COM"},
|
||||
],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
},
|
||||
]
|
||||
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
|
||||
assert "alice@example.com" in index
|
||||
assert "Alice@Example.COM" not in index
|
||||
|
||||
|
||||
def test_build_email_index_empty():
|
||||
index = _build_email_index([], SAMPLE_QUESTIONS)
|
||||
assert index == {}
|
||||
|
||||
|
||||
def test_build_email_index_no_email_field():
|
||||
questions = [{"id": "q1", "label": "Name", "type": "INPUT_TEXT"}]
|
||||
submissions = [
|
||||
{
|
||||
"responses": [{"questionId": "q1", "value": "Alice"}],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
}
|
||||
]
|
||||
index = _build_email_index(submissions, questions)
|
||||
assert index == {}
|
||||
|
||||
|
||||
def test_build_email_index_respondent_email():
|
||||
"""respondentEmail takes precedence over field scanning."""
|
||||
submissions = [
|
||||
{
|
||||
"respondentEmail": "direct@example.com",
|
||||
"responses": [
|
||||
{"questionId": "q2", "value": "field@example.com"},
|
||||
],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
}
|
||||
]
|
||||
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
|
||||
assert "direct@example.com" in index
|
||||
assert "field@example.com" not in index
|
||||
|
||||
|
||||
# ── format_submission_for_llm ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_format_submission_for_llm():
|
||||
submission = {
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Alice Smith"},
|
||||
{"questionId": "q3", "value": "Acme Corp"},
|
||||
],
|
||||
}
|
||||
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
|
||||
assert "Q: What is your name?" in result
|
||||
assert "A: Alice Smith" in result
|
||||
assert "Q: Company name" in result
|
||||
assert "A: Acme Corp" in result
|
||||
|
||||
|
||||
def test_format_submission_for_llm_dict_responses():
|
||||
submission = {
|
||||
"responses": {
|
||||
"q1": "Alice Smith",
|
||||
"q3": "Acme Corp",
|
||||
},
|
||||
}
|
||||
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
|
||||
assert "A: Alice Smith" in result
|
||||
assert "A: Acme Corp" in result
|
||||
|
||||
|
||||
def test_format_answer_types():
|
||||
assert _format_answer(None) == "(no answer)"
|
||||
assert _format_answer("hello") == "hello"
|
||||
assert _format_answer(["a", "b"]) == "a, b"
|
||||
assert _format_answer({"key": "val"}) == "key: val"
|
||||
assert _format_answer(42) == "42"
|
||||
|
||||
|
||||
# ── find_submission_by_email ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_submission_by_email_cache_hit():
|
||||
cached_index = {
|
||||
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
|
||||
}
|
||||
cached_questions = SAMPLE_QUESTIONS
|
||||
|
||||
with patch(
|
||||
"backend.data.tally._get_cached_index",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(cached_index, cached_questions),
|
||||
) as mock_cache:
|
||||
result = await find_submission_by_email("form123", "alice@example.com")
|
||||
|
||||
mock_cache.assert_awaited_once_with("form123")
|
||||
assert result is not None
|
||||
sub, questions = result
|
||||
assert sub["submitted_at"] == "2025-01-15"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_submission_by_email_cache_miss():
|
||||
refreshed_index = {
|
||||
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally._get_cached_index",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally._refresh_cache",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(refreshed_index, SAMPLE_QUESTIONS),
|
||||
) as mock_refresh,
|
||||
):
|
||||
result = await find_submission_by_email("form123", "alice@example.com")
|
||||
|
||||
mock_refresh.assert_awaited_once_with("form123")
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_submission_by_email_no_match():
|
||||
cached_index = {
|
||||
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.data.tally._get_cached_index",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(cached_index, SAMPLE_QUESTIONS),
|
||||
):
|
||||
result = await find_submission_by_email("form123", "unknown@example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── populate_understanding_from_tally ─────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_skips_existing():
|
||||
"""If user already has understanding, skip entirely."""
|
||||
mock_understanding = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_understanding,
|
||||
) as mock_get,
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_find,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "test@example.com")
|
||||
|
||||
mock_get.assert_awaited_once_with("user-1")
|
||||
mock_find.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_skips_no_api_key():
|
||||
"""If no Tally API key, skip gracefully."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = ""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_find,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "test@example.com")
|
||||
|
||||
mock_find.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_handles_errors():
|
||||
"""Must never raise, even on unexpected errors."""
|
||||
with patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("DB down"),
|
||||
):
|
||||
# Should not raise
|
||||
await populate_understanding_from_tally("user-1", "test@example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_full_flow():
|
||||
"""Happy path: no existing understanding, finds submission, extracts, upserts."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
submission = {
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Alice"},
|
||||
{"questionId": "q3", "value": "Acme"},
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
) as mock_extract,
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_upsert,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once_with("user-1", mock_input)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_handles_llm_timeout():
|
||||
"""LLM timeout is caught and doesn't raise."""
|
||||
import asyncio
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
submission = {
|
||||
"responses": [{"questionId": "q1", "value": "Alice"}],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_upsert,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
# ── _mask_email ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mask_email():
|
||||
assert _mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert _mask_email("ab@example.com") == "a***@example.com"
|
||||
assert _mask_email("a@example.com") == "a***@example.com"
|
||||
|
||||
|
||||
def test_mask_email_invalid():
|
||||
assert _mask_email("no-at-sign") == "***"
|
||||
|
||||
|
||||
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
|
||||
|
||||
|
||||
def test_extraction_prompt_safe_with_curly_braces():
|
||||
"""User content with curly braces must not break prompt construction.
|
||||
|
||||
Previously _EXTRACTION_PROMPT.format(submission_text=...) would raise
|
||||
KeyError/ValueError if the user text contained { or }.
|
||||
"""
|
||||
text_with_braces = "Q: What tools do you use?\nA: We use {Slack} and {{Jira}}"
|
||||
# This must not raise — the old .format() call would fail here
|
||||
prompt = f"{_EXTRACTION_PROMPT}{text_with_braces}{_EXTRACTION_SUFFIX}"
|
||||
assert text_with_braces in prompt
|
||||
assert prompt.startswith("You are a business analyst.")
|
||||
assert prompt.endswith("Return ONLY valid JSON.")
|
||||
|
||||
|
||||
def test_extraction_prompt_no_format_placeholders():
|
||||
"""_EXTRACTION_PROMPT must not contain Python format placeholders."""
|
||||
assert "{submission_text}" not in _EXTRACTION_PROMPT
|
||||
# Ensure no stray single-brace placeholders
|
||||
# (double braces {{ are fine — they're literal in format strings)
|
||||
import re
|
||||
|
||||
single_braces = re.findall(r"(?<!\{)\{[^{].*?\}(?!\})", _EXTRACTION_PROMPT)
|
||||
assert single_braces == [], f"Found format placeholders: {single_braces}"
|
||||
|
||||
|
||||
# ── extract_business_understanding ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_success():
|
||||
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"business_name": "Acme Corp",
|
||||
"industry": "Technology",
|
||||
"pain_points": ["manual reporting"],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name == "Acme Corp"
|
||||
assert result.industry == "Technology"
|
||||
assert result.pain_points == ["manual reporting"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_filters_nulls():
|
||||
"""Null values from LLM should be excluded from the result."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{"user_name": "Alice", "business_name": None, "industry": None}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name is None
|
||||
assert result.industry is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_invalid_json():
|
||||
"""Invalid JSON from LLM should raise JSONDecodeError."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "not valid json {"
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
|
||||
pytest.raises(json.JSONDecodeError),
|
||||
):
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_timeout():
|
||||
"""LLM timeout should propagate as asyncio.TimeoutError."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
|
||||
|
||||
with (
|
||||
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
|
||||
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
|
||||
pytest.raises(asyncio.TimeoutError),
|
||||
):
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
# ── _refresh_cache ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_cache_full_fetch():
|
||||
"""First fetch (no last_fetch in Redis) should do a full fetch and store in Redis."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get.return_value = None # No last_fetch, no cached index
|
||||
|
||||
questions = SAMPLE_QUESTIONS
|
||||
submissions = SAMPLE_SUBMISSIONS
|
||||
|
||||
with (
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally._fetch_all_submissions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(questions, submissions),
|
||||
) as mock_fetch,
|
||||
):
|
||||
index, returned_questions = await _refresh_cache("form123")
|
||||
|
||||
mock_fetch.assert_awaited_once()
|
||||
assert "alice@example.com" in index
|
||||
assert "bob@example.com" in index
|
||||
assert returned_questions == questions
|
||||
# Verify Redis setex was called for index, questions, and last_fetch
|
||||
assert mock_redis.setex.await_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_cache_incremental_fetch():
|
||||
"""When last_fetch and index both exist, should do incremental fetch and merge."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
existing_index = {
|
||||
"old@example.com": {"responses": [], "submitted_at": "2025-01-01"}
|
||||
}
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
def mock_get(key):
|
||||
if "last_fetch" in key:
|
||||
return "2025-01-14T00:00:00Z"
|
||||
if "email_index" in key:
|
||||
return json.dumps(existing_index)
|
||||
if "questions" in key:
|
||||
return json.dumps(SAMPLE_QUESTIONS)
|
||||
return None
|
||||
|
||||
mock_redis.get.side_effect = mock_get
|
||||
|
||||
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
|
||||
|
||||
with (
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally._fetch_all_submissions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(SAMPLE_QUESTIONS, new_submissions),
|
||||
),
|
||||
):
|
||||
index, _ = await _refresh_cache("form123")
|
||||
|
||||
# Should contain both old and new entries
|
||||
assert "old@example.com" in index
|
||||
assert "alice@example.com" in index
|
||||
|
||||
|
||||
# ── _make_tally_client ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_make_tally_client_returns_configured_client():
|
||||
"""_make_tally_client should create a Requests client with auth headers."""
|
||||
client = _make_tally_client("test-api-key")
|
||||
assert client.extra_headers is not None
|
||||
assert client.extra_headers.get("Authorization") == "Bearer test-api-key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_tally_page_uses_provided_client():
|
||||
"""_fetch_tally_page should use the passed client, not create its own."""
|
||||
from backend.data.tally import _fetch_tally_page
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"submissions": [], "questions": []}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
result = await _fetch_tally_page(mock_client, "form123", page=1)
|
||||
|
||||
mock_client.get.assert_awaited_once()
|
||||
call_url = mock_client.get.call_args[0][0]
|
||||
assert "form123" in call_url
|
||||
assert "page=1" in call_url
|
||||
assert result == {"submissions": [], "questions": []}
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -7,6 +8,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,3 +128,124 @@ class ClusterLock:
|
||||
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
|
||||
class AsyncClusterLock:
|
||||
"""Async Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(
|
||||
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
|
||||
):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
self._refresh_lock = asyncio.Lock()
|
||||
|
||||
async def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = await self.redis.set(
|
||||
self.key, self.owner_id, nx=True, ex=self.timeout
|
||||
)
|
||||
if success:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = await self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AsyncClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
async def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
|
||||
Async-safe: uses asyncio.Lock to protect _last_refresh access.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period (async-safe read)
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
async with self._refresh_lock:
|
||||
last_refresh = self._last_refresh
|
||||
is_rate_limited = (
|
||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = await self.redis.get(self.key)
|
||||
if not current_value:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if await self.redis.expire(self.key, self.timeout):
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AsyncClusterLock.refresh failed for key {self.key}: {e}")
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
async def release(self):
|
||||
"""Release the lock."""
|
||||
async with self._refresh_lock:
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
@@ -535,14 +535,18 @@ async def _summarize_messages_llm(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Create a detailed summary of the conversation so far. "
|
||||
"Create a factual summary of the conversation so far. "
|
||||
"This summary will be used as context when continuing the conversation.\n\n"
|
||||
"CRITICAL: Only include information that is EXPLICITLY present in the "
|
||||
"conversation. Do NOT fabricate, infer, or invent any details. "
|
||||
"If a section has no relevant content in the conversation, skip it entirely.\n\n"
|
||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
||||
"- User requests and their explicit goals\n"
|
||||
"- Your approach and key decisions made\n"
|
||||
"- Actions taken and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"You MUST include ALL of the following sections:\n\n"
|
||||
"Include ONLY the sections below that have relevant content "
|
||||
"(skip sections with nothing to report):\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
||||
"## 2. Key Technical Concepts\n"
|
||||
@@ -550,19 +554,14 @@ async def _summarize_messages_llm(
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions. "
|
||||
"Include any user feedback on fixes.\n\n"
|
||||
"## 5. Problem Solving\n"
|
||||
"Issues that have been resolved and how they were addressed.\n\n"
|
||||
"## 6. All User Messages\n"
|
||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
||||
"## 7. Pending Tasks\n"
|
||||
"Problems encountered, error messages, and their resolutions.\n\n"
|
||||
"## 5. All User Messages\n"
|
||||
"A complete list of all user inputs (excluding tool outputs) "
|
||||
"to preserve their exact requests.\n\n"
|
||||
"## 6. Pending Tasks\n"
|
||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
||||
"## 8. Current Work\n"
|
||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
||||
"## 9. Next Steps\n"
|
||||
"What should happen next, aligned with the user's most recent requests. "
|
||||
"Include verbatim quotes of recent instructions if relevant."
|
||||
"## 7. Current State\n"
|
||||
"What was happening most recently in the conversation."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
|
||||
@@ -372,8 +372,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for the Agent Generator service",
|
||||
)
|
||||
agentgenerator_timeout: int = Field(
|
||||
default=600,
|
||||
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||
default=30,
|
||||
description="The timeout in seconds for individual Agent Generator HTTP requests (submit and poll)",
|
||||
)
|
||||
agentgenerator_use_dummy: bool = Field(
|
||||
default=False,
|
||||
@@ -691,6 +691,15 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
|
||||
screenshotone_api_key: str = Field(default="", description="ScreenshotOne API Key")
|
||||
|
||||
tally_api_key: str = Field(
|
||||
default="",
|
||||
description="Tally API key for form submission lookup on signup",
|
||||
)
|
||||
tally_form_id: str = Field(
|
||||
default="npGe0q",
|
||||
description="Tally form ID for signup business understanding form",
|
||||
)
|
||||
|
||||
apollo_api_key: str = Field(default="", description="Apollo API Key")
|
||||
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
|
||||
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")
|
||||
|
||||
14
autogpt_platform/backend/poetry.lock
generated
14
autogpt_platform/backend/poetry.lock
generated
@@ -899,17 +899,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "claude-agent-sdk"
|
||||
version = "0.1.35"
|
||||
version = "0.1.39"
|
||||
description = "Python SDK for Claude Code"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
|
||||
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ed6a79781f545b761b9fe467bc5ae213a103c9d3f0fe7a9dad3c01790ed58fa"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0c03b5a3772eaec42e29ea39240c7d24b760358082f2e36336db9e71dde3dda4"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d2665c9e87b6ffece590bcdd6eb9def47cde4809b0d2f66e0a61a719189be7c9"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-win_amd64.whl", hash = "sha256:d03324daf7076be79d2dd05944559aabf4cc11c98d3a574b992a442a7c7a26d6"},
|
||||
{file = "claude_agent_sdk-0.1.39.tar.gz", hash = "sha256:dcf0ebd5a638c9a7d9f3af7640932a9212b2705b7056e4f08bd3968a865b4268"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -8530,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
|
||||
content-hash = "3ef62836d8321b9a3b8e897dade8dc6ca9022fd9468c53f384b0871b521ab343"
|
||||
|
||||
@@ -16,7 +16,7 @@ anthropic = "^0.79.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
claude-agent-sdk = "^0.1.0"
|
||||
claude-agent-sdk = "^0.1.39" # see copilot/sdk/sdk_compat_test.py for capability checks
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestGenerateAgent:
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await core.generate_agent(instructions)
|
||||
|
||||
mock_external.assert_called_once_with(instructions, None, None, None)
|
||||
mock_external.assert_called_once_with(instructions, None)
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Agent"
|
||||
assert "id" in result
|
||||
@@ -173,9 +173,7 @@ class TestGenerateAgentPatch:
|
||||
current_agent = {"nodes": [], "links": []}
|
||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||
|
||||
mock_external.assert_called_once_with(
|
||||
"Add a node", current_agent, None, None, None
|
||||
)
|
||||
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Tests for the Agent Generator external service client.
|
||||
|
||||
This test suite verifies the external Agent Generator service integration,
|
||||
including service detection, API calls, and error handling.
|
||||
including service detection, async polling, and error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -49,6 +49,292 @@ class TestServiceConfiguration:
|
||||
assert url == "http://agent-generator.local:8000"
|
||||
|
||||
|
||||
class TestSubmitAndPoll:
|
||||
"""Test the _submit_and_poll helper that handles async job polling."""
|
||||
|
||||
def setup_method(self):
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_submit_and_poll(self):
|
||||
"""Test normal submit -> poll -> completed flow."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-123", "status": "accepted"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = {
|
||||
"job_id": "job-123",
|
||||
"status": "completed",
|
||||
"result": {"type": "instructions", "steps": ["Step 1"]},
|
||||
}
|
||||
poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.return_value = poll_resp
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {"key": "value"})
|
||||
|
||||
assert result == {"type": "instructions", "steps": ["Step 1"]}
|
||||
mock_client.post.assert_called_once_with("/api/test", json={"key": "value"})
|
||||
mock_client.get.assert_called_once_with("/api/jobs/job-123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_returns_failed_job(self):
|
||||
"""Test submit -> poll -> failed flow."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-456", "status": "accepted"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = {
|
||||
"job_id": "job-456",
|
||||
"status": "failed",
|
||||
"error": "Generation failed",
|
||||
}
|
||||
poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.return_value = poll_resp
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "job_failed"
|
||||
assert "Generation failed" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_http_error(self):
|
||||
"""Test HTTP error during job submission."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "http_error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_connection_error(self):
|
||||
"""Test connection error during job submission."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "connection_error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_job_id_in_submit_response(self):
|
||||
"""Test submit response missing job_id."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"status": "accepted"} # no job_id
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "invalid_response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_retries_on_transient_network_error(self):
|
||||
"""Test that transient network errors during polling are retried."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-789"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
ok_poll_resp = MagicMock()
|
||||
ok_poll_resp.json.return_value = {
|
||||
"job_id": "job-789",
|
||||
"status": "completed",
|
||||
"result": {"data": "ok"},
|
||||
}
|
||||
ok_poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
# First poll fails with transient error, second succeeds
|
||||
mock_client.get.side_effect = [
|
||||
httpx.RequestError("transient"),
|
||||
ok_poll_resp,
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result == {"data": "ok"}
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_returns_404_for_expired_job(self):
|
||||
"""Test that 404 during polling returns job_not_found error."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-expired"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "job_not_found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_retries_on_transient_http_status(self):
|
||||
"""Test that transient HTTP status codes (429, 503, etc.) are retried."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-transient"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_429_response = MagicMock()
|
||||
mock_429_response.status_code = 429
|
||||
|
||||
ok_poll_resp = MagicMock()
|
||||
ok_poll_resp.json.return_value = {
|
||||
"job_id": "job-transient",
|
||||
"status": "completed",
|
||||
"result": {"data": "recovered"},
|
||||
}
|
||||
ok_poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = [
|
||||
httpx.HTTPStatusError(
|
||||
"Too Many Requests", request=MagicMock(), response=mock_429_response
|
||||
),
|
||||
ok_poll_resp,
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result == {"data": "recovered"}
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_does_not_retry_non_transient_http_status(self):
|
||||
"""Test that non-transient HTTP status codes (e.g. 500) fail immediately."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-500"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_500_response = MagicMock()
|
||||
mock_500_response.status_code = 500
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = httpx.HTTPStatusError(
|
||||
"Internal Server Error", request=MagicMock(), response=mock_500_response
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "http_error"
|
||||
assert mock_client.get.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_timeout(self):
|
||||
"""Test that polling times out after MAX_POLL_TIME_SECONDS."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-slow"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
running_resp = MagicMock()
|
||||
running_resp.json.return_value = {"job_id": "job-slow", "status": "running"}
|
||||
running_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.return_value = running_resp
|
||||
|
||||
# Simulate time passing: first call returns 0.0 (start), then jumps past limit
|
||||
monotonic_values = iter([0.0, 0.0, 100.0])
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch.object(service, "MAX_POLL_TIME_SECONDS", 50.0),
|
||||
patch.object(service, "POLL_INTERVAL_SECONDS", 0.01),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
patch("backend.copilot.tools.agent_generator.service.time") as mock_time,
|
||||
):
|
||||
mock_time.monotonic.side_effect = monotonic_values
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "timeout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_gives_up_after_consecutive_transient_errors(self):
|
||||
"""Test that polling gives up after MAX_CONSECUTIVE_POLL_ERRORS."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-flaky"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = httpx.RequestError("network down")
|
||||
|
||||
# Ensure monotonic always returns 0 so timeout doesn't kick in
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch.object(service, "MAX_POLL_TIME_SECONDS", 9999.0),
|
||||
patch.object(service, "POLL_INTERVAL_SECONDS", 0.01),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
patch("backend.copilot.tools.agent_generator.service.time") as mock_time,
|
||||
):
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "poll_error"
|
||||
assert mock_client.get.call_count == service.MAX_CONSECUTIVE_POLL_ERRORS
|
||||
|
||||
|
||||
class TestDecomposeGoalExternal:
|
||||
"""Test decompose_goal_external function."""
|
||||
|
||||
@@ -60,40 +346,37 @@ class TestDecomposeGoalExternal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_instructions(self):
|
||||
"""Test successful decomposition returning instructions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1", "Step 2"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1", "Step 2"],
|
||||
}
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description", json={"description": "Build a chatbot"}
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
{"description": "Build a chatbot"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_clarifying_questions(self):
|
||||
"""Test decomposition returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
result = await service.decompose_goal_external("Build something")
|
||||
|
||||
assert result == {
|
||||
@@ -104,18 +387,13 @@ class TestDecomposeGoalExternal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_with_context(self):
|
||||
"""Test decomposition with additional context enriched into description."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
|
||||
await service.decompose_goal_external(
|
||||
"Build a chatbot", context="Use Python"
|
||||
)
|
||||
@@ -123,27 +401,25 @@ class TestDecomposeGoalExternal:
|
||||
expected_description = (
|
||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
||||
)
|
||||
mock_client.post.assert_called_once_with(
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
json={"description": expected_description},
|
||||
{"description": expected_description},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_unachievable_goal(self):
|
||||
"""Test decomposition returning unachievable goal response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
result = await service.decompose_goal_external("Do something impossible")
|
||||
|
||||
assert result == {
|
||||
@@ -153,58 +429,40 @@ class TestDecomposeGoalExternal:
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_http_error(self):
|
||||
"""Test decomposition handles HTTP errors gracefully."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
async def test_decompose_goal_handles_poll_error(self):
|
||||
"""Test that errors from _submit_and_poll are passed through."""
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "error",
|
||||
"error": "HTTP error calling Agent Generator: Server error",
|
||||
"error_type": "http_error",
|
||||
}
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "http_error"
|
||||
assert "Server error" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_request_error(self):
|
||||
"""Test decomposition handles request errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
async def test_decompose_goal_handles_unexpected_exception(self):
|
||||
"""Test that unexpected exceptions are caught and returned as errors."""
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.side_effect = RuntimeError("unexpected")
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "connection_error"
|
||||
assert "Connection failed" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_service_error(self):
|
||||
"""Test decomposition handles service returning error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": False,
|
||||
"error": "Internal error",
|
||||
"error_type": "internal_error",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error") == "Internal error"
|
||||
assert result.get("error_type") == "internal_error"
|
||||
assert result.get("error_type") == "unexpected_error"
|
||||
|
||||
|
||||
class TestGenerateAgentExternal:
|
||||
@@ -223,39 +481,59 @@ class TestGenerateAgentExternal:
|
||||
"nodes": [],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": agent_json,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"success": True, "agent_json": agent_json}
|
||||
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await service.generate_agent_external(instructions)
|
||||
|
||||
assert result == agent_json
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/generate-agent",
|
||||
{"instructions": instructions},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_handles_error(self):
|
||||
"""Test agent generation handles errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "error",
|
||||
"error": "Connection failed",
|
||||
"error_type": "connection_error",
|
||||
}
|
||||
result = await service.generate_agent_external({"steps": []})
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "connection_error"
|
||||
assert "Connection failed" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_missing_agent_json(self):
|
||||
"""Test that missing agent_json in result returns an error."""
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"success": True}
|
||||
result = await service.generate_agent_external({"steps": ["Step 1"]})
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "invalid_response"
|
||||
|
||||
|
||||
class TestGenerateAgentPatchExternal:
|
||||
@@ -274,27 +552,24 @@ class TestGenerateAgentPatchExternal:
|
||||
"nodes": [{"id": "1", "block_id": "test"}],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": updated_agent,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"success": True, "agent_json": updated_agent}
|
||||
|
||||
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add a new node", current_agent
|
||||
)
|
||||
|
||||
assert result == updated_agent
|
||||
mock_client.post.assert_called_once_with(
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
{
|
||||
"update_request": "Add a new node",
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
@@ -303,18 +578,16 @@ class TestGenerateAgentPatchExternal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_patch_returns_clarifying_questions(self):
|
||||
"""Test patch generation returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add something", {"nodes": []}
|
||||
)
|
||||
@@ -355,9 +628,12 @@ class TestHealthCheck:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
with (
|
||||
patch.object(service, "is_external_service_configured", return_value=True),
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is True
|
||||
mock_client.get.assert_called_once_with("/health")
|
||||
@@ -375,9 +651,12 @@ class TestHealthCheck:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
with (
|
||||
patch.object(service, "is_external_service_configured", return_value=True),
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -387,9 +666,12 @@ class TestHealthCheck:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
with (
|
||||
patch.object(service, "is_external_service_configured", return_value=True),
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -419,7 +701,10 @@ class TestGetBlocksExternal:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result == blocks
|
||||
@@ -431,7 +716,10 @@ class TestGetBlocksExternal:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result is None
|
||||
@@ -459,26 +747,22 @@ class TestLibraryAgentsPassthrough:
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
|
||||
await service.decompose_goal_external(
|
||||
"Send an email",
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert payload["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_passes_library_agents(self):
|
||||
@@ -494,25 +778,24 @@ class TestLibraryAgentsPassthrough:
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||
}
|
||||
await service.generate_agent_external(
|
||||
{"steps": ["Step 1"]},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert payload["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_patch_passes_library_agents(self):
|
||||
@@ -528,17 +811,15 @@ class TestLibraryAgentsPassthrough:
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||
}
|
||||
await service.generate_agent_patch_external(
|
||||
"Add error handling",
|
||||
{"name": "Original Agent", "nodes": []},
|
||||
@@ -546,29 +827,26 @@ class TestLibraryAgentsPassthrough:
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert payload["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_without_library_agents(self):
|
||||
"""Test that decompose goal works without library_agents."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
|
||||
await service.decompose_goal_external("Build a workflow")
|
||||
|
||||
# Verify library_agents was NOT passed when not provided
|
||||
call_args = mock_client.post.call_args
|
||||
assert "library_agents" not in call_args[1]["json"]
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert "library_agents" not in payload
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration test for the requeue fix implementation.
|
||||
Tests actual RabbitMQ behavior to verify that republishing sends messages to back of queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from threading import Event
|
||||
from typing import List
|
||||
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
|
||||
|
||||
class QueueOrderTester:
|
||||
"""Helper class to test message ordering in RabbitMQ using a dedicated test queue."""
|
||||
|
||||
def __init__(self):
|
||||
self.received_messages: List[dict] = []
|
||||
self.stop_consuming = Event()
|
||||
self.queue_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
self.queue_client.connect()
|
||||
|
||||
# Use a dedicated test queue name to avoid conflicts
|
||||
self.test_queue_name = "test_requeue_ordering"
|
||||
self.test_exchange = "test_exchange"
|
||||
self.test_routing_key = "test.requeue"
|
||||
|
||||
def setup_queue(self):
|
||||
"""Set up a dedicated test queue for testing."""
|
||||
channel = self.queue_client.get_channel()
|
||||
|
||||
# Declare test exchange
|
||||
channel.exchange_declare(
|
||||
exchange=self.test_exchange, exchange_type="direct", durable=True
|
||||
)
|
||||
|
||||
# Declare test queue
|
||||
channel.queue_declare(
|
||||
queue=self.test_queue_name, durable=True, auto_delete=False
|
||||
)
|
||||
|
||||
# Bind queue to exchange
|
||||
channel.queue_bind(
|
||||
exchange=self.test_exchange,
|
||||
queue=self.test_queue_name,
|
||||
routing_key=self.test_routing_key,
|
||||
)
|
||||
|
||||
# Purge the queue to start fresh
|
||||
channel.queue_purge(self.test_queue_name)
|
||||
print(f"✅ Test queue {self.test_queue_name} setup and purged")
|
||||
|
||||
def create_test_message(self, message_id: str, user_id: str = "test-user") -> str:
|
||||
"""Create a test graph execution message."""
|
||||
return json.dumps(
|
||||
{
|
||||
"graph_exec_id": f"exec-{message_id}",
|
||||
"graph_id": f"graph-{message_id}",
|
||||
"user_id": user_id,
|
||||
"execution_context": {"timezone": "UTC"},
|
||||
"nodes_input_masks": {},
|
||||
"starting_nodes_input": [],
|
||||
}
|
||||
)
|
||||
|
||||
def publish_message(self, message: str):
|
||||
"""Publish a message to the test queue."""
|
||||
channel = self.queue_client.get_channel()
|
||||
channel.basic_publish(
|
||||
exchange=self.test_exchange,
|
||||
routing_key=self.test_routing_key,
|
||||
body=message,
|
||||
)
|
||||
|
||||
def consume_messages(self, max_messages: int = 10, timeout: float = 5.0):
|
||||
"""Consume messages and track their order."""
|
||||
|
||||
def callback(ch, method, properties, body):
|
||||
try:
|
||||
message_data = json.loads(body.decode())
|
||||
self.received_messages.append(message_data)
|
||||
ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
|
||||
if len(self.received_messages) >= max_messages:
|
||||
self.stop_consuming.set()
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {e}")
|
||||
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
|
||||
|
||||
# Use synchronous consumption with blocking
|
||||
channel = self.queue_client.get_channel()
|
||||
|
||||
# Check if there are messages in the queue first
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=self.test_queue_name, auto_ack=False
|
||||
)
|
||||
if method_frame:
|
||||
# There are messages, set up consumer
|
||||
channel.basic_nack(
|
||||
delivery_tag=method_frame.delivery_tag, requeue=True
|
||||
) # Put message back
|
||||
|
||||
# Set up consumer
|
||||
channel.basic_consume(
|
||||
queue=self.test_queue_name,
|
||||
on_message_callback=callback,
|
||||
)
|
||||
|
||||
# Consume with timeout
|
||||
start_time = time.time()
|
||||
while (
|
||||
not self.stop_consuming.is_set()
|
||||
and (time.time() - start_time) < timeout
|
||||
and len(self.received_messages) < max_messages
|
||||
):
|
||||
try:
|
||||
channel.connection.process_data_events(time_limit=0.1)
|
||||
except Exception as e:
|
||||
print(f"Error during consumption: {e}")
|
||||
break
|
||||
|
||||
# Cancel the consumer
|
||||
try:
|
||||
channel.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# No messages in queue - this might be expected for some tests
|
||||
pass
|
||||
|
||||
return self.received_messages
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up test resources."""
|
||||
try:
|
||||
channel = self.queue_client.get_channel()
|
||||
channel.queue_delete(queue=self.test_queue_name)
|
||||
channel.exchange_delete(exchange=self.test_exchange)
|
||||
print(f"✅ Test queue {self.test_queue_name} cleaned up")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Cleanup issue: {e}")
|
||||
|
||||
|
||||
def test_queue_ordering_behavior():
|
||||
"""
|
||||
Integration test to verify that our republishing method sends messages to back of queue.
|
||||
This tests the actual fix for the rate limiting queue blocking issue.
|
||||
"""
|
||||
tester = QueueOrderTester()
|
||||
|
||||
try:
|
||||
tester.setup_queue()
|
||||
|
||||
print("🧪 Testing actual RabbitMQ queue ordering behavior...")
|
||||
|
||||
# Test 1: Normal FIFO behavior
|
||||
print("1. Testing normal FIFO queue behavior")
|
||||
|
||||
# Publish messages in order: A, B, C
|
||||
msg_a = tester.create_test_message("A")
|
||||
msg_b = tester.create_test_message("B")
|
||||
msg_c = tester.create_test_message("C")
|
||||
|
||||
tester.publish_message(msg_a)
|
||||
tester.publish_message(msg_b)
|
||||
tester.publish_message(msg_c)
|
||||
|
||||
# Consume and verify FIFO order: A, B, C
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=3)
|
||||
|
||||
assert len(messages) == 3, f"Expected 3 messages, got {len(messages)}"
|
||||
assert (
|
||||
messages[0]["graph_exec_id"] == "exec-A"
|
||||
), f"First message should be A, got {messages[0]['graph_exec_id']}"
|
||||
assert (
|
||||
messages[1]["graph_exec_id"] == "exec-B"
|
||||
), f"Second message should be B, got {messages[1]['graph_exec_id']}"
|
||||
assert (
|
||||
messages[2]["graph_exec_id"] == "exec-C"
|
||||
), f"Third message should be C, got {messages[2]['graph_exec_id']}"
|
||||
|
||||
print("✅ FIFO order confirmed: A -> B -> C")
|
||||
|
||||
# Test 2: Rate limiting simulation - the key test!
|
||||
print("2. Testing rate limiting fix scenario")
|
||||
|
||||
# Simulate the scenario where user1 is rate limited
|
||||
user1_msg = tester.create_test_message("RATE-LIMITED", "user1")
|
||||
user2_msg1 = tester.create_test_message("USER2-1", "user2")
|
||||
user2_msg2 = tester.create_test_message("USER2-2", "user2")
|
||||
|
||||
# Initially publish user1 message (gets consumed, then rate limited on retry)
|
||||
tester.publish_message(user1_msg)
|
||||
|
||||
# Other users publish their messages
|
||||
tester.publish_message(user2_msg1)
|
||||
tester.publish_message(user2_msg2)
|
||||
|
||||
# Now simulate: user1 message gets "requeued" using our new republishing method
|
||||
# This is what happens in manager.py when requeue_by_republishing=True
|
||||
tester.publish_message(user1_msg) # Goes to back via our method
|
||||
|
||||
# Expected order: RATE-LIMITED, USER2-1, USER2-2, RATE-LIMITED (republished to back)
|
||||
# This shows that user2 messages get processed instead of being blocked
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=4)
|
||||
|
||||
assert len(messages) == 4, f"Expected 4 messages, got {len(messages)}"
|
||||
|
||||
# The key verification: user2 messages are NOT blocked by user1's rate-limited message
|
||||
user2_messages = [msg for msg in messages if msg["user_id"] == "user2"]
|
||||
assert len(user2_messages) == 2, "Both user2 messages should be processed"
|
||||
assert user2_messages[0]["graph_exec_id"] == "exec-USER2-1"
|
||||
assert user2_messages[1]["graph_exec_id"] == "exec-USER2-2"
|
||||
|
||||
print("✅ Rate limiting fix confirmed: user2 executions NOT blocked by user1")
|
||||
|
||||
# Test 3: Verify our method behaves like going to back of queue
|
||||
print("3. Testing republishing sends messages to back")
|
||||
|
||||
# Start with message X in queue
|
||||
msg_x = tester.create_test_message("X")
|
||||
tester.publish_message(msg_x)
|
||||
|
||||
# Add message Y
|
||||
msg_y = tester.create_test_message("Y")
|
||||
tester.publish_message(msg_y)
|
||||
|
||||
# Republish X (simulates requeue using our method)
|
||||
tester.publish_message(msg_x)
|
||||
|
||||
# Expected: X, Y, X (X was republished to back)
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=3)
|
||||
|
||||
assert len(messages) == 3
|
||||
# Y should come before the republished X
|
||||
y_index = next(
|
||||
i for i, msg in enumerate(messages) if msg["graph_exec_id"] == "exec-Y"
|
||||
)
|
||||
republished_x_index = next(
|
||||
i
|
||||
for i, msg in enumerate(messages[1:], 1)
|
||||
if msg["graph_exec_id"] == "exec-X"
|
||||
)
|
||||
|
||||
assert (
|
||||
y_index < republished_x_index
|
||||
), f"Y should come before republished X, but got order: {[m['graph_exec_id'] for m in messages]}"
|
||||
|
||||
print("✅ Republishing confirmed: messages go to back of queue")
|
||||
|
||||
print("🎉 All integration tests passed!")
|
||||
print("🎉 Our republishing method works correctly with real RabbitMQ")
|
||||
print("🎉 Queue blocking issue is fixed!")
|
||||
|
||||
finally:
|
||||
tester.cleanup()
|
||||
|
||||
|
||||
def test_traditional_requeue_behavior():
|
||||
"""
|
||||
Test that traditional requeue (basic_nack with requeue=True) sends messages to FRONT of queue.
|
||||
This validates our hypothesis about why queue blocking occurs.
|
||||
"""
|
||||
tester = QueueOrderTester()
|
||||
|
||||
try:
|
||||
tester.setup_queue()
|
||||
print("🧪 Testing traditional requeue behavior (basic_nack with requeue=True)")
|
||||
|
||||
# Step 1: Publish message A
|
||||
msg_a = tester.create_test_message("A")
|
||||
tester.publish_message(msg_a)
|
||||
|
||||
# Step 2: Publish message B
|
||||
msg_b = tester.create_test_message("B")
|
||||
tester.publish_message(msg_b)
|
||||
|
||||
# Step 3: Consume message A and requeue it using traditional method
|
||||
channel = tester.queue_client.get_channel()
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=False
|
||||
)
|
||||
|
||||
assert method_frame is not None, "Should have received message A"
|
||||
consumed_msg = json.loads(body.decode())
|
||||
assert (
|
||||
consumed_msg["graph_exec_id"] == "exec-A"
|
||||
), f"Should have consumed message A, got {consumed_msg['graph_exec_id']}"
|
||||
|
||||
# Traditional requeue: basic_nack with requeue=True (sends to FRONT)
|
||||
channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True)
|
||||
print(f"🔄 Traditional requeue (to FRONT): {consumed_msg['graph_exec_id']}")
|
||||
|
||||
# Step 4: Consume all messages using basic_get for reliability
|
||||
received_messages = []
|
||||
|
||||
# Get first message
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=True
|
||||
)
|
||||
if method_frame:
|
||||
msg = json.loads(body.decode())
|
||||
received_messages.append(msg)
|
||||
|
||||
# Get second message
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=True
|
||||
)
|
||||
if method_frame:
|
||||
msg = json.loads(body.decode())
|
||||
received_messages.append(msg)
|
||||
|
||||
# CRITICAL ASSERTION: Traditional requeue should put A at FRONT
|
||||
# Expected order: A (requeued to front), B
|
||||
assert (
|
||||
len(received_messages) == 2
|
||||
), f"Expected 2 messages, got {len(received_messages)}"
|
||||
|
||||
first_msg = received_messages[0]["graph_exec_id"]
|
||||
second_msg = received_messages[1]["graph_exec_id"]
|
||||
|
||||
# This is the critical test: requeued message A should come BEFORE B
|
||||
assert (
|
||||
first_msg == "exec-A"
|
||||
), f"Traditional requeue should put A at FRONT, but first message was: {first_msg}"
|
||||
assert (
|
||||
second_msg == "exec-B"
|
||||
), f"B should come after requeued A, but second message was: {second_msg}"
|
||||
|
||||
print(
|
||||
"✅ HYPOTHESIS CONFIRMED: Traditional requeue sends messages to FRONT of queue"
|
||||
)
|
||||
print(f" Order: {first_msg} (requeued to front) → {second_msg}")
|
||||
print(" This explains why rate-limited messages block other users!")
|
||||
|
||||
finally:
|
||||
tester.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_queue_ordering_behavior()
|
||||
@@ -1,100 +1,30 @@
|
||||
// import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { cn } from "@/lib/utils";
|
||||
import React, { memo } from "react";
|
||||
import { BlockMenu } from "./NewBlockMenu/BlockMenu/BlockMenu";
|
||||
import { useNewControlPanel } from "./useNewControlPanel";
|
||||
// import { NewSaveControl } from "../SaveControl/NewSaveControl";
|
||||
import { GraphExecutionID } from "@/lib/autogpt-server-api";
|
||||
// import { ControlPanelButton } from "../ControlPanelButton";
|
||||
// import { GraphSearchMenu } from "../GraphMenu/GraphMenu";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { CustomNode } from "../FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { NewSaveControl } from "./NewSaveControl/NewSaveControl";
|
||||
import { UndoRedoButtons } from "./UndoRedoButtons";
|
||||
|
||||
export type Control = {
|
||||
icon: React.ReactNode;
|
||||
label: string;
|
||||
disabled?: boolean;
|
||||
onClick: () => void;
|
||||
};
|
||||
export const NewControlPanel = memo(() => {
|
||||
useNewControlPanel({});
|
||||
|
||||
export type NewControlPanelProps = {
|
||||
flowExecutionID?: GraphExecutionID | undefined;
|
||||
visualizeBeads?: "no" | "static" | "animate";
|
||||
pinSavePopover?: boolean;
|
||||
pinBlocksPopover?: boolean;
|
||||
nodes?: CustomNode[];
|
||||
onNodeSelect?: (nodeId: string) => void;
|
||||
onNodeHover?: (nodeId: string) => void;
|
||||
};
|
||||
export const NewControlPanel = memo(
|
||||
({
|
||||
flowExecutionID: _flowExecutionID,
|
||||
visualizeBeads: _visualizeBeads,
|
||||
pinSavePopover: _pinSavePopover,
|
||||
pinBlocksPopover: _pinBlocksPopover,
|
||||
nodes: _nodes,
|
||||
onNodeSelect: _onNodeSelect,
|
||||
onNodeHover: _onNodeHover,
|
||||
}: NewControlPanelProps) => {
|
||||
const _isGraphSearchEnabled = useGetFlag(Flag.GRAPH_SEARCH);
|
||||
|
||||
const {
|
||||
// agentDescription,
|
||||
// setAgentDescription,
|
||||
// saveAgent,
|
||||
// agentName,
|
||||
// setAgentName,
|
||||
// savedAgent,
|
||||
// isSaving,
|
||||
// isRunning,
|
||||
// isStopping,
|
||||
} = useNewControlPanel({});
|
||||
|
||||
return (
|
||||
<section
|
||||
className={cn(
|
||||
"absolute left-4 top-10 z-10 overflow-hidden rounded-[1rem] border-none bg-white p-0 shadow-[0_1px_5px_0_rgba(0,0,0,0.1)]",
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col items-center justify-center rounded-[1rem] p-0">
|
||||
<BlockMenu />
|
||||
{/* <Separator className="text-[#E1E1E1]" />
|
||||
{isGraphSearchEnabled && (
|
||||
<>
|
||||
<GraphSearchMenu
|
||||
nodes={nodes}
|
||||
blockMenuSelected={blockMenuSelected}
|
||||
setBlockMenuSelected={setBlockMenuSelected}
|
||||
onNodeSelect={onNodeSelect}
|
||||
onNodeHover={onNodeHover}
|
||||
/>
|
||||
<Separator className="text-[#E1E1E1]" />
|
||||
</>
|
||||
)}
|
||||
{controls.map((control, index) => (
|
||||
<ControlPanelButton
|
||||
key={index}
|
||||
onClick={() => control.onClick()}
|
||||
data-id={`control-button-${index}`}
|
||||
data-testid={`blocks-control-${control.label.toLowerCase()}-button`}
|
||||
disabled={control.disabled || false}
|
||||
className="rounded-none"
|
||||
>
|
||||
{control.icon}
|
||||
</ControlPanelButton>
|
||||
))} */}
|
||||
<Separator className="text-[#E1E1E1]" />
|
||||
<NewSaveControl />
|
||||
<Separator className="text-[#E1E1E1]" />
|
||||
<UndoRedoButtons />
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
},
|
||||
);
|
||||
return (
|
||||
<section
|
||||
className={cn(
|
||||
"absolute left-4 top-10 z-10 overflow-hidden rounded-[1rem] border-none bg-white p-0 shadow-[0_1px_5px_0_rgba(0,0,0,0.1)]",
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col items-center justify-center rounded-[1rem] p-0">
|
||||
<BlockMenu />
|
||||
<Separator className="text-[#E1E1E1]" />
|
||||
<NewSaveControl />
|
||||
<Separator className="text-[#E1E1E1]" />
|
||||
<UndoRedoButtons />
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
});
|
||||
|
||||
export default NewControlPanel;
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { CustomNode } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { CustomNode } from "../../../FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useGraphSearch } from "../GraphMenuSearchBar/useGraphMenuSearchBar";
|
||||
import { CustomNode } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { CustomNode } from "../../../FlowEditor/nodes/CustomNode/CustomNode";
|
||||
|
||||
interface UseGraphMenuProps {
|
||||
nodes: CustomNode[];
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import React from "react";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import { beautifyString, getPrimaryCategoryColor } from "@/lib/utils";
|
||||
import { beautifyString, categoryColorMap } from "@/lib/utils";
|
||||
import { SearchableNode } from "../GraphMenuSearchBar/useGraphMenuSearchBar";
|
||||
import { TextRenderer } from "@/components/__legacy__/ui/render";
|
||||
import {
|
||||
@@ -73,14 +73,12 @@ export const GraphSearchContent: React.FC<GraphSearchContentProps> = ({
|
||||
}
|
||||
|
||||
const nodeTitle =
|
||||
node.data?.metadata?.customized_name ||
|
||||
beautifyString(node.data?.blockType || "").replace(
|
||||
/ Block$/,
|
||||
"",
|
||||
);
|
||||
const nodeType = beautifyString(
|
||||
node.data?.blockType || "",
|
||||
).replace(/ Block$/, "");
|
||||
(node.data?.metadata?.customized_name as string) ||
|
||||
beautifyString(node.data?.title || "").replace(/ Block$/, "");
|
||||
const nodeType = beautifyString(node.data?.title || "").replace(
|
||||
/ Block$/,
|
||||
"",
|
||||
);
|
||||
|
||||
return (
|
||||
<TooltipProvider key={node.id}>
|
||||
@@ -100,7 +98,13 @@ export const GraphSearchContent: React.FC<GraphSearchContentProps> = ({
|
||||
onMouseLeave={() => onNodeHover?.(null)}
|
||||
>
|
||||
<div
|
||||
className={`h-full w-3 rounded-l-[7px] ${getPrimaryCategoryColor(node.data?.categories)}`}
|
||||
className={`h-full w-3 rounded-l-[7px] ${
|
||||
(node.data?.categories?.[0]?.category &&
|
||||
categoryColorMap[
|
||||
node.data.categories[0].category
|
||||
]) ||
|
||||
"bg-gray-300 dark:bg-slate-700"
|
||||
}`}
|
||||
/>
|
||||
<div className="mx-3 flex flex-1 items-center justify-between">
|
||||
<div className="mr-2 min-w-0">
|
||||
@@ -129,9 +133,10 @@ export const GraphSearchContent: React.FC<GraphSearchContentProps> = ({
|
||||
<div className="font-semibold">
|
||||
Node Type: {nodeType}
|
||||
</div>
|
||||
{node.data?.metadata?.customized_name && (
|
||||
{!!node.data?.metadata?.customized_name && (
|
||||
<div className="text-xs text-gray-500">
|
||||
Custom Name: {node.data.metadata.customized_name}
|
||||
Custom Name:{" "}
|
||||
{String(node.data.metadata.customized_name)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState, useMemo, useDeferredValue } from "react";
|
||||
import { CustomNode } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { CustomNode } from "../../../FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import jaro from "jaro-winkler";
|
||||
|
||||
@@ -67,10 +67,10 @@ function calculateNodeScore(
|
||||
const nodeTitle = (node.data?.title || "").toLowerCase(); // This includes the ID
|
||||
const nodeId = (node.id || "").toLowerCase();
|
||||
const nodeDescription = (node.data?.description || "").toLowerCase();
|
||||
const blockType = (node.data?.blockType || "").toLowerCase();
|
||||
const blockType = (node.data?.title || "").toLowerCase();
|
||||
const beautifiedBlockType = beautifyString(blockType).toLowerCase();
|
||||
const customizedName = (
|
||||
node.data?.metadata?.customized_name || ""
|
||||
const customizedName = String(
|
||||
node.data?.metadata?.customized_name || "",
|
||||
).toLowerCase();
|
||||
|
||||
// Get input and output names with defensive checks
|
||||
|
||||
@@ -1,54 +1,18 @@
|
||||
import { GraphID } from "@/lib/autogpt-server-api";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
|
||||
export interface NewControlPanelProps {
|
||||
// flowExecutionID: GraphExecutionID | undefined;
|
||||
visualizeBeads?: "no" | "static" | "animate";
|
||||
}
|
||||
|
||||
export const useNewControlPanel = ({
|
||||
// flowExecutionID,
|
||||
visualizeBeads: _visualizeBeads,
|
||||
}: NewControlPanelProps) => {
|
||||
const [blockMenuSelected, setBlockMenuSelected] = useState<
|
||||
"save" | "block" | "search" | ""
|
||||
>("");
|
||||
const query = useSearchParams();
|
||||
const _graphVersion = query.get("flowVersion");
|
||||
const _graphVersionParsed = _graphVersion
|
||||
? parseInt(_graphVersion)
|
||||
: undefined;
|
||||
|
||||
const _flowID = (query.get("flowID") as GraphID | null) ?? undefined;
|
||||
// const {
|
||||
// agentDescription,
|
||||
// setAgentDescription,
|
||||
// saveAgent,
|
||||
// agentName,
|
||||
// setAgentName,
|
||||
// savedAgent,
|
||||
// isSaving,
|
||||
// isRunning,
|
||||
// isStopping,
|
||||
// } = useAgentGraph(
|
||||
// flowID,
|
||||
// graphVersion,
|
||||
// flowExecutionID,
|
||||
// visualizeBeads !== "no",
|
||||
// );
|
||||
|
||||
return {
|
||||
blockMenuSelected,
|
||||
setBlockMenuSelected,
|
||||
// agentDescription,
|
||||
// setAgentDescription,
|
||||
// saveAgent,
|
||||
// agentName,
|
||||
// setAgentName,
|
||||
// savedAgent,
|
||||
// isSaving,
|
||||
// isRunning,
|
||||
// isStopping,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
import React, { useCallback, useMemo, useState, useDeferredValue } from "react";
|
||||
import { Card, CardContent, CardHeader } from "@/components/__legacy__/ui/card";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Input } from "@/components/__legacy__/ui/input";
|
||||
import { TextRenderer } from "@/components/__legacy__/ui/render";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import { CustomNode } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/__legacy__/ui/popover";
|
||||
import {
|
||||
Block,
|
||||
BlockIORootSchema,
|
||||
BlockUIType,
|
||||
GraphInputSchema,
|
||||
GraphOutputSchema,
|
||||
SpecialBlockID,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { MagnifyingGlassIcon, PlusIcon } from "@radix-ui/react-icons";
|
||||
import { IconToyBrick } from "@/components/__legacy__/ui/icons";
|
||||
import { getPrimaryCategoryColor } from "@/lib/utils";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import jaro from "jaro-winkler";
|
||||
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
||||
uiKey?: string;
|
||||
inputSchema: BlockIORootSchema | GraphInputSchema;
|
||||
outputSchema: BlockIORootSchema | GraphOutputSchema;
|
||||
hardcodedValues?: Record<string, any>;
|
||||
_cached?: {
|
||||
blockName: string;
|
||||
beautifiedName: string;
|
||||
description: string;
|
||||
};
|
||||
};
|
||||
|
||||
// Hook to preprocess blocks with cached expensive operations
|
||||
const useSearchableBlocks = (blocks: _Block[]): _Block[] => {
|
||||
return useMemo(
|
||||
() =>
|
||||
blocks.map((block) => {
|
||||
if (!block._cached) {
|
||||
block._cached = {
|
||||
blockName: block.name.toLowerCase(),
|
||||
beautifiedName: beautifyString(block.name).toLowerCase(),
|
||||
description: block.description.toLowerCase(),
|
||||
};
|
||||
}
|
||||
return block;
|
||||
}),
|
||||
[blocks],
|
||||
);
|
||||
};
|
||||
|
||||
interface BlocksControlProps {
|
||||
blocks: _Block[];
|
||||
addBlock: (
|
||||
id: string,
|
||||
name: string,
|
||||
hardcodedValues: Record<string, any>,
|
||||
) => void;
|
||||
pinBlocksPopover: boolean;
|
||||
flows: GraphMeta[];
|
||||
nodes: CustomNode[];
|
||||
}
|
||||
|
||||
/**
|
||||
* A React functional component that displays a control for managing blocks.
|
||||
*
|
||||
* @component
|
||||
* @param {Object} BlocksControlProps - The properties for the BlocksControl component.
|
||||
* @param {Block[]} BlocksControlProps.blocks - An array of blocks to be displayed and filtered.
|
||||
* @param {(id: string, name: string) => void} BlocksControlProps.addBlock - A function to call when a block is added.
|
||||
* @returns The rendered BlocksControl component.
|
||||
*/
|
||||
export function BlocksControl({
|
||||
blocks: _blocks,
|
||||
addBlock,
|
||||
pinBlocksPopover,
|
||||
flows,
|
||||
nodes,
|
||||
}: BlocksControlProps) {
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const deferredSearchQuery = useDeferredValue(searchQuery);
|
||||
const [selectedCategory, setSelectedCategory] = useState<string | null>(null);
|
||||
|
||||
const blocks = useSearchableBlocks(_blocks);
|
||||
|
||||
const graphHasWebhookNodes = nodes.some((n) =>
|
||||
[BlockUIType.WEBHOOK, BlockUIType.WEBHOOK_MANUAL].includes(n.data.uiType),
|
||||
);
|
||||
const graphHasInputNodes = nodes.some(
|
||||
(n) => n.data.uiType == BlockUIType.INPUT,
|
||||
);
|
||||
|
||||
const filteredAvailableBlocks = useMemo(() => {
|
||||
const blockList = blocks
|
||||
.filter((b) => b.uiType !== BlockUIType.AGENT)
|
||||
.sort((a, b) => a.name.localeCompare(b.name));
|
||||
|
||||
// Agent blocks are created from GraphMeta which doesn't include schemas.
|
||||
// Schemas will be fetched on-demand when the block is actually added.
|
||||
const agentBlockList = flows
|
||||
.map((flow): _Block => {
|
||||
return {
|
||||
id: SpecialBlockID.AGENT,
|
||||
name: flow.name,
|
||||
description:
|
||||
`Ver.${flow.version}` +
|
||||
(flow.description ? ` | ${flow.description}` : ""),
|
||||
categories: [{ category: "AGENT", description: "" }],
|
||||
// Empty schemas - will be populated when block is added
|
||||
inputSchema: { type: "object", properties: {} },
|
||||
outputSchema: { type: "object", properties: {} },
|
||||
staticOutput: false,
|
||||
uiType: BlockUIType.AGENT,
|
||||
costs: [],
|
||||
uiKey: flow.id,
|
||||
hardcodedValues: {
|
||||
graph_id: flow.id,
|
||||
graph_version: flow.version,
|
||||
// Schemas will be fetched on-demand when block is added
|
||||
},
|
||||
};
|
||||
})
|
||||
.map(
|
||||
(agentBlock): _Block => ({
|
||||
...agentBlock,
|
||||
_cached: {
|
||||
blockName: agentBlock.name.toLowerCase(),
|
||||
beautifiedName: beautifyString(agentBlock.name).toLowerCase(),
|
||||
description: agentBlock.description.toLowerCase(),
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
return blockList
|
||||
.concat(agentBlockList)
|
||||
.map((block) => ({
|
||||
block,
|
||||
score: blockScoreForQuery(block, deferredSearchQuery),
|
||||
}))
|
||||
.filter(
|
||||
({ block, score }) =>
|
||||
score > 0 &&
|
||||
(!selectedCategory ||
|
||||
block.categories.some((cat) => cat.category === selectedCategory)),
|
||||
)
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.map(({ block }) => ({
|
||||
...block,
|
||||
notAvailable:
|
||||
(block.uiType == BlockUIType.WEBHOOK &&
|
||||
graphHasWebhookNodes &&
|
||||
"Agents can only have one webhook-triggered block") ||
|
||||
(block.uiType == BlockUIType.WEBHOOK &&
|
||||
graphHasInputNodes &&
|
||||
"Webhook-triggered blocks can't be used together with input blocks") ||
|
||||
(block.uiType == BlockUIType.INPUT &&
|
||||
graphHasWebhookNodes &&
|
||||
"Input blocks can't be used together with a webhook-triggered block") ||
|
||||
null,
|
||||
}));
|
||||
}, [
|
||||
blocks,
|
||||
flows,
|
||||
selectedCategory,
|
||||
deferredSearchQuery,
|
||||
graphHasInputNodes,
|
||||
graphHasWebhookNodes,
|
||||
]);
|
||||
|
||||
const resetFilters = useCallback(() => {
|
||||
setSearchQuery("");
|
||||
setSelectedCategory(null);
|
||||
}, []);
|
||||
|
||||
// Handler to add a block, fetching graph data on-demand for agent blocks
|
||||
const handleAddBlock = useCallback(
|
||||
async (block: _Block & { notAvailable: string | null }) => {
|
||||
if (block.notAvailable) return;
|
||||
|
||||
// For agent blocks, fetch the full graph to get schemas
|
||||
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
|
||||
const graphID = block.hardcodedValues.graph_id as string;
|
||||
const graphVersion = block.hardcodedValues.graph_version as number;
|
||||
const graphData = okData(
|
||||
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
|
||||
);
|
||||
|
||||
if (graphData) {
|
||||
addBlock(block.id, block.name, {
|
||||
...block.hardcodedValues,
|
||||
input_schema: graphData.input_schema,
|
||||
output_schema: graphData.output_schema,
|
||||
});
|
||||
} else {
|
||||
// Fallback: add without schemas (will be incomplete)
|
||||
console.error("Failed to fetch graph data for agent block");
|
||||
addBlock(block.id, block.name, block.hardcodedValues || {});
|
||||
}
|
||||
} else {
|
||||
addBlock(block.id, block.name, block.hardcodedValues || {});
|
||||
}
|
||||
},
|
||||
[addBlock],
|
||||
);
|
||||
|
||||
// Extract unique categories from blocks
|
||||
const categories = useMemo(() => {
|
||||
return Array.from(
|
||||
new Set([
|
||||
null,
|
||||
...blocks
|
||||
.flatMap((block) => block.categories.map((cat) => cat.category))
|
||||
.sort(),
|
||||
]),
|
||||
);
|
||||
}, [blocks]);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
open={pinBlocksPopover ? true : undefined}
|
||||
onOpenChange={(open) => open || resetFilters()}
|
||||
>
|
||||
<Tooltip delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
data-id="blocks-control-popover-trigger"
|
||||
data-testid="blocks-control-blocks-button"
|
||||
name="Blocks"
|
||||
className="dark:hover:bg-slate-800"
|
||||
>
|
||||
<IconToyBrick />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">Blocks</TooltipContent>
|
||||
</Tooltip>
|
||||
<PopoverContent
|
||||
side="right"
|
||||
sideOffset={22}
|
||||
align="start"
|
||||
className="absolute -top-3 w-[17rem] rounded-xl border-none p-0 shadow-none md:w-[30rem]"
|
||||
data-id="blocks-control-popover-content"
|
||||
>
|
||||
<Card className="p-3 pb-0 dark:bg-slate-900">
|
||||
<CardHeader className="flex flex-col gap-x-8 gap-y-1 p-3 px-2">
|
||||
<div className="items-center justify-between">
|
||||
<Label
|
||||
htmlFor="search-blocks"
|
||||
className="whitespace-nowrap text-base font-bold text-black dark:text-white 2xl:text-xl"
|
||||
data-id="blocks-control-label"
|
||||
data-testid="blocks-control-blocks-label"
|
||||
>
|
||||
Blocks
|
||||
</Label>
|
||||
</div>
|
||||
<div className="relative flex items-center">
|
||||
<MagnifyingGlassIcon className="absolute m-2 h-5 w-5 text-gray-500 dark:text-gray-400" />
|
||||
<Input
|
||||
id="search-blocks"
|
||||
type="text"
|
||||
placeholder="Search blocks"
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
className="rounded-lg px-8 py-5 dark:bg-slate-800 dark:text-white"
|
||||
data-id="blocks-control-search-input"
|
||||
autoComplete="off"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
className="mt-2 flex flex-wrap gap-2"
|
||||
data-testid="blocks-categories-list"
|
||||
>
|
||||
{categories.map((category) => {
|
||||
const color = getPrimaryCategoryColor([
|
||||
{ category: category || "All", description: "" },
|
||||
]);
|
||||
const colorClass =
|
||||
selectedCategory === category ? `${color}` : "";
|
||||
return (
|
||||
<div
|
||||
key={category}
|
||||
data-testid="blocks-category"
|
||||
role="button"
|
||||
className={`cursor-pointer rounded-xl border px-2 py-2 text-xs font-medium dark:border-slate-700 dark:text-white ${colorClass}`}
|
||||
onClick={() =>
|
||||
setSelectedCategory(
|
||||
selectedCategory === category ? null : category,
|
||||
)
|
||||
}
|
||||
>
|
||||
{beautifyString((category || "All").toLowerCase())}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-scroll border-t border-t-gray-200 p-0 dark:border-t-slate-700">
|
||||
<ScrollArea
|
||||
className="h-[60vh] w-full"
|
||||
data-id="blocks-control-scroll-area"
|
||||
>
|
||||
{filteredAvailableBlocks.map((block) => (
|
||||
<Card
|
||||
key={block.uiKey || block.id}
|
||||
className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${
|
||||
block.notAvailable
|
||||
? "cursor-not-allowed opacity-50"
|
||||
: "cursor-move hover:shadow-lg"
|
||||
}`}
|
||||
data-id={`block-card-${block.id}`}
|
||||
draggable={!block.notAvailable}
|
||||
onDragStart={(e) => {
|
||||
if (block.notAvailable) return;
|
||||
e.dataTransfer.effectAllowed = "copy";
|
||||
e.dataTransfer.setData(
|
||||
"application/reactflow",
|
||||
JSON.stringify({
|
||||
blockId: block.id,
|
||||
blockName: block.name,
|
||||
hardcodedValues: block?.hardcodedValues || {},
|
||||
}),
|
||||
);
|
||||
}}
|
||||
onClick={() => handleAddBlock(block)}
|
||||
title={block.notAvailable ?? undefined}
|
||||
>
|
||||
<div
|
||||
className={`-ml-px h-full w-3 rounded-l-xl ${getPrimaryCategoryColor(block.categories)}`}
|
||||
></div>
|
||||
|
||||
<div className="mx-3 flex flex-1 items-center justify-between">
|
||||
<div className="mr-2 min-w-0">
|
||||
<span
|
||||
className="block truncate pb-1 text-sm font-semibold dark:text-white"
|
||||
data-id={`block-name-${block.id}`}
|
||||
data-type={block.uiType}
|
||||
data-testid={`block-name-${block.id}`}
|
||||
>
|
||||
<TextRenderer
|
||||
value={beautifyString(block.name).replace(
|
||||
/ Block$/,
|
||||
"",
|
||||
)}
|
||||
truncateLengthLimit={45}
|
||||
/>
|
||||
</span>
|
||||
<span
|
||||
className="block break-all text-xs font-normal text-gray-500 dark:text-gray-400"
|
||||
data-testid={`block-description-${block.id}`}
|
||||
>
|
||||
<TextRenderer
|
||||
value={block.description}
|
||||
truncateLengthLimit={165}
|
||||
/>
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
className="flex flex-shrink-0 items-center gap-1"
|
||||
data-id={`block-tooltip-${block.id}`}
|
||||
data-testid={`block-add`}
|
||||
>
|
||||
<PlusIcon className="h-6 w-6 rounded-lg bg-gray-200 stroke-black stroke-[0.5px] p-1 dark:bg-gray-700 dark:stroke-white" />
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
))}
|
||||
</ScrollArea>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates how well a block matches the search query and returns a relevance score.
|
||||
* The scoring algorithm works as follows:
|
||||
* - Returns 1 if no query (all blocks match equally)
|
||||
* - Normalized query for case-insensitive matching
|
||||
* - Returns 3 for exact substring matches in block name (highest priority)
|
||||
* - Returns 2 when all query words appear in the block name (regardless of order)
|
||||
* - Returns 1.X for blocks with names similar to query using Jaro-Winkler distance (X is similarity score)
|
||||
* - Returns 0.5 when all query words appear in the block description (lowest priority)
|
||||
* - Returns 0 for no match
|
||||
*
|
||||
* Higher scores will appear first in search results.
|
||||
*/
|
||||
function blockScoreForQuery(block: _Block, query: string): number {
|
||||
if (!query) return 1;
|
||||
const normalizedQuery = query.toLowerCase().trim();
|
||||
const queryWords = normalizedQuery.split(/\s+/);
|
||||
|
||||
// Use cached values for performance
|
||||
const { blockName, beautifiedName, description } = block._cached!;
|
||||
|
||||
// 1. Exact match in name (highest priority)
|
||||
if (
|
||||
blockName.includes(normalizedQuery) ||
|
||||
beautifiedName.includes(normalizedQuery)
|
||||
) {
|
||||
return 3;
|
||||
}
|
||||
|
||||
// 2. All query words in name (regardless of order)
|
||||
const allWordsInName = queryWords.every(
|
||||
(word) => blockName.includes(word) || beautifiedName.includes(word),
|
||||
);
|
||||
if (allWordsInName) return 2;
|
||||
|
||||
// 3. Similarity with name (Jaro-Winkler)
|
||||
const similarityThreshold = 0.65;
|
||||
const nameSimilarity = jaro(blockName, normalizedQuery);
|
||||
const beautifiedSimilarity = jaro(beautifiedName, normalizedQuery);
|
||||
const maxSimilarity = Math.max(nameSimilarity, beautifiedSimilarity);
|
||||
if (maxSimilarity > similarityThreshold) {
|
||||
return 1 + maxSimilarity; // Score between 1 and 2
|
||||
}
|
||||
|
||||
// 4. All query words in description (lower priority)
|
||||
const allWordsInDescription = queryWords.every((word) =>
|
||||
description.includes(word),
|
||||
);
|
||||
if (allWordsInDescription) return 0.5;
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
import React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { LogOut } from "lucide-react";
|
||||
import { ClockIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { IconPlay, IconSquare } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
interface Props {
|
||||
onClickAgentOutputs?: () => void;
|
||||
onClickRunAgent?: () => void;
|
||||
onClickStopRun: () => void;
|
||||
onClickScheduleButton?: () => void;
|
||||
isRunning: boolean;
|
||||
isDisabled: boolean;
|
||||
className?: string;
|
||||
resolutionModeActive?: boolean;
|
||||
}
|
||||
|
||||
export const BuildActionBar: React.FC<Props> = ({
|
||||
onClickAgentOutputs,
|
||||
onClickRunAgent,
|
||||
onClickStopRun,
|
||||
onClickScheduleButton,
|
||||
isRunning,
|
||||
isDisabled,
|
||||
className,
|
||||
resolutionModeActive = false,
|
||||
}) => {
|
||||
const buttonClasses =
|
||||
"flex items-center gap-2 text-sm font-medium md:text-lg";
|
||||
|
||||
// Show resolution mode message instead of action buttons
|
||||
if (resolutionModeActive) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-fit select-none items-center justify-center p-4",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-3 rounded-lg border border-amber-300 bg-amber-50 px-4 py-3 dark:border-amber-700 dark:bg-amber-900/30">
|
||||
<WarningIcon className="size-5 text-amber-600 dark:text-amber-400" />
|
||||
<span className="text-sm font-medium text-amber-800 dark:text-amber-200">
|
||||
Remove incompatible connections to continue
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-fit select-none items-center justify-center p-4",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex gap-1 md:gap-4">
|
||||
{onClickAgentOutputs && (
|
||||
<Button
|
||||
className={buttonClasses}
|
||||
variant="outline"
|
||||
size="primary"
|
||||
onClick={onClickAgentOutputs}
|
||||
title="View agent outputs"
|
||||
>
|
||||
<LogOut className="hidden size-5 md:flex" /> Agent Outputs
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{!isRunning ? (
|
||||
<Button
|
||||
className={cn(
|
||||
buttonClasses,
|
||||
onClickRunAgent && isDisabled
|
||||
? "cursor-default opacity-50 hover:bg-accent"
|
||||
: "",
|
||||
)}
|
||||
variant="accent"
|
||||
size="primary"
|
||||
onClick={onClickRunAgent}
|
||||
disabled={!onClickRunAgent}
|
||||
title="Run the agent"
|
||||
aria-label="Run the agent"
|
||||
data-testid="primary-action-run-agent"
|
||||
data-tutorial-id="primary-action-run-agent"
|
||||
>
|
||||
<IconPlay /> Run
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
className={buttonClasses}
|
||||
variant="destructive"
|
||||
size="primary"
|
||||
onClick={onClickStopRun}
|
||||
title="Stop the agent"
|
||||
data-id="primary-action-stop-agent"
|
||||
>
|
||||
<IconSquare /> Stop
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{onClickScheduleButton && (
|
||||
<Button
|
||||
className={buttonClasses}
|
||||
variant="outline"
|
||||
size="primary"
|
||||
onClick={onClickScheduleButton}
|
||||
title="Set up a run schedule for the agent"
|
||||
data-id="primary-action-schedule-agent"
|
||||
>
|
||||
<ClockIcon className="hidden h-5 w-5 md:flex" />
|
||||
Schedule Run
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,33 +0,0 @@
|
||||
import {
|
||||
BaseEdge,
|
||||
ConnectionLineComponentProps,
|
||||
Node,
|
||||
getBezierPath,
|
||||
Position,
|
||||
} from "@xyflow/react";
|
||||
|
||||
export default function ConnectionLine<NodeType extends Node>({
|
||||
fromPosition,
|
||||
fromHandle,
|
||||
fromX,
|
||||
fromY,
|
||||
toPosition,
|
||||
toX,
|
||||
toY,
|
||||
}: ConnectionLineComponentProps<NodeType>) {
|
||||
const sourceX =
|
||||
fromPosition === Position.Right
|
||||
? fromX + ((fromHandle?.width ?? 0) / 2 - 5)
|
||||
: fromX - ((fromHandle?.width ?? 0) / 2 - 5);
|
||||
|
||||
const [path] = getBezierPath({
|
||||
sourceX: sourceX,
|
||||
sourceY: fromY,
|
||||
sourcePosition: fromPosition,
|
||||
targetX: toX,
|
||||
targetY: toY,
|
||||
targetPosition: toPosition,
|
||||
});
|
||||
|
||||
return <BaseEdge path={path} style={{ strokeWidth: 2, stroke: "#555" }} />;
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
import { Card, CardContent } from "@/components/__legacy__/ui/card";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { cn } from "@/lib/utils";
|
||||
import React from "react";
|
||||
|
||||
/**
|
||||
* Represents a control element for the ControlPanel Component.
|
||||
* @type {Object} Control
|
||||
* @property {React.ReactNode} icon - The icon of the control from lucide-react https://lucide.dev/icons/
|
||||
* @property {string} label - The label of the control, to be leveraged by ToolTip.
|
||||
* @property {onclick} onClick - The function to be executed when the control is clicked.
|
||||
*/
|
||||
export type Control = {
|
||||
icon: React.ReactNode;
|
||||
label: string;
|
||||
disabled?: boolean;
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
interface ControlPanelProps {
|
||||
controls: Control[];
|
||||
topChildren?: React.ReactNode;
|
||||
botChildren?: React.ReactNode;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* ControlPanel component displays a panel with controls as icons.tsx with the ability to take in children.
|
||||
* @param {Object} ControlPanelProps - The properties of the control panel component.
|
||||
* @param {Array} ControlPanelProps.controls - An array of control objects representing actions to be preformed.
|
||||
* @param {Array} ControlPanelProps.children - The child components of the control panel.
|
||||
* @param {string} ControlPanelProps.className - Additional CSS class names for the control panel.
|
||||
* @returns The rendered control panel component.
|
||||
*/
|
||||
export const ControlPanel = ({
|
||||
controls,
|
||||
topChildren,
|
||||
botChildren,
|
||||
className,
|
||||
}: ControlPanelProps) => {
|
||||
return (
|
||||
<Card className={cn("m-4 mt-24 w-14 dark:bg-slate-900", className)}>
|
||||
<CardContent className="p-0">
|
||||
<div className="flex flex-col items-center gap-3 rounded-xl py-3">
|
||||
{topChildren}
|
||||
<Separator className="dark:bg-slate-700" />
|
||||
{controls.map((control, index) => (
|
||||
<Tooltip key={index} delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => control.onClick()}
|
||||
data-id={`control-button-${index}`}
|
||||
data-testid={`blocks-control-${control.label.toLowerCase()}-button`}
|
||||
disabled={control.disabled || false}
|
||||
className="dark:bg-slate-900 dark:text-slate-100 dark:hover:bg-slate-800"
|
||||
>
|
||||
{control.icon}
|
||||
<span className="sr-only">{control.label}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent
|
||||
side="right"
|
||||
className="dark:bg-slate-800 dark:text-slate-100"
|
||||
>
|
||||
{control.label}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
))}
|
||||
<Separator className="dark:bg-slate-700" />
|
||||
{botChildren}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
export default ControlPanel;
|
||||
@@ -1,240 +0,0 @@
|
||||
import React, {
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useState,
|
||||
useRef,
|
||||
} from "react";
|
||||
import {
|
||||
BaseEdge,
|
||||
EdgeLabelRenderer,
|
||||
EdgeProps,
|
||||
useReactFlow,
|
||||
XYPosition,
|
||||
Edge,
|
||||
Node,
|
||||
} from "@xyflow/react";
|
||||
import "./customedge.css";
|
||||
import { X } from "lucide-react";
|
||||
import { BuilderContext } from "../Flow/Flow";
|
||||
import { NodeExecutionResult } from "@/lib/autogpt-server-api";
|
||||
import { useCustomEdge } from "./useCustomEdge";
|
||||
|
||||
export type CustomEdgeData = {
|
||||
edgeColor: string;
|
||||
sourcePos?: XYPosition;
|
||||
isStatic?: boolean;
|
||||
beadUp: number;
|
||||
beadDown: number;
|
||||
beadData?: Map<string, NodeExecutionResult["status"]>;
|
||||
};
|
||||
|
||||
type Bead = {
|
||||
t: number;
|
||||
targetT: number;
|
||||
startTime: number;
|
||||
};
|
||||
|
||||
export type CustomEdge = Edge<CustomEdgeData, "custom">;
|
||||
|
||||
export function CustomEdge({
|
||||
id,
|
||||
data,
|
||||
selected,
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
markerEnd,
|
||||
}: EdgeProps<CustomEdge>) {
|
||||
const [beads, setBeads] = useState<{
|
||||
beads: Bead[];
|
||||
created: number;
|
||||
destroyed: number;
|
||||
}>({ beads: [], created: 0, destroyed: 0 });
|
||||
const beadsRef = useRef(beads);
|
||||
const { svgPath, length, getPointForT, getTForDistance } = useCustomEdge(
|
||||
sourceX - 5,
|
||||
sourceY - 5,
|
||||
targetX + 3,
|
||||
targetY - 5,
|
||||
);
|
||||
const { deleteElements } = useReactFlow<Node, CustomEdge>();
|
||||
const builderContext = useContext(BuilderContext);
|
||||
const { visualizeBeads } = builderContext ?? {
|
||||
visualizeBeads: "no",
|
||||
};
|
||||
|
||||
// Check if this edge is broken (during resolution mode)
|
||||
const isBroken =
|
||||
builderContext?.resolutionMode?.active &&
|
||||
builderContext?.resolutionMode?.brokenEdgeIds?.includes(id);
|
||||
|
||||
const onEdgeRemoveClick = () => {
|
||||
deleteElements({ edges: [{ id }] });
|
||||
};
|
||||
|
||||
const animationDuration = 500; // Duration in milliseconds for bead to travel the curve
|
||||
const beadDiameter = 12;
|
||||
const deltaTime = 16;
|
||||
|
||||
const setTargetPositions = useCallback(
|
||||
(beads: Bead[]) => {
|
||||
const distanceBetween = Math.min(
|
||||
(length - beadDiameter) / (beads.length + 1),
|
||||
beadDiameter,
|
||||
);
|
||||
|
||||
return beads.map((bead, index) => {
|
||||
const distanceFromEnd = beadDiameter * 1.35;
|
||||
const targetPosition = distanceBetween * index + distanceFromEnd;
|
||||
const t = getTForDistance(-targetPosition);
|
||||
|
||||
return {
|
||||
...bead,
|
||||
t: visualizeBeads === "animate" ? bead.t : t,
|
||||
targetT: t,
|
||||
} as Bead;
|
||||
});
|
||||
},
|
||||
[getTForDistance, length, visualizeBeads],
|
||||
);
|
||||
|
||||
beadsRef.current = beads;
|
||||
useEffect(() => {
|
||||
const beadUp: number = data?.beadUp ?? 0;
|
||||
const beadDown: number = data?.beadDown ?? 0;
|
||||
|
||||
if (
|
||||
beadUp === 0 &&
|
||||
beadDown === 0 &&
|
||||
(beads.created > 0 || beads.destroyed > 0)
|
||||
) {
|
||||
setBeads({ beads: [], created: 0, destroyed: 0 });
|
||||
return;
|
||||
}
|
||||
|
||||
// Add beads
|
||||
if (beadUp > beads.created) {
|
||||
setBeads(({ beads, created, destroyed }) => {
|
||||
const newBeads = [];
|
||||
for (let i = 0; i < beadUp - created; i++) {
|
||||
newBeads.push({ t: 0, targetT: 0, startTime: Date.now() });
|
||||
}
|
||||
|
||||
const b = setTargetPositions([...beads, ...newBeads]);
|
||||
return { beads: b, created: beadUp, destroyed };
|
||||
});
|
||||
}
|
||||
|
||||
// Animate and remove beads
|
||||
const interval = setInterval(
|
||||
({ current: beads }) => {
|
||||
// If there are no beads visible or moving, stop re-rendering
|
||||
if (
|
||||
(beadUp === beads.created && beads.created === beads.destroyed) ||
|
||||
beads.beads.every((bead) => bead.t >= bead.targetT)
|
||||
) {
|
||||
clearInterval(interval);
|
||||
return;
|
||||
}
|
||||
|
||||
setBeads(({ beads, created, destroyed }) => {
|
||||
let destroyedCount = 0;
|
||||
|
||||
const newBeads = beads
|
||||
.map((bead) => {
|
||||
const progressIncrement = deltaTime / animationDuration;
|
||||
const t = Math.min(
|
||||
bead.t + bead.targetT * progressIncrement,
|
||||
bead.targetT,
|
||||
);
|
||||
|
||||
return { ...bead, t };
|
||||
})
|
||||
.filter((bead, index) => {
|
||||
const removeCount = beadDown - destroyed;
|
||||
if (bead.t >= bead.targetT && index < removeCount) {
|
||||
destroyedCount++;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
return {
|
||||
beads: setTargetPositions(newBeads),
|
||||
created,
|
||||
destroyed: destroyed + destroyedCount,
|
||||
};
|
||||
});
|
||||
},
|
||||
deltaTime,
|
||||
beadsRef,
|
||||
);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [data?.beadUp, data?.beadDown, setTargetPositions, visualizeBeads]);
|
||||
|
||||
const middle = getPointForT(0.5);
|
||||
|
||||
// Determine edge color - red for broken edges
|
||||
const baseColor = data?.edgeColor ?? "#555555";
|
||||
const edgeColor = isBroken ? "#ef4444" : baseColor;
|
||||
// Add opacity to hex color (99 = 60% opacity, 80 = 50% opacity)
|
||||
const strokeColor = isBroken
|
||||
? `${edgeColor}99`
|
||||
: selected
|
||||
? edgeColor
|
||||
: `${edgeColor}80`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={svgPath}
|
||||
markerEnd={markerEnd}
|
||||
style={{
|
||||
stroke: strokeColor,
|
||||
strokeWidth: data?.isStatic ? 2.5 : 2,
|
||||
strokeDasharray: data?.isStatic ? "5 3" : undefined,
|
||||
}}
|
||||
className="data-sentry-unmask transition-all duration-200"
|
||||
/>
|
||||
<path
|
||||
d={svgPath}
|
||||
fill="none"
|
||||
strokeOpacity={0}
|
||||
strokeWidth={20}
|
||||
className="data-sentry-unmask react-flow__edge-interaction"
|
||||
/>
|
||||
<EdgeLabelRenderer>
|
||||
<div
|
||||
style={{
|
||||
position: "absolute",
|
||||
transform: `translate(-50%, -50%) translate(${middle.x}px,${middle.y}px)`,
|
||||
pointerEvents: "all",
|
||||
}}
|
||||
className="edge-label-renderer"
|
||||
>
|
||||
<button
|
||||
className="edge-label-button opacity-0 transition-opacity duration-200 hover:opacity-100"
|
||||
onClick={onEdgeRemoveClick}
|
||||
>
|
||||
<X className="size-4" />
|
||||
</button>
|
||||
</div>
|
||||
</EdgeLabelRenderer>
|
||||
{beads.beads.map((bead, index) => {
|
||||
const pos = getPointForT(bead.t);
|
||||
return (
|
||||
<circle
|
||||
key={index}
|
||||
cx={pos.x}
|
||||
cy={pos.y}
|
||||
r={beadDiameter / 2} // Bead radius
|
||||
fill={data?.edgeColor ?? "#555555"}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
.edge-label-renderer {
|
||||
position: absolute;
|
||||
pointer-events: all;
|
||||
}
|
||||
|
||||
.edge-label-button {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: #eee;
|
||||
border: 1px solid #fff;
|
||||
cursor: pointer;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
color: #555;
|
||||
opacity: 0;
|
||||
transition:
|
||||
opacity 0.2s ease-in-out,
|
||||
background-color 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
.edge-label-button.visible {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.edge-label-button:hover {
|
||||
box-shadow: 0 0 6px 2px rgba(0, 0, 0, 0.08);
|
||||
background: #f0f0f0;
|
||||
}
|
||||
|
||||
.edge-label-button svg {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
|
||||
.react-flow__edge-interaction {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.react-flow__edges > svg:has(> g.selected) {
|
||||
z-index: 10 !important;
|
||||
}
|
||||
|
||||
.react-flow__edgelabel-renderer {
|
||||
z-index: 11 !important;
|
||||
}
|
||||
@@ -1,157 +0,0 @@
|
||||
import { useCallback, useMemo } from "react";
|
||||
|
||||
type XYPosition = {
|
||||
x: number;
|
||||
y: number;
|
||||
};
|
||||
|
||||
export type BezierPath = {
|
||||
sourcePosition: XYPosition;
|
||||
control1: XYPosition;
|
||||
control2: XYPosition;
|
||||
targetPosition: XYPosition;
|
||||
};
|
||||
|
||||
export function useCustomEdge(
|
||||
sourceX: number,
|
||||
sourceY: number,
|
||||
targetX: number,
|
||||
targetY: number,
|
||||
) {
|
||||
const path: BezierPath = useMemo(() => {
|
||||
const xDifference = Math.abs(sourceX - targetX);
|
||||
const yDifference = Math.abs(sourceY - targetY);
|
||||
const xControlDistance =
|
||||
sourceX < targetX ? 64 : Math.max(xDifference / 2, 64);
|
||||
const yControlDistance = yDifference < 128 && sourceX > targetX ? -64 : 0;
|
||||
|
||||
return {
|
||||
sourcePosition: { x: sourceX, y: sourceY },
|
||||
control1: {
|
||||
x: sourceX + xControlDistance,
|
||||
y: sourceY + yControlDistance,
|
||||
},
|
||||
control2: {
|
||||
x: targetX - xControlDistance,
|
||||
y: targetY + yControlDistance,
|
||||
},
|
||||
targetPosition: { x: targetX, y: targetY },
|
||||
};
|
||||
}, [sourceX, sourceY, targetX, targetY]);
|
||||
|
||||
const svgPath = useMemo(
|
||||
() =>
|
||||
`M ${path.sourcePosition.x} ${path.sourcePosition.y} ` +
|
||||
`C ${path.control1.x} ${path.control1.y} ${path.control2.x} ${path.control2.y} ` +
|
||||
`${path.targetPosition.x}, ${path.targetPosition.y}`,
|
||||
[path],
|
||||
);
|
||||
|
||||
const getPointForT = useCallback(
|
||||
(t: number) => {
|
||||
// Bezier formula: (1-t)^3 * p0 + 3*(1-t)^2*t*p1 + 3*(1-t)*t^2*p2 + t^3*p3
|
||||
const x =
|
||||
Math.pow(1 - t, 3) * path.sourcePosition.x +
|
||||
3 * Math.pow(1 - t, 2) * t * path.control1.x +
|
||||
3 * (1 - t) * Math.pow(t, 2) * path.control2.x +
|
||||
Math.pow(t, 3) * path.targetPosition.x;
|
||||
|
||||
const y =
|
||||
Math.pow(1 - t, 3) * path.sourcePosition.y +
|
||||
3 * Math.pow(1 - t, 2) * t * path.control1.y +
|
||||
3 * (1 - t) * Math.pow(t, 2) * path.control2.y +
|
||||
Math.pow(t, 3) * path.targetPosition.y;
|
||||
|
||||
return { x, y };
|
||||
},
|
||||
[path],
|
||||
);
|
||||
|
||||
const getArcLength = useCallback(
|
||||
(t: number, samples: number = 100) => {
|
||||
let length = 0;
|
||||
let prevPoint = getPointForT(0);
|
||||
|
||||
for (let i = 1; i <= samples; i++) {
|
||||
const currT = (i / samples) * t;
|
||||
const currPoint = getPointForT(currT);
|
||||
length += Math.sqrt(
|
||||
Math.pow(currPoint.x - prevPoint.x, 2) +
|
||||
Math.pow(currPoint.y - prevPoint.y, 2),
|
||||
);
|
||||
prevPoint = currPoint;
|
||||
}
|
||||
|
||||
return length;
|
||||
},
|
||||
[getPointForT],
|
||||
);
|
||||
|
||||
const length = useMemo(() => {
|
||||
return getArcLength(1);
|
||||
}, [getArcLength]);
|
||||
|
||||
const getBezierDerivative = useCallback(
|
||||
(t: number) => {
|
||||
const mt = 1 - t;
|
||||
const x =
|
||||
3 *
|
||||
(mt * mt * (path.control1.x - path.sourcePosition.x) +
|
||||
2 * mt * t * (path.control2.x - path.control1.x) +
|
||||
t * t * (path.targetPosition.x - path.control2.x));
|
||||
const y =
|
||||
3 *
|
||||
(mt * mt * (path.control1.y - path.sourcePosition.y) +
|
||||
2 * mt * t * (path.control2.y - path.control1.y) +
|
||||
t * t * (path.targetPosition.y - path.control2.y));
|
||||
return { x, y };
|
||||
},
|
||||
[path],
|
||||
);
|
||||
|
||||
const getTForDistance = useCallback(
|
||||
(distance: number, epsilon: number = 0.0001) => {
|
||||
if (distance < 0) {
|
||||
distance = length + distance; // If distance is negative, calculate from the end of the curve
|
||||
}
|
||||
|
||||
let t = distance / getArcLength(1);
|
||||
let prevT = 0;
|
||||
|
||||
while (Math.abs(t - prevT) > epsilon) {
|
||||
prevT = t;
|
||||
const length = getArcLength(t);
|
||||
const derivative = Math.sqrt(
|
||||
Math.pow(getBezierDerivative(t).x, 2) +
|
||||
Math.pow(getBezierDerivative(t).y, 2),
|
||||
);
|
||||
t -= (length - distance) / derivative;
|
||||
t = Math.max(0, Math.min(1, t)); // Clamp t between 0 and 1
|
||||
}
|
||||
|
||||
return t;
|
||||
},
|
||||
[getArcLength, getBezierDerivative, length],
|
||||
);
|
||||
|
||||
const getPointAtDistance = useCallback(
|
||||
(distance: number) => {
|
||||
if (distance < 0) {
|
||||
distance = length + distance; // If distance is negative, calculate from the end of the curve
|
||||
}
|
||||
|
||||
const t = getTForDistance(distance);
|
||||
return getPointForT(t);
|
||||
},
|
||||
[getTForDistance, getPointForT, length],
|
||||
);
|
||||
|
||||
return {
|
||||
path,
|
||||
svgPath,
|
||||
length,
|
||||
getPointForT,
|
||||
getTForDistance,
|
||||
getPointAtDistance,
|
||||
};
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,244 +0,0 @@
|
||||
import React from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { AlertTriangle, XCircle, PlusCircle } from "lucide-react";
|
||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
|
||||
interface IncompatibilityDialogProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onConfirm: () => void;
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
agentName: string;
|
||||
incompatibilities: IncompatibilityInfo;
|
||||
}
|
||||
|
||||
export const IncompatibilityDialog: React.FC<IncompatibilityDialogProps> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
agentName,
|
||||
incompatibilities,
|
||||
}) => {
|
||||
const hasMissingInputs = incompatibilities.missingInputs.length > 0;
|
||||
const hasMissingOutputs = incompatibilities.missingOutputs.length > 0;
|
||||
const hasNewInputs = incompatibilities.newInputs.length > 0;
|
||||
const hasNewOutputs = incompatibilities.newOutputs.length > 0;
|
||||
const hasNewRequired = incompatibilities.newRequiredInputs.length > 0;
|
||||
const hasTypeMismatches = incompatibilities.inputTypeMismatches.length > 0;
|
||||
|
||||
const hasInputChanges = hasMissingInputs || hasNewInputs;
|
||||
const hasOutputChanges = hasMissingOutputs || hasNewOutputs;
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={(open) => !open && onClose()}>
|
||||
<DialogContent className="max-w-lg">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<AlertTriangle className="h-5 w-5 text-amber-500" />
|
||||
Incompatible Update
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Updating <strong>{beautifyString(agentName)}</strong> from v
|
||||
{currentVersion} to v{latestVersion} will break some connections.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-4 py-2">
|
||||
{/* Input changes - two column layout */}
|
||||
{hasInputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Input Changes"
|
||||
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingInputs}
|
||||
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newInputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Output changes - two column layout */}
|
||||
{hasOutputChanges && (
|
||||
<TwoColumnSection
|
||||
title="Output Changes"
|
||||
leftIcon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||
leftTitle="Removed"
|
||||
leftItems={incompatibilities.missingOutputs}
|
||||
rightIcon={<PlusCircle className="h-4 w-4 text-green-500" />}
|
||||
rightTitle="Added"
|
||||
rightItems={incompatibilities.newOutputs}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasTypeMismatches && (
|
||||
<SingleColumnSection
|
||||
icon={<XCircle className="h-4 w-4 text-red-500" />}
|
||||
title="Type Changed"
|
||||
description="These connected inputs have a different type:"
|
||||
items={incompatibilities.inputTypeMismatches.map(
|
||||
(m) => `${m.name} (${m.oldType} → ${m.newType})`,
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasNewRequired && (
|
||||
<SingleColumnSection
|
||||
icon={<PlusCircle className="h-4 w-4 text-amber-500" />}
|
||||
title="New Required Inputs"
|
||||
description="These inputs are now required:"
|
||||
items={incompatibilities.newRequiredInputs}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Alert variant="warning">
|
||||
<AlertDescription>
|
||||
If you proceed, you'll need to remove the broken connections
|
||||
before you can save or run your agent.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<DialogFooter className="gap-2 sm:gap-0">
|
||||
<Button variant="outline" onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={onConfirm}
|
||||
className="bg-amber-600 hover:bg-amber-700"
|
||||
>
|
||||
Update Anyway
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
interface TwoColumnSectionProps {
|
||||
title: string;
|
||||
leftIcon: React.ReactNode;
|
||||
leftTitle: string;
|
||||
leftItems: string[];
|
||||
rightIcon: React.ReactNode;
|
||||
rightTitle: string;
|
||||
rightItems: string[];
|
||||
}
|
||||
|
||||
const TwoColumnSection: React.FC<TwoColumnSectionProps> = ({
|
||||
title,
|
||||
leftIcon,
|
||||
leftTitle,
|
||||
leftItems,
|
||||
rightIcon,
|
||||
rightTitle,
|
||||
rightItems,
|
||||
}) => (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<span className="font-medium">{title}</span>
|
||||
<div className="mt-2 grid grid-cols-2 items-start gap-4">
|
||||
{/* Left column - Breaking changes */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{leftIcon}
|
||||
<span>{leftTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{leftItems.length > 0 ? (
|
||||
leftItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-red-50 px-1 py-0.5 font-mono text-xs text-red-700 dark:bg-red-900/30 dark:text-red-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
{/* Right column - Possible solutions */}
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-1.5 text-sm text-gray-500 dark:text-gray-400">
|
||||
{rightIcon}
|
||||
<span>{rightTitle}</span>
|
||||
</div>
|
||||
<ul className="mt-1.5 space-y-1">
|
||||
{rightItems.length > 0 ? (
|
||||
rightItems.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-green-50 px-1 py-0.5 font-mono text-xs text-green-700 dark:bg-green-900/30 dark:text-green-300">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))
|
||||
) : (
|
||||
<li className="text-sm italic text-gray-400 dark:text-gray-500">
|
||||
None
|
||||
</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
interface SingleColumnSectionProps {
|
||||
icon: React.ReactNode;
|
||||
title: string;
|
||||
description: string;
|
||||
items: string[];
|
||||
}
|
||||
|
||||
const SingleColumnSection: React.FC<SingleColumnSectionProps> = ({
|
||||
icon,
|
||||
title,
|
||||
description,
|
||||
items,
|
||||
}) => (
|
||||
<div className="rounded-md border border-gray-200 p-3 dark:border-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
{icon}
|
||||
<span className="font-medium">{title}</span>
|
||||
</div>
|
||||
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{description}
|
||||
</p>
|
||||
<ul className="mt-2 space-y-1">
|
||||
{items.map((item) => (
|
||||
<li
|
||||
key={item}
|
||||
className="ml-4 list-disc text-sm text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<code className="rounded bg-gray-100 px-1 py-0.5 font-mono text-xs dark:bg-gray-800">
|
||||
{item}
|
||||
</code>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
|
||||
export default IncompatibilityDialog;
|
||||
@@ -1,130 +0,0 @@
|
||||
import React from "react";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { ArrowUp, AlertTriangle, Info } from "lucide-react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface SubAgentUpdateBarProps {
|
||||
currentVersion: number;
|
||||
latestVersion: number;
|
||||
isCompatible: boolean;
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
onUpdate: () => void;
|
||||
isInResolutionMode?: boolean;
|
||||
}
|
||||
|
||||
export const SubAgentUpdateBar: React.FC<SubAgentUpdateBarProps> = ({
|
||||
currentVersion,
|
||||
latestVersion,
|
||||
isCompatible,
|
||||
incompatibilities,
|
||||
onUpdate,
|
||||
isInResolutionMode = false,
|
||||
}) => {
|
||||
if (isInResolutionMode) {
|
||||
return <ResolutionModeBar incompatibilities={incompatibilities} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-blue-50 px-3 py-2 dark:bg-blue-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<ArrowUp className="h-4 w-4 text-blue-600 dark:text-blue-400" />
|
||||
<span className="text-sm text-blue-700 dark:text-blue-300">
|
||||
Update available (v{currentVersion} → v{latestVersion})
|
||||
</span>
|
||||
{!isCompatible && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<AlertTriangle className="h-4 w-4 text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-xs">
|
||||
<p className="font-medium">Incompatible changes detected</p>
|
||||
<p className="text-xs text-gray-400">
|
||||
Click Update to see details
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={isCompatible ? "default" : "outline"}
|
||||
onClick={onUpdate}
|
||||
className={cn(
|
||||
"h-7 text-xs",
|
||||
!isCompatible && "border-amber-500 text-amber-600 hover:bg-amber-50",
|
||||
)}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface ResolutionModeBarProps {
|
||||
incompatibilities: IncompatibilityInfo | null;
|
||||
}
|
||||
|
||||
const ResolutionModeBar: React.FC<ResolutionModeBarProps> = ({
|
||||
incompatibilities,
|
||||
}) => {
|
||||
const formatIncompatibilities = () => {
|
||||
if (!incompatibilities) return "No incompatibilities";
|
||||
|
||||
const items: string[] = [];
|
||||
|
||||
if (incompatibilities.missingInputs.length > 0) {
|
||||
items.push(
|
||||
`Missing inputs: ${incompatibilities.missingInputs.join(", ")}`,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.missingOutputs.length > 0) {
|
||||
items.push(
|
||||
`Missing outputs: ${incompatibilities.missingOutputs.join(", ")}`,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.newRequiredInputs.length > 0) {
|
||||
items.push(
|
||||
`New required inputs: ${incompatibilities.newRequiredInputs.join(", ")}`,
|
||||
);
|
||||
}
|
||||
if (incompatibilities.inputTypeMismatches.length > 0) {
|
||||
const mismatches = incompatibilities.inputTypeMismatches
|
||||
.map((m) => `${m.name} (${m.oldType} → ${m.newType})`)
|
||||
.join(", ");
|
||||
items.push(`Type changed: ${mismatches}`);
|
||||
}
|
||||
|
||||
return items.join("\n");
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-t-lg bg-amber-50 px-3 py-2 dark:bg-amber-900/30">
|
||||
<div className="flex items-center gap-2">
|
||||
<AlertTriangle className="h-4 w-4 text-amber-600 dark:text-amber-400" />
|
||||
<span className="text-sm text-amber-700 dark:text-amber-300">
|
||||
Remove incompatible connections
|
||||
</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Info className="h-4 w-4 cursor-help text-amber-500" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="max-w-sm whitespace-pre-line">
|
||||
<p className="font-medium">Incompatible changes:</p>
|
||||
<p className="mt-1 text-xs">{formatIncompatibilities()}</p>
|
||||
<p className="mt-2 text-xs text-gray-400">
|
||||
Delete the red connections to continue
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SubAgentUpdateBar;
|
||||
@@ -1,131 +0,0 @@
|
||||
.custom-node {
|
||||
color: #000000;
|
||||
box-sizing: border-box;
|
||||
transition: border-color 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
.custom-node .custom-switch {
|
||||
padding: 0.5rem 1.25rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.error-message {
|
||||
color: #d9534f;
|
||||
font-size: 13px;
|
||||
padding-left: 0.5rem;
|
||||
}
|
||||
|
||||
/* Existing styles */
|
||||
.handle-container {
|
||||
display: flex;
|
||||
position: relative;
|
||||
margin-bottom: 0px;
|
||||
padding: 5px;
|
||||
min-height: 44px;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.react-flow__handle {
|
||||
background: transparent;
|
||||
width: auto;
|
||||
height: auto;
|
||||
border: 0;
|
||||
position: relative;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.border-error {
|
||||
border: 1px solid #d9534f;
|
||||
}
|
||||
|
||||
.select-input {
|
||||
width: 100%;
|
||||
padding: 5px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid #000;
|
||||
background: #fff;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.radio-label {
|
||||
display: block;
|
||||
margin: 5px 0;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.number-input {
|
||||
width: 100%;
|
||||
padding: 5px;
|
||||
border-radius: 4px;
|
||||
background: #fff;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.array-item-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.array-item-input {
|
||||
flex-grow: 1;
|
||||
padding: 5px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid #000;
|
||||
background: #fff;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.array-item-remove {
|
||||
background: #d9534f;
|
||||
border: none;
|
||||
color: white;
|
||||
cursor: pointer;
|
||||
margin-left: 5px;
|
||||
border-radius: 4px;
|
||||
padding: 5px 10px;
|
||||
}
|
||||
|
||||
.array-item-add {
|
||||
background: #5bc0de;
|
||||
border: none;
|
||||
color: white;
|
||||
cursor: pointer;
|
||||
border-radius: 4px;
|
||||
padding: 5px 10px;
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
.error-message {
|
||||
color: #d9534f;
|
||||
font-size: 13px;
|
||||
margin-top: 5px;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
/* Styles for node states */
|
||||
.completed {
|
||||
border-color: #27ae60; /* Green border for completed nodes */
|
||||
}
|
||||
|
||||
.running {
|
||||
border-color: #f39c12; /* Orange border for running nodes */
|
||||
}
|
||||
|
||||
.failed {
|
||||
border-color: #c0392b; /* Red border for failed nodes */
|
||||
}
|
||||
|
||||
.incomplete {
|
||||
border-color: #9f14ab; /* Pink border for incomplete nodes */
|
||||
}
|
||||
|
||||
.queued {
|
||||
border-color: #25e6e6; /* Cyan border for queued nodes */
|
||||
}
|
||||
|
||||
.custom-switch {
|
||||
padding-left: 2px;
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { Clipboard, Maximize2 } from "lucide-react";
|
||||
import React, { useMemo, useState } from "react";
|
||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "../../../../../components/__legacy__/ui/table";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
||||
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
||||
|
||||
type DataTableProps = {
|
||||
title?: string;
|
||||
truncateLongData?: boolean;
|
||||
data: { [key: string]: Array<any> };
|
||||
};
|
||||
|
||||
export default function DataTable({
|
||||
title,
|
||||
truncateLongData,
|
||||
data,
|
||||
}: DataTableProps) {
|
||||
const { toast } = useToast();
|
||||
const enableEnhancedOutputHandling = useGetFlag(
|
||||
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||
);
|
||||
const [expandedDialog, setExpandedDialog] = useState<{
|
||||
isOpen: boolean;
|
||||
execId: string;
|
||||
pinName: string;
|
||||
data: any[];
|
||||
} | null>(null);
|
||||
|
||||
// Prepare renderers for each item when enhanced mode is enabled
|
||||
const getItemRenderer = useMemo(() => {
|
||||
if (!enableEnhancedOutputHandling) return null;
|
||||
return (item: unknown) => {
|
||||
const metadata: OutputMetadata = {};
|
||||
return globalRegistry.getRenderer(item, metadata);
|
||||
};
|
||||
}, [enableEnhancedOutputHandling]);
|
||||
|
||||
const copyData = (pin: string, data: string) => {
|
||||
navigator.clipboard.writeText(data).then(() => {
|
||||
toast({
|
||||
title: `"${pin}" data copied to clipboard!`,
|
||||
duration: 2000,
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const openExpandedView = (pinName: string, pinData: any[]) => {
|
||||
setExpandedDialog({
|
||||
isOpen: true,
|
||||
execId: title || "Unknown Execution",
|
||||
pinName,
|
||||
data: pinData,
|
||||
});
|
||||
};
|
||||
|
||||
const closeExpandedView = () => {
|
||||
setExpandedDialog(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{title && <strong className="mt-2 flex justify-center">{title}</strong>}
|
||||
<Table className="cursor-default select-text">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Pin</TableHead>
|
||||
<TableHead>Data</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{Object.entries(data).map(([key, value]) => (
|
||||
<TableRow className="group" key={key}>
|
||||
<TableCell className="cursor-text">
|
||||
{beautifyString(key)}
|
||||
</TableCell>
|
||||
<TableCell className="cursor-text">
|
||||
<div className="flex min-h-9 items-center whitespace-pre-wrap">
|
||||
<div className="absolute right-1 top-auto m-1 hidden gap-1 group-hover:flex">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() => openExpandedView(key, value)}
|
||||
title="Expand Full View"
|
||||
>
|
||||
<Maximize2 size={18} />
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() =>
|
||||
copyData(
|
||||
beautifyString(key),
|
||||
value
|
||||
.map((i) =>
|
||||
typeof i === "object"
|
||||
? JSON.stringify(i, null, 2)
|
||||
: String(i),
|
||||
)
|
||||
.join(", "),
|
||||
)
|
||||
}
|
||||
title="Copy Data"
|
||||
>
|
||||
<Clipboard size={18} />
|
||||
</Button>
|
||||
</div>
|
||||
{value.map((item, index) => {
|
||||
const renderer = getItemRenderer?.(item);
|
||||
if (enableEnhancedOutputHandling && renderer) {
|
||||
const metadata: OutputMetadata = {};
|
||||
return (
|
||||
<React.Fragment key={index}>
|
||||
<OutputItem
|
||||
value={item}
|
||||
metadata={metadata}
|
||||
renderer={renderer}
|
||||
/>
|
||||
{index < value.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<React.Fragment key={index}>
|
||||
<ContentRenderer
|
||||
value={item}
|
||||
truncateLongData={truncateLongData}
|
||||
/>
|
||||
{index < value.length - 1 && ", "}
|
||||
</React.Fragment>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
{expandedDialog && (
|
||||
<ExpandableOutputDialog
|
||||
isOpen={expandedDialog.isOpen}
|
||||
onClose={closeExpandedView}
|
||||
execId={expandedDialog.execId}
|
||||
pinName={expandedDialog.pinName}
|
||||
data={expandedDialog.data}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,269 +0,0 @@
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputActions,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { Clipboard, Maximize2 } from "lucide-react";
|
||||
import React, { FC, useMemo, useState } from "react";
|
||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
||||
import { ScrollArea } from "../../../../../components/__legacy__/ui/scroll-area";
|
||||
import { Separator } from "../../../../../components/__legacy__/ui/separator";
|
||||
import { Switch } from "../../../../../components/atoms/Switch/Switch";
|
||||
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
||||
|
||||
interface ExpandableOutputDialogProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
execId: string;
|
||||
pinName: string;
|
||||
data: any[];
|
||||
}
|
||||
|
||||
const ExpandableOutputDialog: FC<ExpandableOutputDialogProps> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
execId,
|
||||
pinName,
|
||||
data,
|
||||
}) => {
|
||||
const { toast } = useToast();
|
||||
const enableEnhancedOutputHandling = useGetFlag(
|
||||
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||
);
|
||||
const [useEnhancedRenderer, setUseEnhancedRenderer] = useState(false);
|
||||
|
||||
// Prepare items for the enhanced renderer system
|
||||
const outputItems = useMemo(() => {
|
||||
if (!data || !useEnhancedRenderer) return [];
|
||||
|
||||
const items: Array<{
|
||||
key: string;
|
||||
label: string;
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
renderer: any;
|
||||
}> = [];
|
||||
|
||||
data.forEach((value, index) => {
|
||||
const metadata: OutputMetadata = {};
|
||||
|
||||
// Extract metadata from the value if it's an object
|
||||
if (
|
||||
typeof value === "object" &&
|
||||
value !== null &&
|
||||
!React.isValidElement(value)
|
||||
) {
|
||||
const objValue = value as any;
|
||||
if (objValue.type) metadata.type = objValue.type;
|
||||
if (objValue.mimeType) metadata.mimeType = objValue.mimeType;
|
||||
if (objValue.filename) metadata.filename = objValue.filename;
|
||||
if (objValue.language) metadata.language = objValue.language;
|
||||
}
|
||||
|
||||
const renderer = globalRegistry.getRenderer(value, metadata);
|
||||
if (renderer) {
|
||||
items.push({
|
||||
key: `item-${index}`,
|
||||
label: index === 0 ? beautifyString(pinName) : "",
|
||||
value,
|
||||
metadata,
|
||||
renderer,
|
||||
});
|
||||
} else {
|
||||
// Fallback to text renderer
|
||||
const textRenderer = globalRegistry
|
||||
.getAllRenderers()
|
||||
.find((r) => r.name === "TextRenderer");
|
||||
if (textRenderer) {
|
||||
items.push({
|
||||
key: `item-${index}`,
|
||||
label: index === 0 ? beautifyString(pinName) : "",
|
||||
value:
|
||||
typeof value === "string"
|
||||
? value
|
||||
: JSON.stringify(value, null, 2),
|
||||
metadata,
|
||||
renderer: textRenderer,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return items;
|
||||
}, [data, useEnhancedRenderer, pinName]);
|
||||
|
||||
const copyData = () => {
|
||||
const formattedData = data
|
||||
.map((item) =>
|
||||
typeof item === "object" ? JSON.stringify(item, null, 2) : String(item),
|
||||
)
|
||||
.join("\n\n");
|
||||
|
||||
navigator.clipboard.writeText(formattedData).then(() => {
|
||||
toast({
|
||||
title: `"${beautifyString(pinName)}" data copied to clipboard!`,
|
||||
duration: 2000,
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={
|
||||
<div className="flex items-center justify-between pr-8">
|
||||
<div className="flex items-center gap-2">
|
||||
<Maximize2 size={20} />
|
||||
Full Output Preview
|
||||
</div>
|
||||
{enableEnhancedOutputHandling && (
|
||||
<div className="flex items-center gap-3">
|
||||
<label
|
||||
htmlFor="enhanced-rendering-toggle"
|
||||
className="cursor-pointer select-none text-sm font-normal text-gray-600"
|
||||
>
|
||||
Enhanced Rendering
|
||||
</label>
|
||||
<Switch
|
||||
id="enhanced-rendering-toggle"
|
||||
checked={useEnhancedRenderer}
|
||||
onCheckedChange={setUseEnhancedRenderer}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: (open) => {
|
||||
if (!open) onClose();
|
||||
},
|
||||
}}
|
||||
onClose={onClose}
|
||||
styling={{
|
||||
maxWidth: "56rem",
|
||||
width: "90vw",
|
||||
height: "90vh",
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex h-full flex-col">
|
||||
<div className="pb-4">
|
||||
<p className="text-sm text-zinc-600">
|
||||
Execution ID: <span className="font-mono text-xs">{execId}</span>
|
||||
<br />
|
||||
Pin:{" "}
|
||||
<span className="font-semibold">{beautifyString(pinName)}</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-1 flex-col overflow-hidden">
|
||||
{useEnhancedRenderer && outputItems.length > 0 && (
|
||||
<div className="border-b px-4 py-2">
|
||||
<OutputActions
|
||||
items={outputItems.map((item) => ({
|
||||
value: item.value,
|
||||
metadata: item.metadata,
|
||||
renderer: item.renderer,
|
||||
}))}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<ScrollArea className="h-full">
|
||||
<div className="p-4">
|
||||
{data.length > 0 ? (
|
||||
useEnhancedRenderer ? (
|
||||
<div className="space-y-4">
|
||||
{outputItems.map((item) => (
|
||||
<OutputItem
|
||||
key={item.key}
|
||||
value={item.value}
|
||||
metadata={item.metadata}
|
||||
renderer={item.renderer}
|
||||
label={item.label}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
{data.map((item, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="rounded-lg border bg-gray-50 p-4"
|
||||
>
|
||||
<div className="mb-2 flex items-center justify-between">
|
||||
<span className="text-sm font-medium text-gray-600">
|
||||
Item {index + 1} of {data.length}
|
||||
</span>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
const itemData =
|
||||
typeof item === "object"
|
||||
? JSON.stringify(item, null, 2)
|
||||
: String(item);
|
||||
navigator.clipboard
|
||||
.writeText(itemData)
|
||||
.then(() => {
|
||||
toast({
|
||||
title: `Item ${index + 1} copied to clipboard!`,
|
||||
duration: 2000,
|
||||
});
|
||||
});
|
||||
}}
|
||||
className="flex items-center gap-1"
|
||||
>
|
||||
<Clipboard size={14} />
|
||||
Copy Item
|
||||
</Button>
|
||||
</div>
|
||||
<Separator className="mb-3" />
|
||||
<div className="whitespace-pre-wrap break-words font-mono text-sm">
|
||||
<ContentRenderer
|
||||
value={item}
|
||||
truncateLongData={false}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
) : (
|
||||
<div className="py-8 text-center text-gray-500">
|
||||
No data available
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
<Dialog.Footer className="flex justify-between">
|
||||
<div className="text-sm text-gray-600">
|
||||
{data.length} item{data.length !== 1 ? "s" : ""} total
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
{!useEnhancedRenderer && (
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={copyData}
|
||||
className="flex items-center gap-1"
|
||||
>
|
||||
<Clipboard size={16} />
|
||||
Copy All
|
||||
</Button>
|
||||
)}
|
||||
<Button onClick={onClose}>Close</Button>
|
||||
</div>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export default ExpandableOutputDialog;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,103 +0,0 @@
|
||||
/* flow.css or index.css */
|
||||
|
||||
body {
|
||||
font-family:
|
||||
-apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", "Ubuntu",
|
||||
"Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif;
|
||||
}
|
||||
|
||||
code {
|
||||
font-family:
|
||||
source-code-pro, Menlo, Monaco, Consolas, "Courier New", monospace;
|
||||
}
|
||||
|
||||
.modal {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
right: auto;
|
||||
bottom: auto;
|
||||
margin-right: -50%;
|
||||
transform: translate(-50%, -50%);
|
||||
background: #ffffff;
|
||||
padding: 20px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
color: #000000;
|
||||
}
|
||||
|
||||
.overlay {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.75);
|
||||
}
|
||||
|
||||
.modal h2 {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.modal button {
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.modal form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.modal form div {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: -600px;
|
||||
width: 350px;
|
||||
height: calc(100vh - 68px); /* Full height minus top offset */
|
||||
background-color: #ffffff;
|
||||
color: #000000;
|
||||
padding: 20px;
|
||||
transition: left 0.3s ease;
|
||||
z-index: 1000;
|
||||
overflow-y: auto;
|
||||
margin-top: 68px; /* Margin to push content below the top fixed area */
|
||||
}
|
||||
|
||||
.sidebar.open {
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.sidebar h3 {
|
||||
margin: 0 0 10px;
|
||||
}
|
||||
|
||||
.sidebar input {
|
||||
margin: 0 0 10px;
|
||||
}
|
||||
|
||||
.sidebarNodeRowStyle {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
background-color: #e2e2e2;
|
||||
padding: 10px;
|
||||
margin-bottom: 10px;
|
||||
border-radius: 10px;
|
||||
cursor: grab;
|
||||
}
|
||||
|
||||
.sidebarNodeRowStyle.dragging {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.flow-container {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100vw;
|
||||
height: 100vh;
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
import React from "react";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/__legacy__/ui/popover";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { MagnifyingGlassIcon } from "@radix-ui/react-icons";
|
||||
import { CustomNode } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import { GraphSearchContent } from "../NewControlPanel/NewSearchGraph/GraphMenuContent/GraphContent";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { useGraphMenu } from "../NewControlPanel/NewSearchGraph/GraphMenu/useGraphMenu";
|
||||
|
||||
interface GraphSearchControlProps {
|
||||
nodes: CustomNode[];
|
||||
onNodeSelect: (nodeId: string) => void;
|
||||
onNodeHover?: (nodeId: string | null) => void;
|
||||
}
|
||||
|
||||
export function GraphSearchControl({
|
||||
nodes,
|
||||
onNodeSelect,
|
||||
onNodeHover,
|
||||
}: GraphSearchControlProps) {
|
||||
// Use the same hook as GraphSearchMenu for consistency
|
||||
const {
|
||||
open,
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
filteredNodes,
|
||||
handleNodeSelect,
|
||||
handleOpenChange,
|
||||
} = useGraphMenu({
|
||||
nodes,
|
||||
blockMenuSelected: "", // We don't need to track this in the old control panel
|
||||
setBlockMenuSelected: () => {}, // Not needed in this context
|
||||
onNodeSelect,
|
||||
});
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={handleOpenChange}>
|
||||
<Tooltip delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
data-id="graph-search-control-trigger"
|
||||
data-testid="graph-search-control-button"
|
||||
name="Search"
|
||||
className="dark:hover:bg-slate-800"
|
||||
>
|
||||
<MagnifyingGlassIcon className="h-5 w-5" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">Search Graph</TooltipContent>
|
||||
</Tooltip>
|
||||
|
||||
<PopoverContent
|
||||
side="right"
|
||||
sideOffset={22}
|
||||
align="start"
|
||||
alignOffset={-50} // Offset upward to align with control panel top
|
||||
className="absolute -top-3 w-[17rem] rounded-xl border-none p-0 shadow-none md:w-[30rem]"
|
||||
data-id="graph-search-popover-content"
|
||||
>
|
||||
<GraphSearchContent
|
||||
searchQuery={searchQuery}
|
||||
onSearchChange={setSearchQuery}
|
||||
filteredNodes={filteredNodes}
|
||||
onNodeSelect={handleNodeSelect}
|
||||
onNodeHover={onNodeHover}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
import React, { FC, useEffect, useState } from "react";
|
||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||
import { Textarea } from "../../../../../components/__legacy__/ui/textarea";
|
||||
import { Maximize2, Minimize2, Clipboard } from "lucide-react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { toast } from "../../../../../components/molecules/Toast/use-toast";
|
||||
|
||||
interface ModalProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onSave: (value: string) => void;
|
||||
title?: string;
|
||||
defaultValue: string;
|
||||
}
|
||||
|
||||
const InputModalComponent: FC<ModalProps> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
onSave,
|
||||
title,
|
||||
defaultValue,
|
||||
}) => {
|
||||
const [tempValue, setTempValue] = useState(defaultValue);
|
||||
const [isMaximized, setIsMaximized] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
setTempValue(defaultValue);
|
||||
setIsMaximized(false);
|
||||
}
|
||||
}, [isOpen, defaultValue]);
|
||||
|
||||
const handleSave = () => {
|
||||
onSave(tempValue);
|
||||
onClose();
|
||||
};
|
||||
|
||||
const toggleSize = () => {
|
||||
setIsMaximized(!isMaximized);
|
||||
};
|
||||
|
||||
const copyValue = () => {
|
||||
navigator.clipboard.writeText(tempValue).then(() => {
|
||||
toast({
|
||||
title: "Input value copied to clipboard!",
|
||||
duration: 2000,
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
if (!isOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const modalContent = (
|
||||
<div
|
||||
id="modal-content"
|
||||
className={`fixed rounded-lg border-[1.5px] bg-white p-5 ${
|
||||
isMaximized ? "inset-[128px] flex flex-col" : `w-[90%] max-w-[800px]`
|
||||
}`}
|
||||
>
|
||||
<h2 className="mb-4 text-center text-lg font-semibold">
|
||||
{title || "Enter input text"}
|
||||
</h2>
|
||||
<div className="nowheel relative flex-grow">
|
||||
<Textarea
|
||||
className="h-full min-h-[200px] w-full resize-none"
|
||||
value={tempValue}
|
||||
onChange={(e) => setTempValue(e.target.value)}
|
||||
/>
|
||||
<div className="absolute bottom-2 right-2 flex space-x-2">
|
||||
<Button onClick={copyValue} size="icon" variant="outline">
|
||||
<Clipboard size={18} />
|
||||
</Button>
|
||||
<Button onClick={toggleSize} size="icon" variant="outline">
|
||||
{isMaximized ? <Minimize2 size={18} /> : <Maximize2 size={18} />}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-4 flex justify-end space-x-2">
|
||||
<Button onClick={onClose} variant="outline">
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleSave}>Save</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
{isMaximized ? (
|
||||
createPortal(
|
||||
<div className="fixed inset-0 flex items-center justify-center bg-white bg-opacity-60">
|
||||
{modalContent}
|
||||
</div>,
|
||||
document.body,
|
||||
)
|
||||
) : (
|
||||
<div className="nodrag fixed inset-0 flex items-center justify-center bg-white bg-opacity-60">
|
||||
{modalContent}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default InputModalComponent;
|
||||
@@ -1,163 +0,0 @@
|
||||
import { BlockIOSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
cn,
|
||||
beautifyString,
|
||||
getTypeBgColor,
|
||||
getTypeTextColor,
|
||||
getEffectiveType,
|
||||
} from "@/lib/utils";
|
||||
import { FC, memo, useCallback } from "react";
|
||||
import { Handle, Position } from "@xyflow/react";
|
||||
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
||||
|
||||
type HandleProps = {
|
||||
keyName: string;
|
||||
schema: BlockIOSubSchema;
|
||||
isConnected: boolean;
|
||||
isRequired?: boolean;
|
||||
side: "left" | "right";
|
||||
title?: string;
|
||||
className?: string;
|
||||
isBroken?: boolean;
|
||||
};
|
||||
|
||||
// Move the constant out of the component to avoid re-creation on every render.
|
||||
const TYPE_NAME: Record<string, string> = {
|
||||
string: "text",
|
||||
number: "number",
|
||||
integer: "integer",
|
||||
boolean: "true/false",
|
||||
object: "object",
|
||||
array: "list",
|
||||
null: "null",
|
||||
};
|
||||
|
||||
// Extract and memoize the Dot component so that it doesn't re-render unnecessarily.
|
||||
const Dot: FC<{ isConnected: boolean; type?: string; isBroken?: boolean }> =
|
||||
memo(({ isConnected, type, isBroken }) => {
|
||||
const color = isBroken
|
||||
? "border-red-500 bg-red-100 dark:bg-red-900/30"
|
||||
: isConnected
|
||||
? getTypeBgColor(type || "any")
|
||||
: "border-gray-300 dark:border-gray-600";
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"m-1 h-4 w-4 rounded-full border-2 bg-white transition-colors duration-100 group-hover:bg-gray-300 dark:bg-slate-800 dark:group-hover:bg-gray-700",
|
||||
color,
|
||||
isBroken && "opacity-50",
|
||||
)}
|
||||
/>
|
||||
);
|
||||
});
|
||||
Dot.displayName = "Dot";
|
||||
|
||||
const NodeHandle: FC<HandleProps> = ({
|
||||
keyName,
|
||||
schema,
|
||||
isConnected,
|
||||
isRequired,
|
||||
side,
|
||||
title,
|
||||
className,
|
||||
isBroken = false,
|
||||
}) => {
|
||||
// Extract effective type from schema (handles anyOf/oneOf/allOf wrappers)
|
||||
const effectiveType = getEffectiveType(schema);
|
||||
|
||||
const typeClass = `text-sm ${getTypeTextColor(effectiveType || "any")} ${
|
||||
side === "left" ? "text-left" : "text-right"
|
||||
}`;
|
||||
|
||||
const label = (
|
||||
<div className={cn("flex flex-grow flex-row", isBroken && "opacity-50")}>
|
||||
<span
|
||||
className={cn(
|
||||
"data-sentry-unmask text-m green flex items-end pr-2 text-gray-900 dark:text-gray-100",
|
||||
className,
|
||||
isBroken && "text-red-500 line-through",
|
||||
)}
|
||||
>
|
||||
{title || schema.title || beautifyString(keyName.toLowerCase())}
|
||||
{isRequired ? "*" : ""}
|
||||
</span>
|
||||
<span
|
||||
className={cn(
|
||||
`${typeClass} data-sentry-unmask flex items-end`,
|
||||
isBroken && "text-red-400",
|
||||
)}
|
||||
>
|
||||
({TYPE_NAME[effectiveType as keyof typeof TYPE_NAME] || "any"})
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
|
||||
// Use a native HTML onContextMenu handler instead of wrapping a large node with a Radix ContextMenu trigger.
|
||||
const handleContextMenu = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
// Optionally, you can trigger a custom, lightweight context menu here.
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
if (side === "left") {
|
||||
return (
|
||||
<div
|
||||
key={keyName}
|
||||
className={cn("handle-container", isBroken && "pointer-events-none")}
|
||||
onContextMenu={handleContextMenu}
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
data-testid={`input-handle-${keyName}`}
|
||||
position={Position.Left}
|
||||
id={keyName}
|
||||
className={cn("group -ml-[38px]", isBroken && "cursor-not-allowed")}
|
||||
isConnectable={!isBroken}
|
||||
>
|
||||
<div className="pointer-events-none flex items-center">
|
||||
<Dot
|
||||
isConnected={isConnected}
|
||||
type={effectiveType}
|
||||
isBroken={isBroken}
|
||||
/>
|
||||
{label}
|
||||
</div>
|
||||
</Handle>
|
||||
<InformationTooltip description={schema.description} />
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div
|
||||
key={keyName}
|
||||
className={cn(
|
||||
"handle-container justify-end",
|
||||
isBroken && "pointer-events-none",
|
||||
)}
|
||||
onContextMenu={handleContextMenu}
|
||||
>
|
||||
<Handle
|
||||
type="source"
|
||||
data-testid={`output-handle-${keyName}`}
|
||||
position={Position.Right}
|
||||
id={keyName}
|
||||
className={cn("group -mr-[38px]", isBroken && "cursor-not-allowed")}
|
||||
isConnectable={!isBroken}
|
||||
>
|
||||
<div className="pointer-events-none flex items-center">
|
||||
{label}
|
||||
<Dot
|
||||
isConnected={isConnected}
|
||||
type={effectiveType}
|
||||
isBroken={isBroken}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
export default memo(NodeHandle);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,158 +0,0 @@
|
||||
import React, { useContext, useMemo, useState } from "react";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Maximize2 } from "lucide-react";
|
||||
import * as Separator from "@radix-ui/react-separator";
|
||||
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
|
||||
import { BuilderContext } from "./Flow/Flow";
|
||||
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
||||
|
||||
type NodeOutputsProps = {
|
||||
title?: string;
|
||||
truncateLongData?: boolean;
|
||||
data: { [key: string]: Array<any> };
|
||||
};
|
||||
|
||||
export default function NodeOutputs({
|
||||
title,
|
||||
truncateLongData,
|
||||
data,
|
||||
}: NodeOutputsProps) {
|
||||
const builderContext = useContext(BuilderContext);
|
||||
const enableEnhancedOutputHandling = useGetFlag(
|
||||
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||
);
|
||||
|
||||
const [expandedDialog, setExpandedDialog] = useState<{
|
||||
isOpen: boolean;
|
||||
execId: string;
|
||||
pinName: string;
|
||||
data: any[];
|
||||
} | null>(null);
|
||||
|
||||
if (!builderContext) {
|
||||
throw new Error(
|
||||
"BuilderContext consumer must be inside FlowEditor component",
|
||||
);
|
||||
}
|
||||
|
||||
const { getNodeTitle } = builderContext;
|
||||
|
||||
// Prepare renderers for each item when enhanced mode is enabled
|
||||
const getItemRenderer = useMemo(() => {
|
||||
if (!enableEnhancedOutputHandling) return null;
|
||||
return (item: unknown) => {
|
||||
const metadata: OutputMetadata = {};
|
||||
return globalRegistry.getRenderer(item, metadata);
|
||||
};
|
||||
}, [enableEnhancedOutputHandling]);
|
||||
|
||||
const getBeautifiedPinName = (pin: string) => {
|
||||
if (!pin.startsWith("tools_^_")) {
|
||||
return beautifyString(pin);
|
||||
}
|
||||
// Special handling for tool pins: replace node ID with node title
|
||||
const toolNodeID = pin.slice(8).split("_~_")[0]; // tools_^_{node_id}_~_{field}
|
||||
const toolNodeTitle = getNodeTitle(toolNodeID);
|
||||
return toolNodeTitle
|
||||
? beautifyString(pin.replace(toolNodeID, toolNodeTitle))
|
||||
: beautifyString(pin);
|
||||
};
|
||||
|
||||
const openExpandedView = (pinName: string, pinData: any[]) => {
|
||||
setExpandedDialog({
|
||||
isOpen: true,
|
||||
execId: title || "Node Output",
|
||||
pinName,
|
||||
data: pinData,
|
||||
});
|
||||
};
|
||||
|
||||
const closeExpandedView = () => {
|
||||
setExpandedDialog(null);
|
||||
};
|
||||
return (
|
||||
<div className="m-4 space-y-4">
|
||||
{title && <strong className="mt-2flex">{title}</strong>}
|
||||
{Object.entries(data).map(([pin, dataArray]) => (
|
||||
<div key={pin} className="group">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center">
|
||||
<strong className="mr-2">Pin:</strong>
|
||||
<span>{getBeautifiedPinName(pin)}</span>
|
||||
</div>
|
||||
{(truncateLongData || dataArray.length > 10) && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => openExpandedView(pin, dataArray)}
|
||||
className="hidden items-center gap-1 group-hover:flex"
|
||||
title="Expand Full View"
|
||||
>
|
||||
<Maximize2 size={14} />
|
||||
Expand
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<strong className="mr-2">Data:</strong>
|
||||
<div className="mt-1">
|
||||
{dataArray.slice(0, 10).map((item, index) => {
|
||||
const renderer = getItemRenderer?.(item);
|
||||
if (enableEnhancedOutputHandling && renderer) {
|
||||
const metadata: OutputMetadata = {};
|
||||
return (
|
||||
<React.Fragment key={index}>
|
||||
<OutputItem
|
||||
value={item}
|
||||
metadata={metadata}
|
||||
renderer={renderer}
|
||||
/>
|
||||
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||
</React.Fragment>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<React.Fragment key={index}>
|
||||
<ContentRenderer
|
||||
value={item}
|
||||
truncateLongData={truncateLongData}
|
||||
/>
|
||||
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||
</React.Fragment>
|
||||
);
|
||||
})}
|
||||
{dataArray.length > 10 && (
|
||||
<span style={{ color: "#888" }}>
|
||||
<br />
|
||||
<b>⋮</b>
|
||||
<br />
|
||||
<span>and {dataArray.length - 10} more</span>
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<Separator.Root className="my-4 h-[1px] bg-gray-300" />
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{expandedDialog && (
|
||||
<ExpandableOutputDialog
|
||||
isOpen={expandedDialog.isOpen}
|
||||
onClose={closeExpandedView}
|
||||
execId={expandedDialog.execId}
|
||||
pinName={expandedDialog.pinName}
|
||||
data={expandedDialog.data}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,205 +0,0 @@
|
||||
import { FC, useCallback, useEffect, useState } from "react";
|
||||
|
||||
import NodeHandle from "@/app/(platform)/build/components/legacy-builder/NodeHandle";
|
||||
import type {
|
||||
BlockIOTableSubSchema,
|
||||
TableCellValue,
|
||||
TableRow,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import type { ConnectedEdge } from "./CustomNode/CustomNode";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { PlusIcon, XIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
|
||||
interface NodeTableInputProps {
|
||||
/** Unique identifier for the node in the builder graph */
|
||||
nodeId: string;
|
||||
/** Key identifier for this specific input field within the node */
|
||||
selfKey: string;
|
||||
/** Schema definition for the table structure */
|
||||
schema: BlockIOTableSubSchema;
|
||||
/** Column headers for the table */
|
||||
headers: string[];
|
||||
/** Initial row data for the table */
|
||||
rows?: TableRow[];
|
||||
/** Validation errors mapped by field key */
|
||||
errors: { [key: string]: string | undefined };
|
||||
/** Graph connections between nodes in the builder */
|
||||
connections: ConnectedEdge[];
|
||||
/** Callback when table data changes */
|
||||
handleInputChange: (key: string, value: TableRow[]) => void;
|
||||
/** Callback when input field is clicked (for builder selection) */
|
||||
handleInputClick: (key: string) => void;
|
||||
/** Additional CSS classes */
|
||||
className?: string;
|
||||
/** Display name for the input field */
|
||||
displayName?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Table input component for the workflow builder interface.
|
||||
*
|
||||
* This component is specifically designed for use in the agent builder where users
|
||||
* design workflows with connected nodes. It includes graph connection capabilities
|
||||
* via NodeHandle and is tightly integrated with the builder's state management.
|
||||
*
|
||||
* @warning Do NOT use this component in runtime/execution contexts (like RunAgentInputs).
|
||||
* For runtime table inputs, use a simpler implementation without builder-specific features.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <NodeTableInput
|
||||
* nodeId="node-123"
|
||||
* selfKey="table_data"
|
||||
* schema={tableSchema}
|
||||
* headers={["Name", "Value"]}
|
||||
* rows={existingData}
|
||||
* connections={graphConnections}
|
||||
* handleInputChange={handleChange}
|
||||
* handleInputClick={handleClick}
|
||||
* errors={{}}
|
||||
* />
|
||||
* ```
|
||||
*
|
||||
* @see Used exclusively in: `/app/(platform)/build/components/legacy-builder/NodeInputs.tsx`
|
||||
*/
|
||||
export const NodeTableInput: FC<NodeTableInputProps> = ({
|
||||
nodeId,
|
||||
selfKey,
|
||||
schema,
|
||||
headers,
|
||||
rows = [],
|
||||
errors,
|
||||
connections,
|
||||
handleInputChange,
|
||||
handleInputClick: _handleInputClick,
|
||||
className,
|
||||
displayName,
|
||||
}) => {
|
||||
const [tableData, setTableData] = useState<TableRow[]>(rows);
|
||||
|
||||
// Sync with parent state when rows change
|
||||
useEffect(() => {
|
||||
setTableData(rows);
|
||||
}, [rows]);
|
||||
|
||||
const isConnected = (key: string) =>
|
||||
connections.some((c) => c.targetHandle === key && c.target === nodeId);
|
||||
|
||||
const updateTableData = useCallback(
|
||||
(newData: TableRow[]) => {
|
||||
setTableData(newData);
|
||||
handleInputChange(selfKey, newData);
|
||||
},
|
||||
[selfKey, handleInputChange],
|
||||
);
|
||||
|
||||
const updateCell = (
|
||||
rowIndex: number,
|
||||
header: string,
|
||||
value: TableCellValue,
|
||||
) => {
|
||||
const newData = [...tableData];
|
||||
if (!newData[rowIndex]) {
|
||||
newData[rowIndex] = {};
|
||||
}
|
||||
newData[rowIndex][header] = value;
|
||||
updateTableData(newData);
|
||||
};
|
||||
|
||||
const addRow = () => {
|
||||
if (!headers || headers.length === 0) {
|
||||
return;
|
||||
}
|
||||
const newRow: TableRow = {};
|
||||
headers.forEach((header) => {
|
||||
newRow[header] = "";
|
||||
});
|
||||
updateTableData([...tableData, newRow]);
|
||||
};
|
||||
|
||||
const removeRow = (index: number) => {
|
||||
const newData = tableData.filter((_, i) => i !== index);
|
||||
updateTableData(newData);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={cn("w-full space-y-2", className)}>
|
||||
<NodeHandle
|
||||
title={displayName || selfKey}
|
||||
keyName={selfKey}
|
||||
schema={schema}
|
||||
isConnected={isConnected(selfKey)}
|
||||
isRequired={false}
|
||||
side="left"
|
||||
/>
|
||||
|
||||
{!isConnected(selfKey) && (
|
||||
<div className="nodrag overflow-x-auto">
|
||||
<table className="w-full border-collapse">
|
||||
<thead>
|
||||
<tr>
|
||||
{headers.map((header, index) => (
|
||||
<th
|
||||
key={index}
|
||||
className="border border-gray-300 bg-gray-100 px-2 py-1 text-left text-sm font-medium dark:border-gray-600 dark:bg-gray-800"
|
||||
>
|
||||
{header}
|
||||
</th>
|
||||
))}
|
||||
<th className="w-10"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{tableData.map((row, rowIndex) => (
|
||||
<tr key={rowIndex}>
|
||||
{headers.map((header, colIndex) => (
|
||||
<td
|
||||
key={colIndex}
|
||||
className="border border-gray-300 p-1 dark:border-gray-600"
|
||||
>
|
||||
<Input
|
||||
id={`${selfKey}-${rowIndex}-${header}`}
|
||||
label={header}
|
||||
type="text"
|
||||
value={String(row[header] || "")}
|
||||
onChange={(e) =>
|
||||
updateCell(rowIndex, header, e.target.value)
|
||||
}
|
||||
className="h-8 w-full"
|
||||
placeholder={`Enter ${header}`}
|
||||
/>
|
||||
</td>
|
||||
))}
|
||||
<td className="p-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => removeRow(rowIndex)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
<XIcon />
|
||||
</Button>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<Button
|
||||
className="mt-2 bg-gray-200 font-normal text-black hover:text-white dark:bg-gray-700 dark:text-white dark:hover:bg-gray-600"
|
||||
onClick={addRow}
|
||||
size="small"
|
||||
>
|
||||
<PlusIcon className="mr-2" /> Add Row
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{errors[selfKey] && (
|
||||
<span className="text-sm text-red-500">{errors[selfKey]}</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,311 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { useEffect, useState, useRef } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
|
||||
import type { GraphID } from "@/lib/autogpt-server-api/types";
|
||||
import { askOtto } from "@/app/(platform)/build/actions";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { environment } from "@/services/environment";
|
||||
|
||||
interface Message {
|
||||
type: "user" | "assistant";
|
||||
content: string;
|
||||
}
|
||||
|
||||
export default function OttoChatWidget({
|
||||
graphID,
|
||||
className,
|
||||
}: {
|
||||
graphID?: GraphID;
|
||||
className?: string;
|
||||
}): React.ReactNode {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [inputValue, setInputValue] = useState("");
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [includeGraphData, setIncludeGraphData] = useState(false);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
// Add welcome message when component mounts
|
||||
if (messages.length === 0) {
|
||||
setMessages([
|
||||
{
|
||||
type: "assistant",
|
||||
content: "Hello, I am Otto! Ask me anything about AutoGPT!",
|
||||
},
|
||||
]);
|
||||
}
|
||||
}, [messages.length]);
|
||||
|
||||
useEffect(() => {
|
||||
// Scroll to bottom whenever messages change
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [messages]);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!inputValue.trim() || isProcessing) return;
|
||||
|
||||
const userMessage = inputValue.trim();
|
||||
setInputValue("");
|
||||
setIsProcessing(true);
|
||||
|
||||
// Add user message to chat
|
||||
setMessages((prev) => [...prev, { type: "user", content: userMessage }]);
|
||||
|
||||
// Add temporary processing message
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{ type: "assistant", content: "Processing your question..." },
|
||||
]);
|
||||
|
||||
const conversationHistory = messages.reduce<
|
||||
{ query: string; response: string }[]
|
||||
>((acc, msg, i, arr) => {
|
||||
if (
|
||||
msg.type === "user" &&
|
||||
i + 1 < arr.length &&
|
||||
arr[i + 1].type === "assistant" &&
|
||||
arr[i + 1].content !== "Processing your question..."
|
||||
) {
|
||||
acc.push({
|
||||
query: msg.content,
|
||||
response: arr[i + 1].content,
|
||||
});
|
||||
}
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
try {
|
||||
const data = await askOtto(
|
||||
userMessage,
|
||||
conversationHistory,
|
||||
includeGraphData,
|
||||
graphID,
|
||||
);
|
||||
|
||||
// Check if the response contains an error
|
||||
if ("error" in data && data.error === true) {
|
||||
// Handle different error types
|
||||
let errorMessage =
|
||||
"Sorry, there was an error processing your message. Please try again.";
|
||||
|
||||
if (data.answer === "Authentication required") {
|
||||
errorMessage = "Please sign in to use the chat feature.";
|
||||
} else if (data.answer === "Failed to connect to Otto service") {
|
||||
errorMessage =
|
||||
"Otto service is currently unavailable. Please try again later.";
|
||||
} else if (data.answer.includes("timed out")) {
|
||||
errorMessage = "Request timed out. Please try again later.";
|
||||
}
|
||||
|
||||
// Remove processing message and add error message
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{ type: "assistant", content: errorMessage },
|
||||
]);
|
||||
} else {
|
||||
// Remove processing message and add actual response
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{ type: "assistant", content: data.answer },
|
||||
]);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Unexpected error in chat widget:", error);
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{
|
||||
type: "assistant",
|
||||
content:
|
||||
"An unexpected error occurred. Please refresh the page and try again.",
|
||||
},
|
||||
]);
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
setIncludeGraphData(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Don't render the chat widget if we're not on the build page or in local mode
|
||||
if (environment.isLocal()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!isOpen) {
|
||||
return (
|
||||
<div className={className}>
|
||||
<button
|
||||
onClick={() => setIsOpen(true)}
|
||||
className="inline-flex h-14 w-14 items-center justify-center whitespace-nowrap rounded-2xl bg-[rgba(65,65,64,1)] text-neutral-50 shadow transition-colors hover:bg-neutral-900/90 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-neutral-950 disabled:pointer-events-none disabled:opacity-50 dark:bg-neutral-50 dark:text-neutral-900 dark:hover:bg-neutral-50/90 dark:focus-visible:ring-neutral-300"
|
||||
aria-label="Open chat widget"
|
||||
>
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
fill="none"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-[600px] w-[600px] flex-col rounded-lg border bg-background shadow-xl",
|
||||
className,
|
||||
"z-40",
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between border-b p-4">
|
||||
<h2 className="font-semibold">Otto Assistant</h2>
|
||||
<button
|
||||
onClick={() => setIsOpen(false)}
|
||||
className="text-muted-foreground transition-colors hover:text-foreground"
|
||||
aria-label="Close chat"
|
||||
>
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
className="h-5 w-5"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
fill="none"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<line x1="18" y1="6" x2="6" y2="18" />
|
||||
<line x1="6" y1="6" x2="18" y2="18" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Messages */}
|
||||
<div className="flex-1 space-y-4 overflow-y-auto p-4">
|
||||
{messages.map((message, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`flex ${message.type === "user" ? "justify-end" : "justify-start"}`}
|
||||
>
|
||||
<div
|
||||
className={`max-w-[80%] rounded-lg p-3 ${
|
||||
message.type === "user"
|
||||
? "ml-4 bg-black text-white"
|
||||
: "mr-4 bg-[#8b5cf6] text-white"
|
||||
}`}
|
||||
>
|
||||
{message.type === "user" ? (
|
||||
message.content
|
||||
) : (
|
||||
<ReactMarkdown
|
||||
className="prose prose-sm dark:prose-invert max-w-none"
|
||||
components={{
|
||||
p: ({ children }) => (
|
||||
<p className="mb-2 last:mb-0">{children}</p>
|
||||
),
|
||||
code(props) {
|
||||
const { children, className, node: _, ...rest } = props;
|
||||
const match = /language-(\w+)/.exec(className || "");
|
||||
return match ? (
|
||||
<pre className="overflow-x-auto rounded-md bg-muted-foreground/20 p-3">
|
||||
<code className="font-mono text-sm" {...rest}>
|
||||
{children}
|
||||
</code>
|
||||
</pre>
|
||||
) : (
|
||||
<code
|
||||
className="rounded-md bg-muted-foreground/20 px-1 py-0.5 font-mono text-sm"
|
||||
{...rest}
|
||||
>
|
||||
{children}
|
||||
</code>
|
||||
);
|
||||
},
|
||||
ul: ({ children }) => (
|
||||
<ul className="mb-2 list-disc pl-4 last:mb-0">
|
||||
{children}
|
||||
</ul>
|
||||
),
|
||||
ol: ({ children }) => (
|
||||
<ol className="mb-2 list-decimal pl-4 last:mb-0">
|
||||
{children}
|
||||
</ol>
|
||||
),
|
||||
li: ({ children }) => (
|
||||
<li className="mb-1 last:mb-0">{children}</li>
|
||||
),
|
||||
}}
|
||||
>
|
||||
{message.content}
|
||||
</ReactMarkdown>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
|
||||
{/* Input */}
|
||||
<form onSubmit={handleSubmit} className="border-t p-4">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={inputValue}
|
||||
onChange={(e) => setInputValue(e.target.value)}
|
||||
placeholder="Type your message..."
|
||||
className="flex-1 rounded-md border bg-background px-3 py-2 focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
disabled={isProcessing}
|
||||
/>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isProcessing}
|
||||
className="rounded-md bg-primary px-4 py-2 text-primary-foreground transition-colors hover:bg-primary/90 disabled:opacity-50"
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
{graphID && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setIncludeGraphData((prev) => !prev);
|
||||
}}
|
||||
className={`flex items-center gap-2 rounded border px-2 py-1.5 text-sm transition-all duration-200 ${
|
||||
includeGraphData
|
||||
? "border-primary/30 bg-primary/10 text-primary hover:shadow-[0_0_10px_3px_rgba(139,92,246,0.3)]"
|
||||
: "border-transparent bg-muted text-muted-foreground hover:bg-muted/80 hover:shadow-[0_0_10px_3px_rgba(139,92,246,0.15)]"
|
||||
}`}
|
||||
>
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
className="h-4 w-4"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
fill="none"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{includeGraphData
|
||||
? "Graph data will be included"
|
||||
: "Include graph data"}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
import React, { FC } from "react";
|
||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||
import { NodeExecutionResult } from "@/lib/autogpt-server-api/types";
|
||||
import DataTable from "./DataTable";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
|
||||
interface OutputModalProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
executionResults: {
|
||||
execId: string;
|
||||
data: NodeExecutionResult["output_data"];
|
||||
}[];
|
||||
}
|
||||
|
||||
const OutputModalComponent: FC<OutputModalProps> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
executionResults,
|
||||
}) => {
|
||||
if (!isOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="nodrag nowheel fixed inset-0 flex items-center justify-center bg-white bg-opacity-60">
|
||||
<div className="w-[500px] max-w-[90%] rounded-lg border-[1.5px] bg-white p-5">
|
||||
<strong>Output Data History</strong>
|
||||
<div className="my-2 max-h-[384px] flex-grow overflow-y-auto rounded-md p-2">
|
||||
{executionResults.map((data, i) => (
|
||||
<>
|
||||
<DataTable
|
||||
key={i}
|
||||
title={data.execId}
|
||||
data={data.data}
|
||||
truncateLongData={true}
|
||||
/>
|
||||
<Separator />
|
||||
</>
|
||||
))}
|
||||
</div>
|
||||
<div className="mt-2.5 flex justify-end gap-2.5">
|
||||
<Button onClick={onClose}>Close</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default OutputModalComponent;
|
||||
@@ -1,96 +0,0 @@
|
||||
import { useCallback } from "react";
|
||||
|
||||
import { AgentRunDraftView } from "@/app/(platform)/build/components/legacy-builder/agent-run-draft-view";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import type {
|
||||
CredentialsMetaInput,
|
||||
Graph,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
interface RunInputDialogProps {
|
||||
isOpen: boolean;
|
||||
doClose: () => void;
|
||||
graph: Graph;
|
||||
doRun?: (
|
||||
inputs: Record<string, any>,
|
||||
credentialsInputs: Record<string, CredentialsMetaInput>,
|
||||
) => Promise<void> | void;
|
||||
doCreateSchedule?: (
|
||||
cronExpression: string,
|
||||
scheduleName: string,
|
||||
inputs: Record<string, any>,
|
||||
credentialsInputs: Record<string, CredentialsMetaInput>,
|
||||
) => Promise<void> | void;
|
||||
}
|
||||
|
||||
export function RunnerInputDialog({
|
||||
isOpen,
|
||||
doClose,
|
||||
graph,
|
||||
doRun,
|
||||
doCreateSchedule,
|
||||
}: RunInputDialogProps) {
|
||||
const handleRun = useCallback(
|
||||
doRun
|
||||
? async (
|
||||
inputs: Record<string, any>,
|
||||
credentials_inputs: Record<string, CredentialsMetaInput>,
|
||||
) => {
|
||||
await doRun(inputs, credentials_inputs);
|
||||
doClose();
|
||||
}
|
||||
: async () => {},
|
||||
[doRun, doClose],
|
||||
);
|
||||
|
||||
const handleSchedule = useCallback(
|
||||
doCreateSchedule
|
||||
? async (
|
||||
cronExpression: string,
|
||||
scheduleName: string,
|
||||
inputs: Record<string, any>,
|
||||
credentialsInputs: Record<string, CredentialsMetaInput>,
|
||||
) => {
|
||||
await doCreateSchedule(
|
||||
cronExpression,
|
||||
scheduleName,
|
||||
inputs,
|
||||
credentialsInputs,
|
||||
);
|
||||
doClose();
|
||||
}
|
||||
: async () => {},
|
||||
[doCreateSchedule, doClose],
|
||||
);
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Run your agent"
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: (open) => {
|
||||
if (!open) doClose();
|
||||
},
|
||||
}}
|
||||
onClose={doClose}
|
||||
styling={{
|
||||
maxWidth: "56rem",
|
||||
width: "90vw",
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col p-10">
|
||||
<p className="mt-2 text-sm text-zinc-600">{graph.name}</p>
|
||||
<AgentRunDraftView
|
||||
className="p-0"
|
||||
graph={graph}
|
||||
doRun={doRun ? handleRun : undefined}
|
||||
onRun={doRun ? undefined : doClose}
|
||||
doCreateSchedule={doCreateSchedule ? handleSchedule : undefined}
|
||||
onCreateSchedule={doCreateSchedule ? undefined : doClose}
|
||||
/>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
import React from "react";
|
||||
import {
|
||||
Sheet,
|
||||
SheetContent,
|
||||
SheetHeader,
|
||||
SheetTitle,
|
||||
SheetDescription,
|
||||
} from "@/components/__legacy__/ui/sheet";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { Textarea } from "@/components/__legacy__/ui/textarea";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Clipboard } from "lucide-react";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
|
||||
export type OutputNodeInfo = {
|
||||
metadata: {
|
||||
name: string;
|
||||
description: string;
|
||||
};
|
||||
result?: any;
|
||||
};
|
||||
|
||||
interface OutputModalProps {
|
||||
isOpen: boolean;
|
||||
doClose: () => void;
|
||||
outputs: OutputNodeInfo[];
|
||||
graphExecutionError?: string | null;
|
||||
}
|
||||
|
||||
const formatOutput = (output: any): string => {
|
||||
if (typeof output === "object") {
|
||||
try {
|
||||
if (
|
||||
Array.isArray(output) &&
|
||||
output.every((item) => typeof item === "string")
|
||||
) {
|
||||
return output.join("\n").replace(/\\n/g, "\n");
|
||||
}
|
||||
return JSON.stringify(output, null, 2);
|
||||
} catch (error) {
|
||||
return `Error formatting output: ${(error as Error).message}`;
|
||||
}
|
||||
}
|
||||
if (typeof output === "string") {
|
||||
return output.replace(/\\n/g, "\n");
|
||||
}
|
||||
return String(output);
|
||||
};
|
||||
|
||||
export function RunnerOutputUI({
|
||||
isOpen,
|
||||
doClose,
|
||||
outputs,
|
||||
graphExecutionError,
|
||||
}: OutputModalProps) {
|
||||
const { toast } = useToast();
|
||||
|
||||
const copyOutput = (name: string, output: any) => {
|
||||
const formattedOutput = formatOutput(output);
|
||||
navigator.clipboard.writeText(formattedOutput).then(() => {
|
||||
toast({
|
||||
title: `"${name}" output copied to clipboard!`,
|
||||
duration: 2000,
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const adjustTextareaHeight = (textarea: HTMLTextAreaElement) => {
|
||||
textarea.style.height = "auto";
|
||||
textarea.style.height = `${textarea.scrollHeight}px`;
|
||||
};
|
||||
|
||||
return (
|
||||
<Sheet open={isOpen} onOpenChange={doClose}>
|
||||
<SheetContent
|
||||
side="right"
|
||||
className="flex h-full w-full flex-col overflow-hidden sm:max-w-[600px]"
|
||||
>
|
||||
<SheetHeader className="px-2 py-2">
|
||||
<SheetTitle className="text-xl">Run Outputs</SheetTitle>
|
||||
<SheetDescription className="mt-1 text-sm">
|
||||
View the outputs from your agent run.
|
||||
</SheetDescription>
|
||||
</SheetHeader>
|
||||
<div className="flex-grow overflow-y-auto px-2 py-2">
|
||||
<ScrollArea className="h-full overflow-auto pr-4">
|
||||
<div className="space-y-4">
|
||||
{graphExecutionError && (
|
||||
<div className="rounded-md border border-red-200 bg-red-50 p-3 dark:border-red-800 dark:bg-red-900/20">
|
||||
<p className="text-sm text-red-800 dark:text-red-200">
|
||||
<strong>Error:</strong> {graphExecutionError}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
{outputs && outputs.length > 0 ? (
|
||||
outputs.map((output, i) => (
|
||||
<div key={i} className="space-y-1">
|
||||
<Label className="text-base font-semibold">
|
||||
{output.metadata.name || "Unnamed Output"}
|
||||
</Label>
|
||||
|
||||
{output.metadata.description && (
|
||||
<Label className="block text-sm text-gray-600">
|
||||
{output.metadata.description}
|
||||
</Label>
|
||||
)}
|
||||
|
||||
<div className="group relative rounded-md bg-gray-100 p-2">
|
||||
<Button
|
||||
className="absolute right-1 top-1 z-10 m-1 hidden p-2 group-hover:block"
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() =>
|
||||
copyOutput(
|
||||
output.metadata.name || "Unnamed Output",
|
||||
output.result,
|
||||
)
|
||||
}
|
||||
title="Copy Output"
|
||||
>
|
||||
<Clipboard size={18} />
|
||||
</Button>
|
||||
<Textarea
|
||||
readOnly
|
||||
value={formatOutput(output.result ?? "No output yet")}
|
||||
className="w-full resize-none whitespace-pre-wrap break-words border-none bg-transparent text-sm"
|
||||
style={{
|
||||
height: "auto",
|
||||
minHeight: "2.5rem",
|
||||
maxHeight: "400px",
|
||||
}}
|
||||
ref={(el) => {
|
||||
if (el) {
|
||||
adjustTextareaHeight(el);
|
||||
if (el.scrollHeight > 400) {
|
||||
el.style.height = "400px";
|
||||
}
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<p>No output blocks available.</p>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</SheetContent>
|
||||
</Sheet>
|
||||
);
|
||||
}
|
||||
|
||||
export default RunnerOutputUI;
|
||||
@@ -1,117 +0,0 @@
|
||||
import React, {
|
||||
useState,
|
||||
forwardRef,
|
||||
useImperativeHandle,
|
||||
useMemo,
|
||||
} from "react";
|
||||
import { Node } from "@xyflow/react";
|
||||
import { CustomNodeData } from "@/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode";
|
||||
import {
|
||||
BlockUIType,
|
||||
CredentialsMetaInput,
|
||||
Graph,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import RunnerOutputUI, { OutputNodeInfo } from "./RunnerOutputUI";
|
||||
import { RunnerInputDialog } from "./RunnerInputUI";
|
||||
|
||||
interface RunnerUIWrapperProps {
|
||||
graph: Graph;
|
||||
nodes: Node<CustomNodeData>[];
|
||||
graphExecutionError?: string | null;
|
||||
saveAndRun: (
|
||||
inputs: Record<string, any>,
|
||||
credentialsInputs: Record<string, CredentialsMetaInput>,
|
||||
) => void;
|
||||
createRunSchedule: (
|
||||
cronExpression: string,
|
||||
scheduleName: string,
|
||||
inputs: Record<string, any>,
|
||||
credentialsInputs: Record<string, CredentialsMetaInput>,
|
||||
) => Promise<void>;
|
||||
}
|
||||
|
||||
export interface RunnerUIWrapperRef {
|
||||
openRunInputDialog: () => void;
|
||||
openRunnerOutput: () => void;
|
||||
runOrOpenInput: () => void;
|
||||
}
|
||||
|
||||
const RunnerUIWrapper = forwardRef<RunnerUIWrapperRef, RunnerUIWrapperProps>(
|
||||
(
|
||||
{ graph, nodes, graphExecutionError, saveAndRun, createRunSchedule },
|
||||
ref,
|
||||
) => {
|
||||
const [isRunInputDialogOpen, setIsRunInputDialogOpen] = useState(false);
|
||||
const [isRunnerOutputOpen, setIsRunnerOutputOpen] = useState(false);
|
||||
|
||||
const graphInputs = graph.input_schema.properties;
|
||||
|
||||
const graphOutputs = useMemo((): OutputNodeInfo[] => {
|
||||
const outputNodes = nodes.filter(
|
||||
(node) => node.data.uiType === BlockUIType.OUTPUT,
|
||||
);
|
||||
|
||||
return outputNodes.map(
|
||||
(node) =>
|
||||
({
|
||||
metadata: {
|
||||
name: node.data.hardcodedValues.name || "Output",
|
||||
description:
|
||||
node.data.hardcodedValues.description ||
|
||||
"Output from the agent",
|
||||
},
|
||||
result:
|
||||
(node.data.executionResults as any)
|
||||
?.map((result: any) => result?.data?.output)
|
||||
.join("\n--\n") || "No output yet",
|
||||
}) satisfies OutputNodeInfo,
|
||||
);
|
||||
}, [nodes]);
|
||||
|
||||
const openRunInputDialog = () => setIsRunInputDialogOpen(true);
|
||||
const openRunnerOutput = () => setIsRunnerOutputOpen(true);
|
||||
|
||||
const runOrOpenInput = () => {
|
||||
if (
|
||||
Object.keys(graphInputs).length > 0 ||
|
||||
Object.keys(graph.credentials_input_schema.properties).length > 0
|
||||
) {
|
||||
openRunInputDialog();
|
||||
} else {
|
||||
saveAndRun({}, {});
|
||||
}
|
||||
};
|
||||
|
||||
useImperativeHandle(
|
||||
ref,
|
||||
() =>
|
||||
({
|
||||
openRunInputDialog,
|
||||
openRunnerOutput,
|
||||
runOrOpenInput,
|
||||
}) satisfies RunnerUIWrapperRef,
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<RunnerInputDialog
|
||||
isOpen={isRunInputDialogOpen}
|
||||
doClose={() => setIsRunInputDialogOpen(false)}
|
||||
graph={graph}
|
||||
doRun={saveAndRun}
|
||||
doCreateSchedule={createRunSchedule}
|
||||
/>
|
||||
<RunnerOutputUI
|
||||
isOpen={isRunnerOutputOpen}
|
||||
doClose={() => setIsRunnerOutputOpen(false)}
|
||||
outputs={graphOutputs}
|
||||
graphExecutionError={graphExecutionError}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
RunnerUIWrapper.displayName = "RunnerUIWrapper";
|
||||
|
||||
export default RunnerUIWrapper;
|
||||
@@ -1,217 +0,0 @@
|
||||
import React, { useEffect, useState } from "react";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/__legacy__/ui/popover";
|
||||
import { Card, CardContent, CardFooter } from "@/components/__legacy__/ui/card";
|
||||
import { Input } from "@/components/__legacy__/ui/input";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { IconSave } from "@/components/__legacy__/ui/icons";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { CronExpressionDialog } from "@/components/contextual/CronScheduler/cron-scheduler-dialog";
|
||||
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
||||
import { CalendarClockIcon } from "lucide-react";
|
||||
|
||||
interface SaveControlProps {
|
||||
agentMeta: GraphMeta | null;
|
||||
agentName: string;
|
||||
agentDescription: string;
|
||||
agentRecommendedScheduleCron: string;
|
||||
canSave: boolean;
|
||||
onSave: () => Promise<void>;
|
||||
onNameChange: (name: string) => void;
|
||||
onDescriptionChange: (description: string) => void;
|
||||
onRecommendedScheduleCronChange: (cron: string) => void;
|
||||
pinSavePopover: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* A SaveControl component to be used within the ControlPanel. It allows the user to save the agent.
|
||||
* @param {Object} SaveControlProps - The properties of the SaveControl component.
|
||||
* @param {GraphMeta | null} SaveControlProps.agentMeta - The agent's metadata, or null if creating a new agent.
|
||||
* @param {string} SaveControlProps.agentName - The agent's name.
|
||||
* @param {string} SaveControlProps.agentDescription - The agent's description.
|
||||
* @param {boolean} SaveControlProps.canSave - Whether the button to save the agent should be enabled.
|
||||
* @param {() => void} SaveControlProps.onSave - Function to save the agent.
|
||||
* @param {(name: string) => void} SaveControlProps.onNameChange - Function to handle name changes.
|
||||
* @param {(description: string) => void} SaveControlProps.onDescriptionChange - Function to handle description changes.
|
||||
* @returns The SaveControl component.
|
||||
*/
|
||||
export const SaveControl = ({
|
||||
agentMeta,
|
||||
canSave,
|
||||
onSave,
|
||||
agentName,
|
||||
onNameChange,
|
||||
agentDescription,
|
||||
onDescriptionChange,
|
||||
agentRecommendedScheduleCron,
|
||||
onRecommendedScheduleCronChange,
|
||||
pinSavePopover,
|
||||
}: SaveControlProps) => {
|
||||
/**
|
||||
* Note for improvement:
|
||||
* At the moment we are leveraging onDescriptionChange and onNameChange to handle the changes in the description and name of the agent.
|
||||
* We should migrate this to be handled with form controls and a form library.
|
||||
*/
|
||||
|
||||
const { toast } = useToast();
|
||||
const queryClient = useQueryClient();
|
||||
const [cronScheduleDialogOpen, setCronScheduleDialogOpen] = useState(false);
|
||||
|
||||
const handleScheduleChange = (cronExpression: string) => {
|
||||
onRecommendedScheduleCronChange(cronExpression);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = async (event: KeyboardEvent) => {
|
||||
if ((event.ctrlKey || event.metaKey) && event.key === "s") {
|
||||
event.preventDefault(); // Stop the browser default action
|
||||
await onSave(); // Call your save function
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListMySubmissionsQueryKey(),
|
||||
});
|
||||
toast({
|
||||
duration: 2000,
|
||||
title: "All changes saved successfully!",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", handleKeyDown);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, [onSave, toast]);
|
||||
|
||||
return (
|
||||
<Popover open={pinSavePopover ? true : undefined}>
|
||||
<Tooltip delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
data-id="save-control-popover-trigger"
|
||||
data-testid="blocks-control-save-button"
|
||||
name="Save"
|
||||
>
|
||||
<IconSave className="dark:text-gray-300" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">Save</TooltipContent>
|
||||
</Tooltip>
|
||||
<PopoverContent
|
||||
side="right"
|
||||
sideOffset={15}
|
||||
align="start"
|
||||
data-id="save-control-popover-content"
|
||||
className="w-96 max-w-[400px]"
|
||||
>
|
||||
<Card className="border-none shadow-none dark:bg-slate-900">
|
||||
<CardContent className="p-4">
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<Label htmlFor="name" className="dark:text-gray-300">
|
||||
Name
|
||||
</Label>
|
||||
<Input
|
||||
id="name"
|
||||
placeholder="Enter your agent name"
|
||||
value={agentName}
|
||||
onChange={(e) => onNameChange(e.target.value)}
|
||||
data-id="save-control-name-input"
|
||||
data-testid="save-control-name-input"
|
||||
maxLength={100}
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label htmlFor="description" className="dark:text-gray-300">
|
||||
Description
|
||||
</Label>
|
||||
<Input
|
||||
id="description"
|
||||
placeholder="Your agent description"
|
||||
value={agentDescription}
|
||||
onChange={(e) => onDescriptionChange(e.target.value)}
|
||||
data-id="save-control-description-input"
|
||||
data-testid="save-control-description-input"
|
||||
maxLength={500}
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label className="dark:text-gray-300">
|
||||
Recommended Schedule
|
||||
</Label>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setCronScheduleDialogOpen(true)}
|
||||
className="mt-1 w-full min-w-0 justify-start text-sm"
|
||||
data-id="save-control-recommended-schedule-button"
|
||||
data-testid="save-control-recommended-schedule-button"
|
||||
>
|
||||
<CalendarClockIcon className="mr-2 h-4 w-4 flex-shrink-0" />
|
||||
<span className="min-w-0 flex-1 truncate">
|
||||
{agentRecommendedScheduleCron
|
||||
? humanizeCronExpression(agentRecommendedScheduleCron)
|
||||
: "Set schedule"}
|
||||
</span>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{agentMeta?.version && (
|
||||
<div>
|
||||
<Label htmlFor="version" className="dark:text-gray-300">
|
||||
Version
|
||||
</Label>
|
||||
<Input
|
||||
id="version"
|
||||
placeholder="Version"
|
||||
value={agentMeta?.version || "-"}
|
||||
disabled
|
||||
data-testid="save-control-version-output"
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-col items-stretch gap-2">
|
||||
<Button
|
||||
className="w-full dark:bg-slate-700 dark:text-slate-100 dark:hover:bg-slate-800"
|
||||
onClick={onSave}
|
||||
data-id="save-control-save-agent"
|
||||
data-testid="save-control-save-agent-button"
|
||||
disabled={!canSave}
|
||||
>
|
||||
Save Agent
|
||||
</Button>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</PopoverContent>
|
||||
<CronExpressionDialog
|
||||
open={cronScheduleDialogOpen}
|
||||
setOpen={setCronScheduleDialogOpen}
|
||||
onSubmit={handleScheduleChange}
|
||||
defaultCronExpression={agentRecommendedScheduleCron}
|
||||
title="Recommended Schedule"
|
||||
/>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
@@ -1,95 +0,0 @@
|
||||
import { CustomNodeData } from "./CustomNode/CustomNode";
|
||||
import { CustomEdgeData } from "./CustomEdge/CustomEdge";
|
||||
import { Edge } from "@xyflow/react";
|
||||
|
||||
type ActionType =
|
||||
| "ADD_NODE"
|
||||
| "DELETE_NODE"
|
||||
| "ADD_EDGE"
|
||||
| "DELETE_EDGE"
|
||||
| "UPDATE_NODE"
|
||||
| "MOVE_NODE"
|
||||
| "UPDATE_INPUT"
|
||||
| "UPDATE_NODE_POSITION";
|
||||
|
||||
type AddNodePayload = { node: CustomNodeData };
|
||||
type DeleteNodePayload = { nodeId: string };
|
||||
type AddEdgePayload = { edge: Edge<CustomEdgeData> };
|
||||
type DeleteEdgePayload = { edgeId: string };
|
||||
type UpdateNodePayload = { nodeId: string; newData: Partial<CustomNodeData> };
|
||||
type MoveNodePayload = { nodeId: string; position: { x: number; y: number } };
|
||||
type UpdateInputPayload = {
|
||||
nodeId: string;
|
||||
oldValues: { [key: string]: any };
|
||||
newValues: { [key: string]: any };
|
||||
};
|
||||
type UpdateNodePositionPayload = {
|
||||
nodeId: string;
|
||||
oldPosition: { x: number; y: number };
|
||||
newPosition: { x: number; y: number };
|
||||
};
|
||||
|
||||
type ActionPayload =
|
||||
| AddNodePayload
|
||||
| DeleteNodePayload
|
||||
| AddEdgePayload
|
||||
| DeleteEdgePayload
|
||||
| UpdateNodePayload
|
||||
| MoveNodePayload
|
||||
| UpdateInputPayload
|
||||
| UpdateNodePositionPayload;
|
||||
|
||||
type Action = {
|
||||
type: ActionType;
|
||||
payload: ActionPayload;
|
||||
undo: () => void;
|
||||
redo: () => void;
|
||||
};
|
||||
|
||||
class History {
|
||||
private past: Action[] = [];
|
||||
private future: Action[] = [];
|
||||
|
||||
push(action: Action) {
|
||||
this.past.push(action);
|
||||
this.future = [];
|
||||
}
|
||||
|
||||
undo() {
|
||||
const action = this.past.pop();
|
||||
if (action) {
|
||||
action.undo();
|
||||
this.future.push(action);
|
||||
}
|
||||
}
|
||||
|
||||
redo() {
|
||||
const action = this.future.pop();
|
||||
if (action) {
|
||||
action.redo();
|
||||
this.past.push(action);
|
||||
}
|
||||
}
|
||||
|
||||
canUndo(): boolean {
|
||||
return this.past.length > 0;
|
||||
}
|
||||
|
||||
canRedo(): boolean {
|
||||
return this.future.length > 0;
|
||||
}
|
||||
|
||||
clear() {
|
||||
this.past = [];
|
||||
this.future = [];
|
||||
}
|
||||
|
||||
getHistoryState() {
|
||||
return {
|
||||
past: [...this.past],
|
||||
future: [...this.future],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export const history = new History();
|
||||
@@ -1,569 +0,0 @@
|
||||
import Shepherd from "shepherd.js";
|
||||
import "shepherd.js/dist/css/shepherd.css";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { analytics } from "@/services/analytics";
|
||||
|
||||
export const startTutorial = (
|
||||
emptyNodeList: (forceEmpty: boolean) => boolean,
|
||||
setPinBlocksPopover: (value: boolean) => void,
|
||||
setPinSavePopover: (value: boolean) => void,
|
||||
) => {
|
||||
const tour = new Shepherd.Tour({
|
||||
useModalOverlay: true,
|
||||
defaultStepOptions: {
|
||||
cancelIcon: { enabled: true },
|
||||
scrollTo: { behavior: "smooth", block: "center" },
|
||||
},
|
||||
});
|
||||
|
||||
// CSS classes for disabling and highlighting blocks
|
||||
const disableClass = "disable-blocks";
|
||||
const highlightClass = "highlight-block";
|
||||
let isConnecting = false;
|
||||
|
||||
// Helper function to disable all blocks except the target block
|
||||
const disableOtherBlocks = (targetBlockSelector: string) => {
|
||||
document.querySelectorAll('[data-id^="block-card-"]').forEach((block) => {
|
||||
block.classList.toggle(disableClass, !block.matches(targetBlockSelector));
|
||||
block.classList.toggle(
|
||||
highlightClass,
|
||||
block.matches(targetBlockSelector),
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
// Helper function to enable all blocks
|
||||
const enableAllBlocks = () => {
|
||||
document.querySelectorAll('[data-id^="block-card-"]').forEach((block) => {
|
||||
block.classList.remove(disableClass, highlightClass);
|
||||
});
|
||||
};
|
||||
|
||||
// Inject CSS for disabling and highlighting blocks
|
||||
const injectStyles = () => {
|
||||
const style = document.createElement("style");
|
||||
style.textContent = `
|
||||
.${disableClass} {
|
||||
pointer-events: none;
|
||||
opacity: 0.5;
|
||||
}
|
||||
.${highlightClass} {
|
||||
background-color: #ffeb3b;
|
||||
border: 2px solid #fbc02d;
|
||||
transition: background-color 0.3s, border-color 0.3s;
|
||||
}
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
};
|
||||
|
||||
// Helper function to check if an element is present in the DOM
|
||||
const waitForElement = (selector: string): Promise<void> => {
|
||||
return new Promise((resolve) => {
|
||||
const checkElement = () => {
|
||||
if (document.querySelector(selector)) {
|
||||
resolve();
|
||||
} else {
|
||||
setTimeout(checkElement, 10);
|
||||
}
|
||||
};
|
||||
checkElement();
|
||||
});
|
||||
};
|
||||
|
||||
// Function to detect the correct connection and advance the tour
|
||||
const detectConnection = () => {
|
||||
const checkForConnection = () => {
|
||||
const correctConnection = document.querySelector(
|
||||
'[data-testid^="rf__edge-"]',
|
||||
);
|
||||
if (correctConnection) {
|
||||
tour.show("press-run-again");
|
||||
} else {
|
||||
setTimeout(checkForConnection, 100);
|
||||
}
|
||||
};
|
||||
|
||||
checkForConnection();
|
||||
};
|
||||
|
||||
// Define state management functions to handle connection state
|
||||
function startConnecting() {
|
||||
isConnecting = true;
|
||||
}
|
||||
|
||||
function stopConnecting() {
|
||||
isConnecting = false;
|
||||
}
|
||||
|
||||
// Reset connection state when revisiting the step
|
||||
function resetConnectionState() {
|
||||
stopConnecting();
|
||||
}
|
||||
|
||||
// Event handlers for mouse down and up to manage connection state
|
||||
function handleMouseDown() {
|
||||
startConnecting();
|
||||
setTimeout(() => {
|
||||
if (isConnecting) {
|
||||
tour.next();
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
// Event handler for mouse up to check if the connection was successful
|
||||
function handleMouseUp(event: { target: any }) {
|
||||
const target = event.target;
|
||||
const validConnectionPoint = document.querySelector(
|
||||
'[data-testid^="rf__node-"]:nth-child(2) [data-id$="-a-target"]',
|
||||
);
|
||||
|
||||
if (validConnectionPoint && !validConnectionPoint.contains(target)) {
|
||||
setTimeout(() => {
|
||||
if (!document.querySelector('[data-testid^="rf__edge-"]')) {
|
||||
stopConnecting();
|
||||
tour.show("connect-blocks-output");
|
||||
}
|
||||
}, 200);
|
||||
} else {
|
||||
stopConnecting();
|
||||
}
|
||||
}
|
||||
|
||||
// Define the fitViewToScreen function
|
||||
const fitViewToScreen = () => {
|
||||
const fitViewButton = document.querySelector(
|
||||
".react-flow__controls-fitview",
|
||||
) as HTMLButtonElement;
|
||||
if (fitViewButton) {
|
||||
fitViewButton.click();
|
||||
}
|
||||
};
|
||||
|
||||
injectStyles();
|
||||
|
||||
const warningText = emptyNodeList(false)
|
||||
? ""
|
||||
: "<br/><br/><b>Caution: Clicking next will start a tutorial and will clear the current flow.</b>";
|
||||
|
||||
tour.addStep({
|
||||
id: "starting-step",
|
||||
title: "Welcome to the Tutorial",
|
||||
text: `This is the AutoGPT builder! ${warningText}`,
|
||||
buttons: [
|
||||
{
|
||||
text: "Skip Tutorial",
|
||||
action: () => {
|
||||
tour.cancel(); // Ends the tour
|
||||
storage.set(Key.SHEPHERD_TOUR, "skipped"); // Set the tutorial as skipped in local storage
|
||||
},
|
||||
classes: "shepherd-button-secondary", // Optionally add a class for styling the skip button differently
|
||||
},
|
||||
{
|
||||
text: "Next",
|
||||
action: () => {
|
||||
emptyNodeList(true);
|
||||
tour.next();
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "open-block-step",
|
||||
title: "Open Blocks Menu",
|
||||
text: "Please click the block button to open the blocks menu.",
|
||||
attachTo: {
|
||||
element: '[data-id="blocks-control-popover-trigger"]',
|
||||
on: "right",
|
||||
},
|
||||
advanceOn: {
|
||||
selector: '[data-id="blocks-control-popover-trigger"]',
|
||||
event: "click",
|
||||
},
|
||||
buttons: [],
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "scroll-block-menu",
|
||||
title: "Scroll Down or Search",
|
||||
text: 'Scroll down or search in the blocks menu for the "Calculator Block" and press the block to add it.',
|
||||
attachTo: {
|
||||
element: '[data-id="blocks-control-popover-content"]',
|
||||
on: "right",
|
||||
},
|
||||
buttons: [],
|
||||
beforeShowPromise: () =>
|
||||
waitForElement('[data-id="blocks-control-popover-content"]').then(() => {
|
||||
disableOtherBlocks(
|
||||
'[data-id="block-card-b1ab9b19-67a6-406d-abf5-2dba76d00c79"]',
|
||||
);
|
||||
}),
|
||||
advanceOn: {
|
||||
selector: '[data-id="block-card-b1ab9b19-67a6-406d-abf5-2dba76d00c79"]',
|
||||
event: "click",
|
||||
},
|
||||
when: {
|
||||
show: () => setPinBlocksPopover(true),
|
||||
hide: enableAllBlocks,
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "focus-new-block",
|
||||
title: "New Block",
|
||||
text: "This is the Calculator Block! Let's go over how it works.",
|
||||
attachTo: { element: `[data-id="custom-node-1"]`, on: "left" },
|
||||
beforeShowPromise: () => waitForElement('[data-id="custom-node-1"]'),
|
||||
buttons: [
|
||||
{
|
||||
text: "Next",
|
||||
action: tour.next,
|
||||
},
|
||||
],
|
||||
when: {
|
||||
show: () => {
|
||||
setPinBlocksPopover(false);
|
||||
setTimeout(() => {
|
||||
fitViewToScreen();
|
||||
}, 100);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "input-to-block",
|
||||
title: "Input to the Block",
|
||||
text: "This is the input pin for the block. You can input the output of other blocks here; this block takes numbers as input.",
|
||||
attachTo: { element: '[data-nodeid="1"]', on: "left" },
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: tour.back,
|
||||
},
|
||||
{
|
||||
text: "Next",
|
||||
action: tour.next,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "output-from-block",
|
||||
title: "Output from the Block",
|
||||
text: "This is the output pin for the block. You can connect this to another block to pass the output along.",
|
||||
attachTo: { element: '[data-handlepos="right"]', on: "right" },
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: tour.back,
|
||||
},
|
||||
{
|
||||
text: "Next",
|
||||
action: tour.next,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "select-operation-and-input",
|
||||
title: "Select Operation and Input Numbers",
|
||||
text: "Select any mathematical operation you'd like to perform, and enter numbers in both input fields.",
|
||||
attachTo: { element: '[data-id="input-handles"]', on: "right" },
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: tour.back,
|
||||
},
|
||||
{
|
||||
text: "Next",
|
||||
action: tour.next,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "press-initial-save-button",
|
||||
title: "Press Save",
|
||||
text: "First we need to save the flow before we can run it!",
|
||||
attachTo: {
|
||||
element: '[data-id="save-control-popover-trigger"]',
|
||||
on: "left",
|
||||
},
|
||||
advanceOn: {
|
||||
selector: '[data-id="save-control-popover-trigger"]',
|
||||
event: "click",
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: tour.back,
|
||||
},
|
||||
],
|
||||
when: {
|
||||
hide: () => setPinSavePopover(true),
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "save-agent-details",
|
||||
title: "Save the Agent",
|
||||
text: "Enter a name for your agent, add an optional description, and then click 'Save agent' to save your flow.",
|
||||
attachTo: {
|
||||
element: '[data-id="save-control-popover-content"]',
|
||||
on: "top",
|
||||
},
|
||||
buttons: [],
|
||||
beforeShowPromise: () =>
|
||||
waitForElement('[data-id="save-control-popover-content"]'),
|
||||
advanceOn: {
|
||||
selector: '[data-id="save-control-save-agent"]',
|
||||
event: "click",
|
||||
},
|
||||
when: {
|
||||
hide: () => setPinSavePopover(false),
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "press-run",
|
||||
title: "Press Run",
|
||||
text: "Start your first flow by pressing the Run button!",
|
||||
attachTo: {
|
||||
element: '[data-tutorial-id="primary-action-run-agent"]',
|
||||
on: "top",
|
||||
},
|
||||
advanceOn: {
|
||||
selector: '[data-tutorial-id="primary-action-run-agent"]',
|
||||
event: "click",
|
||||
},
|
||||
buttons: [],
|
||||
beforeShowPromise: () =>
|
||||
waitForElement('[data-tutorial-id="primary-action-run-agent"]'),
|
||||
when: {
|
||||
hide: () => {
|
||||
setTimeout(() => {
|
||||
fitViewToScreen();
|
||||
}, 500);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "wait-for-processing",
|
||||
title: "Processing",
|
||||
text: "Let's wait for the block to finish being processed...",
|
||||
attachTo: {
|
||||
element: '[data-id^="badge-"][data-id$="-QUEUED"]',
|
||||
on: "bottom",
|
||||
},
|
||||
buttons: [],
|
||||
beforeShowPromise: () =>
|
||||
waitForElement('[data-id^="badge-"][data-id$="-QUEUED"]').then(
|
||||
fitViewToScreen,
|
||||
),
|
||||
when: {
|
||||
show: () => {
|
||||
waitForElement('[data-id^="badge-"][data-id$="-COMPLETED"]').then(
|
||||
() => {
|
||||
tour.next();
|
||||
},
|
||||
);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "check-output",
|
||||
title: "Check the Output",
|
||||
text: "Check here to see the output of the block after running the flow.",
|
||||
attachTo: { element: '[data-id="latest-output"]', on: "top" },
|
||||
beforeShowPromise: () =>
|
||||
new Promise((resolve) => {
|
||||
setTimeout(() => {
|
||||
waitForElement('[data-id="latest-output"]').then(resolve);
|
||||
}, 100);
|
||||
}),
|
||||
buttons: [
|
||||
{
|
||||
text: "Next",
|
||||
action: tour.next,
|
||||
},
|
||||
],
|
||||
when: {
|
||||
show: () => {
|
||||
fitViewToScreen();
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "copy-paste-block",
|
||||
title: "Copy and Paste the Block",
|
||||
text: "Let’s duplicate this block. Click and hold the block with your mouse, then press Ctrl+C (Cmd+C on Mac) to copy and Ctrl+V (Cmd+V on Mac) to paste.",
|
||||
attachTo: { element: '[data-testid^="rf__node-"]', on: "top" },
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: tour.back,
|
||||
},
|
||||
],
|
||||
when: {
|
||||
show: () => {
|
||||
fitViewToScreen();
|
||||
waitForElement('[data-testid^="rf__node-"]:nth-child(2)').then(() => {
|
||||
tour.next();
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "focus-second-block",
|
||||
title: "Focus on the New Block",
|
||||
text: "This is your copied Calculator Block. Now, let’s move it to the side of the first block.",
|
||||
attachTo: { element: '[data-testid^="rf__node-"]:nth-child(2)', on: "top" },
|
||||
beforeShowPromise: () =>
|
||||
waitForElement('[data-testid^="rf__node-"]:nth-child(2)'),
|
||||
buttons: [
|
||||
{
|
||||
text: "Next",
|
||||
action: tour.next,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "connect-blocks-output",
|
||||
title: "Connect the Blocks: Output",
|
||||
text: "Now, let's connect the output of the first Calculator Block to the input of the second Calculator Block. Drag from the output pin of the first block to the input pin (A) of the second block.",
|
||||
attachTo: {
|
||||
element:
|
||||
'[data-testid^="rf__node-"]:first-child [data-id$="-result-source"]',
|
||||
on: "bottom",
|
||||
},
|
||||
|
||||
buttons: [
|
||||
{
|
||||
text: "Back",
|
||||
action: tour.back,
|
||||
},
|
||||
],
|
||||
beforeShowPromise: () => {
|
||||
return waitForElement(
|
||||
'[data-testid^="rf__node-"]:first-child [data-id$="-result-source"]',
|
||||
);
|
||||
},
|
||||
when: {
|
||||
show: () => {
|
||||
fitViewToScreen();
|
||||
resetConnectionState(); // Reset state when revisiting this step
|
||||
tour.modal.show();
|
||||
const outputPin = document.querySelector(
|
||||
'[data-testid^="rf__node-"]:first-child [data-id$="-result-source"]',
|
||||
);
|
||||
if (outputPin) {
|
||||
outputPin.addEventListener("mousedown", handleMouseDown);
|
||||
}
|
||||
},
|
||||
hide: () => {
|
||||
const outputPin = document.querySelector(
|
||||
'[data-testid^="rf__node-"]:first-child [data-id$="-result-source"]',
|
||||
);
|
||||
if (outputPin) {
|
||||
outputPin.removeEventListener("mousedown", handleMouseDown);
|
||||
}
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "connect-blocks-input",
|
||||
title: "Connect the Blocks: Input",
|
||||
text: "Now, connect the output to the input pin of the second block (A).",
|
||||
attachTo: {
|
||||
element: '[data-testid^="rf__node-"]:nth-child(2) [data-id$="-a-target"]',
|
||||
on: "top",
|
||||
},
|
||||
buttons: [],
|
||||
beforeShowPromise: () => {
|
||||
return waitForElement(
|
||||
'[data-testid^="rf__node-"]:nth-child(2) [data-id$="-a-target"]',
|
||||
).then(() => {
|
||||
detectConnection();
|
||||
});
|
||||
},
|
||||
when: {
|
||||
show: () => {
|
||||
tour.modal.show();
|
||||
document.addEventListener("mouseup", handleMouseUp, true);
|
||||
},
|
||||
hide: () => {
|
||||
tour.modal.hide();
|
||||
document.removeEventListener("mouseup", handleMouseUp, true);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "press-run-again",
|
||||
title: "Press Run Again",
|
||||
text: "Now, press the Run button again to execute the flow with the new Calculator Block added!",
|
||||
attachTo: {
|
||||
element: '[data-tutorial-id="primary-action-run-agent"]',
|
||||
on: "top",
|
||||
},
|
||||
advanceOn: {
|
||||
selector: '[data-tutorial-id="primary-action-run-agent"]',
|
||||
event: "click",
|
||||
},
|
||||
buttons: [],
|
||||
beforeShowPromise: () =>
|
||||
waitForElement('[data-tutorial-id="primary-action-run-agent"]'),
|
||||
when: {
|
||||
hide: () => {
|
||||
setTimeout(() => {
|
||||
fitViewToScreen();
|
||||
}, 500);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
tour.addStep({
|
||||
id: "congratulations",
|
||||
title: "Congratulations!",
|
||||
text: "You have successfully created your first flow. Watch for the outputs in the blocks!",
|
||||
beforeShowPromise: () => waitForElement('[data-id="latest-output"]'),
|
||||
when: {
|
||||
show: () => tour.modal.hide(),
|
||||
},
|
||||
buttons: [
|
||||
{
|
||||
text: "Finish",
|
||||
action: tour.complete,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Unpin blocks and save menu when the tour is completed or canceled
|
||||
tour.on("complete", () => {
|
||||
setPinBlocksPopover(false);
|
||||
setPinSavePopover(false);
|
||||
storage.set(Key.SHEPHERD_TOUR, "completed"); // Optionally mark the tutorial as completed
|
||||
});
|
||||
|
||||
for (const step of tour.steps) {
|
||||
step.on("show", () => {
|
||||
"use client";
|
||||
console.debug("sendTutorialStep");
|
||||
|
||||
analytics.sendGAEvent("event", "tutorial_step_shown", { value: step.id });
|
||||
});
|
||||
}
|
||||
|
||||
tour.on("cancel", () => {
|
||||
setPinBlocksPopover(false);
|
||||
setPinSavePopover(false);
|
||||
storage.set(Key.SHEPHERD_TOUR, "canceled"); // Optionally mark the tutorial as canceled
|
||||
});
|
||||
|
||||
tour.start();
|
||||
};
|
||||
@@ -1,142 +0,0 @@
|
||||
import { useCallback } from "react";
|
||||
import { Node, Edge, useReactFlow } from "@xyflow/react";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { ConnectedEdge } from "./CustomNode/CustomNode";
|
||||
|
||||
interface CopyableData {
|
||||
nodes: Node[];
|
||||
edges: Edge[];
|
||||
}
|
||||
|
||||
export function useCopyPaste(getNextNodeId: () => string) {
|
||||
const { setNodes, addEdges, getNodes, getEdges, getViewport } =
|
||||
useReactFlow();
|
||||
|
||||
const handleCopyPaste = useCallback(
|
||||
(event: KeyboardEvent) => {
|
||||
if (event.ctrlKey || event.metaKey) {
|
||||
if (event.key === "c" || event.key === "C") {
|
||||
const selectedNodes = getNodes().filter((node) => node.selected);
|
||||
const selectedNodeIds = new Set(selectedNodes.map((node) => node.id));
|
||||
|
||||
// Only copy edges where both source and target nodes are selected
|
||||
const selectedEdges = getEdges().filter(
|
||||
(edge) =>
|
||||
edge.selected &&
|
||||
selectedNodeIds.has(edge.source) &&
|
||||
selectedNodeIds.has(edge.target),
|
||||
);
|
||||
|
||||
const copiedData: CopyableData = {
|
||||
nodes: selectedNodes.map((node) => ({
|
||||
...node,
|
||||
data: {
|
||||
...node.data,
|
||||
connections: node.data.connections || [], // Preserve connections
|
||||
},
|
||||
})),
|
||||
edges: selectedEdges,
|
||||
};
|
||||
|
||||
storage.set(Key.COPIED_FLOW_DATA, JSON.stringify(copiedData));
|
||||
}
|
||||
if (event.key === "v" || event.key === "V") {
|
||||
const copiedDataString = storage.get(Key.COPIED_FLOW_DATA);
|
||||
if (copiedDataString) {
|
||||
const copiedData = JSON.parse(copiedDataString) as CopyableData;
|
||||
const oldToNewIdMap: Record<string, string> = {};
|
||||
|
||||
// Get fresh viewport values at paste time to ensure correct positioning
|
||||
const { x, y, zoom } = getViewport();
|
||||
const viewportCenter = {
|
||||
x: (window.innerWidth / 2 - x) / zoom,
|
||||
y: (window.innerHeight / 2 - y) / zoom,
|
||||
};
|
||||
|
||||
let minX = Infinity,
|
||||
minY = Infinity,
|
||||
maxX = -Infinity,
|
||||
maxY = -Infinity;
|
||||
copiedData.nodes.forEach((node: Node) => {
|
||||
minX = Math.min(minX, node.position.x);
|
||||
minY = Math.min(minY, node.position.y);
|
||||
maxX = Math.max(maxX, node.position.x);
|
||||
maxY = Math.max(maxY, node.position.y);
|
||||
});
|
||||
|
||||
const offsetX = viewportCenter.x - (minX + maxX) / 2;
|
||||
const offsetY = viewportCenter.y - (minY + maxY) / 2;
|
||||
|
||||
const pastedNodes = copiedData.nodes.map((node: Node) => {
|
||||
const newNodeId = getNextNodeId();
|
||||
oldToNewIdMap[node.id] = newNodeId;
|
||||
return {
|
||||
...node,
|
||||
id: newNodeId, // Generate unique ID for the pasted node
|
||||
selected: true, // Select the pasted nodes so they're visible
|
||||
position: {
|
||||
x: node.position.x + offsetX,
|
||||
y: node.position.y + offsetY,
|
||||
},
|
||||
data: {
|
||||
...node.data,
|
||||
backend_id: undefined, // Clear backend_id so the new node.id is used when saving
|
||||
connections: node.data.connections || [], // Preserve connections
|
||||
status: undefined,
|
||||
executionResults: undefined,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const pastedEdges = copiedData.edges.map((edge) => {
|
||||
const newSourceId = oldToNewIdMap[edge.source] ?? edge.source;
|
||||
const newTargetId = oldToNewIdMap[edge.target] ?? edge.target;
|
||||
return {
|
||||
...edge,
|
||||
id: `${newSourceId}_${edge.sourceHandle}_${newTargetId}_${edge.targetHandle}_${Date.now()}`,
|
||||
source: newSourceId,
|
||||
target: newTargetId,
|
||||
};
|
||||
});
|
||||
|
||||
setNodes((existingNodes) => [
|
||||
...existingNodes.map((node) => ({ ...node, selected: false })),
|
||||
...pastedNodes,
|
||||
]);
|
||||
addEdges(pastedEdges);
|
||||
|
||||
setNodes((nodes) => {
|
||||
return nodes.map((node) => {
|
||||
const nodeConnections = getEdges()
|
||||
.filter(
|
||||
(edge: Edge) =>
|
||||
edge.source === node.id || edge.target === node.id,
|
||||
)
|
||||
.map(
|
||||
(edge: Edge): ConnectedEdge => ({
|
||||
id: edge.id,
|
||||
source: edge.source,
|
||||
target: edge.target,
|
||||
sourceHandle: edge.sourceHandle!,
|
||||
targetHandle: edge.targetHandle!,
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
...node,
|
||||
data: {
|
||||
...node.data,
|
||||
connections: nodeConnections,
|
||||
},
|
||||
};
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
[setNodes, addEdges, getNodes, getEdges, getNextNodeId, getViewport],
|
||||
);
|
||||
|
||||
return handleCopyPaste;
|
||||
}
|
||||
@@ -27,6 +27,7 @@ export function CopilotPage() {
|
||||
createSession,
|
||||
onSend,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUserLoading,
|
||||
isLoggedIn,
|
||||
@@ -71,6 +72,7 @@ export function CopilotPage() {
|
||||
error={error}
|
||||
sessionId={sessionId}
|
||||
isLoadingSession={isLoadingSession}
|
||||
isSessionError={isSessionError}
|
||||
isCreatingSession={isCreatingSession}
|
||||
isReconnecting={isReconnecting}
|
||||
onCreateSession={createSession}
|
||||
|
||||
@@ -13,6 +13,7 @@ export interface ChatContainerProps {
|
||||
error: Error | undefined;
|
||||
sessionId: string | null;
|
||||
isLoadingSession: boolean;
|
||||
isSessionError?: boolean;
|
||||
isCreatingSession: boolean;
|
||||
/** True when backend has an active stream but we haven't reconnected yet. */
|
||||
isReconnecting?: boolean;
|
||||
@@ -27,6 +28,7 @@ export const ChatContainer = ({
|
||||
error,
|
||||
sessionId,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isReconnecting,
|
||||
onCreateSession,
|
||||
@@ -34,7 +36,12 @@ export const ChatContainer = ({
|
||||
onStop,
|
||||
headerSlot,
|
||||
}: ChatContainerProps) => {
|
||||
const isBusy = status === "streaming" || !!isReconnecting;
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
status === "submitted" ||
|
||||
!!isReconnecting ||
|
||||
isLoadingSession ||
|
||||
!!isSessionError;
|
||||
const inputLayoutId = "copilot-2-chat-input";
|
||||
|
||||
return (
|
||||
|
||||
@@ -10,9 +10,8 @@ import {
|
||||
MessageResponse,
|
||||
} from "@/components/ai-elements/message";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -129,7 +128,6 @@ export const ChatMessagesContainer = ({
|
||||
headerSlot,
|
||||
}: ChatMessagesContainerProps) => {
|
||||
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
|
||||
const lastToastTimeRef = useRef(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (status === "submitted") {
|
||||
@@ -137,20 +135,6 @@ export const ChatMessagesContainer = ({
|
||||
}
|
||||
}, [status]);
|
||||
|
||||
// Show a toast when a new error occurs, debounced to avoid spam
|
||||
useEffect(() => {
|
||||
if (!error) return;
|
||||
const now = Date.now();
|
||||
if (now - lastToastTimeRef.current < 3_000) return;
|
||||
lastToastTimeRef.current = now;
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: "Something went wrong",
|
||||
description:
|
||||
"The assistant encountered an error. Please try sending your message again.",
|
||||
});
|
||||
}, [error]);
|
||||
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const lastAssistantHasVisibleContent =
|
||||
lastMessage?.role === "assistant" &&
|
||||
@@ -314,13 +298,15 @@ export const ChatMessagesContainer = ({
|
||||
</Message>
|
||||
)}
|
||||
{error && (
|
||||
<div className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
|
||||
<p className="font-medium">Something went wrong</p>
|
||||
<p className="mt-1 text-red-600">
|
||||
<details className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
|
||||
<summary className="cursor-pointer font-medium">
|
||||
The assistant encountered an error. Please try sending your
|
||||
message again.
|
||||
</p>
|
||||
</div>
|
||||
</summary>
|
||||
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words text-xs text-red-600">
|
||||
{error instanceof Error ? error.message : String(error)}
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
</ConversationContent>
|
||||
<ConversationScrollButton />
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user