mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-24 03:00:28 -05:00
Compare commits
67 Commits
fix/agent-
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f95ced2b7b | ||
|
|
4e9d3aae48 | ||
|
|
e073665d26 | ||
|
|
76a924fdef | ||
|
|
b30138664d | ||
|
|
f10615c915 | ||
|
|
f6973ec3b2 | ||
|
|
5195a77ed8 | ||
|
|
6fe5b9942e | ||
|
|
da43c214be | ||
|
|
2406e21162 | ||
|
|
c388985c75 | ||
|
|
392d1c3e48 | ||
|
|
e5bdc7b8a9 | ||
|
|
8c75b5018f | ||
|
|
bf8c2eb5c0 | ||
|
|
db281f1bd0 | ||
|
|
2663ac2377 | ||
|
|
f9596b63c7 | ||
|
|
4088b2ba63 | ||
|
|
004c2d4201 | ||
|
|
f4ccdc3e9e | ||
|
|
ef42b17e3b | ||
|
|
df14b031b2 | ||
|
|
dd047d65a8 | ||
|
|
d864ab6908 | ||
|
|
5751bf35a4 | ||
|
|
b02445e755 | ||
|
|
0b3539ca89 | ||
|
|
1d042c4f10 | ||
|
|
082389d77f | ||
|
|
a18ffd0b21 | ||
|
|
e44154ba3d | ||
|
|
046a04ec19 | ||
|
|
77de834b33 | ||
|
|
6d728a0f9c | ||
|
|
7b10edb820 | ||
|
|
3c4fcf7972 | ||
|
|
d9bce8b702 | ||
|
|
171d7cfd06 | ||
|
|
942ceeb19b | ||
|
|
f4f607a865 | ||
|
|
492eb5edd6 | ||
|
|
b5c2d49260 | ||
|
|
dd10e1b339 | ||
|
|
5cca95b78c | ||
|
|
82e265db7b | ||
|
|
ad19fd7e2d | ||
|
|
dc4a994d5c | ||
|
|
3aca0d6e3b | ||
|
|
20dea80fc9 | ||
|
|
bbfcfef20b | ||
|
|
9f53e82831 | ||
|
|
285b10dab8 | ||
|
|
351e404f46 | ||
|
|
818f0daca0 | ||
|
|
48c2dcbd7f | ||
|
|
9229bc6751 | ||
|
|
8e63a1d38b | ||
|
|
73ec753b86 | ||
|
|
e781a87570 | ||
|
|
f2c9471f3e | ||
|
|
195ebf2a1a | ||
|
|
b3f34b3319 | ||
|
|
d5115de6dd | ||
|
|
e40c8c70ce | ||
|
|
9cdcd6793f |
@@ -2,21 +2,17 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
@@ -46,9 +42,6 @@ from backend.copilot.tools.models import (
|
||||
InputValidationErrorResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
@@ -99,10 +92,8 @@ class CreateSessionResponse(BaseModel):
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
turn_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
@@ -136,18 +127,9 @@ class CancelTaskResponse(BaseModel):
|
||||
"""Response model for the cancel task endpoint."""
|
||||
|
||||
cancelled: bool
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -270,7 +252,7 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
@@ -288,28 +270,21 @@ async def get_session(
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
@@ -338,39 +313,32 @@ async def cancel_session_task(
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
active_task, _ = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if not active_task:
|
||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
if not active_session:
|
||||
return CancelTaskResponse(cancelled=False, reason="no_active_task")
|
||||
|
||||
task_id = active_task.task_id
|
||||
await enqueue_cancel_task(task_id)
|
||||
logger.info(
|
||||
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
|
||||
f"session ...{session_id[-8:]}"
|
||||
)
|
||||
await enqueue_cancel_task(session_id)
|
||||
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
|
||||
|
||||
# Poll until the executor confirms the task is no longer running.
|
||||
# Keep max_wait below typical reverse-proxy read timeouts.
|
||||
poll_interval = 0.5
|
||||
max_wait = 5.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
await asyncio.sleep(poll_interval)
|
||||
waited += poll_interval
|
||||
task = await stream_registry.get_task(task_id)
|
||||
task = await stream_registry.get_session(session_id)
|
||||
if task is None or task.status != "running":
|
||||
logger.info(
|
||||
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
|
||||
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
|
||||
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
|
||||
)
|
||||
return CancelTaskResponse(cancelled=True, task_id=task_id)
|
||||
return CancelTaskResponse(cancelled=True)
|
||||
|
||||
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
|
||||
return CancelTaskResponse(
|
||||
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
|
||||
logger.warning(
|
||||
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s"
|
||||
)
|
||||
return CancelTaskResponse(cancelled=True, reason="cancel_published_not_confirmed")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -390,16 +358,15 @@ async def stream_chat_post(
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
@@ -446,21 +413,19 @@ async def stream_chat_post(
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
log_meta["task_id"] = task_id
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
await stream_registry.create_session_task(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||
f"[TIMING] create_session_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -469,12 +434,14 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
@@ -491,7 +458,7 @@ async def stream_chat_post(
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -499,11 +466,12 @@ async def stream_chat_post(
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -586,19 +554,19 @@ async def stream_chat_post(
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -645,17 +613,22 @@ async def resume_session_stream(
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_task:
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=active_task.task_id,
|
||||
# Subscribe from the beginning ("0-0") to replay all chunks for this turn.
|
||||
# This is necessary because hydrated messages filter out incomplete tool calls
|
||||
# to avoid "No tool invocation found" errors. The resume stream delivers
|
||||
# those tool calls fresh with proper SDK state.
|
||||
# The AI SDK's deduplication will handle any duplicate chunks.
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||
last_message_id="0-0",
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -691,12 +664,12 @@ async def resume_session_stream(
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
active_task.task_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
@@ -747,229 +720,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@@ -1050,9 +800,6 @@ ToolResponseUnion = (
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -42,10 +42,6 @@ import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -123,21 +119,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for Redis Streams notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
|
||||
# Run the last process in the foreground.
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
finally:
|
||||
for process in processes:
|
||||
for process in reversed(processes):
|
||||
try:
|
||||
process.stop()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||
completion notifications (OperationCompleteMessage) from external services
|
||||
(like Agent Generator) and triggers the appropriate stream registry and
|
||||
chat service updates via process_operation_success/process_operation_failure.
|
||||
|
||||
Why Redis Streams instead of RabbitMQ?
|
||||
--------------------------------------
|
||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||
queue), Redis Streams was chosen for chat completion notifications because:
|
||||
|
||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||
Streams for completion notifications keeps all chat streaming infrastructure
|
||||
in one system, simplifying operations and reducing cross-system coordination.
|
||||
|
||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||
allowing consumers to replay missed messages after reconnection. This aligns
|
||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||
|
||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||
recovering from dead consumers - ideal for the completion callback pattern.
|
||||
|
||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||
|
||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||
transactional semantics without distributed coordination.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||
stale pending messages from dead consumers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
Database operations are handled through the chat_db() accessor, which
|
||||
routes through DatabaseManager RPC when Prisma is not directly connected.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# First, claim any stale pending messages from dead consumers
|
||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||
# claim them using XAUTOCLAIM
|
||||
try:
|
||||
claimed_result = await redis.xautoclaim(
|
||||
name=config.stream_completion_name,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
min_idle_time=config.stream_claim_min_idle_ms,
|
||||
start_id="0-0",
|
||||
count=10,
|
||||
)
|
||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||
if claimed_result and len(claimed_result) >= 2:
|
||||
claimed_entries = claimed_result[1]
|
||||
if claimed_entries:
|
||||
logger.info(
|
||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||
)
|
||||
for entry_id, data in claimed_entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _process_entry(
|
||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Process a single stream entry and acknowledge it on success.
|
||||
|
||||
Args:
|
||||
redis: Redis client connection
|
||||
entry_id: The stream entry ID
|
||||
data: The entry data dict
|
||||
"""
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message after successful processing
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message remains in pending state and will be claimed by
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
await process_operation_success(task, message.result)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
await process_operation_failure(task, message.error)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
@@ -1,329 +0,0 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
# Keys that should be stripped from agent_json when returning in error responses
|
||||
SENSITIVE_KEYS = frozenset(
|
||||
{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api_secret",
|
||||
"password",
|
||||
"secret",
|
||||
"credentials",
|
||||
"credential",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"private_key",
|
||||
"privatekey",
|
||||
"auth",
|
||||
"authorization",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_agent_json(obj: Any) -> Any:
|
||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||
|
||||
Args:
|
||||
obj: The object to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy with sensitive keys removed/redacted
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_sanitize_agent_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class ToolMessageUpdateError(Exception):
|
||||
"""Raised when updating a tool message in the database fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Update tool message in database using the chat_db accessor.
|
||||
|
||||
Routes through DatabaseManager RPC when Prisma is not directly
|
||||
connected (e.g. in the CoPilot Executor microservice).
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails.
|
||||
"""
|
||||
try:
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=content,
|
||||
)
|
||||
if not updated:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id="
|
||||
f"{tool_call_id} in session {session_id}"
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize. Can be a dict, list, string,
|
||||
number, boolean, or None.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||
only when result is explicitly None.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result is None:
|
||||
return '{"status": "completed"}'
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
# Sanitize agent_json to remove sensitive keys before returning
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": _sanitize_agent_json(agent_json),
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task
|
||||
will be marked as failed instead of completed.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output (only substitute default when result is exactly None)
|
||||
result_output = result if result is not None else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# If this fails, we must not continue to mark the task as completed
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||
"marking as failed instead of completed"
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText="Failed to save operation result to database"),
|
||||
)
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
raise
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database
|
||||
with the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
|
||||
# Update pending operation with error
|
||||
# If this fails, we still continue to mark the task as failed
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||
"continuing with cleanup"
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -27,7 +27,6 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
@@ -37,52 +36,29 @@ class ChatConfig(BaseSettings):
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_lock_ttl: int = Field(
|
||||
default=120,
|
||||
description="TTL in seconds for stream lock (2 minutes). Short timeout allows "
|
||||
"reconnection after refresh/crash without long waits.",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
stream_claim_min_idle_ms: int = Field(
|
||||
default=60000,
|
||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
session_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
description="Prefix for session metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
turn_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
description="Prefix for turn message stream keys",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
@@ -154,14 +130,6 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
@@ -92,10 +93,9 @@ async def add_chat_message(
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> ChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
data: ChatMessageCreateInput = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
@@ -123,7 +123,7 @@ async def add_chat_message(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
PrismaChatMessage.prisma().create(data=data),
|
||||
)
|
||||
return ChatMessage.from_db(message)
|
||||
|
||||
@@ -132,58 +132,93 @@ async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[ChatMessage]:
|
||||
) -> int:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
Uses collision detection with retry: tries to create messages starting
|
||||
at start_sequence. If a unique constraint violation occurs (e.g., the
|
||||
streaming loop and long-running callback race), queries the latest
|
||||
sequence and retries with the correct offset. This avoids unnecessary
|
||||
upserts and DB queries in the common case (no collision).
|
||||
|
||||
Returns:
|
||||
Next sequence number for the next message to be inserted. This equals
|
||||
start_sequence + len(messages) and allows callers to update their
|
||||
counters even when collision detection adjusts start_sequence.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
# No messages to add - return current count
|
||||
return start_sequence
|
||||
|
||||
created_messages = []
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Single timestamp for all messages and session update
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
async with db.transaction() as tx:
|
||||
# Build all message data
|
||||
messages_data = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
# Note: create_many doesn't support nested creates, use sessionId directly
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
messages_data.append(data)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
# Run create_many and session update in parallel within transaction
|
||||
# Both use the same timestamp for consistency
|
||||
await asyncio.gather(
|
||||
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
|
||||
PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": now},
|
||||
),
|
||||
)
|
||||
|
||||
return [ChatMessage.from_db(m) for m in created_messages]
|
||||
# Return next sequence number for counter sync
|
||||
return start_sequence + len(messages)
|
||||
|
||||
except UniqueViolationError:
|
||||
if attempt < max_retries - 1:
|
||||
# Collision detected - query MAX(sequence)+1 and retry with correct offset
|
||||
logger.info(
|
||||
f"Collision detected for session {session_id} at sequence "
|
||||
f"{start_sequence}, querying DB for latest sequence"
|
||||
)
|
||||
start_sequence = await get_next_sequence(session_id)
|
||||
logger.info(
|
||||
f"Retrying batch insert with start_sequence={start_sequence}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - propagate error
|
||||
raise
|
||||
|
||||
# Should never reach here due to raise in exception handler
|
||||
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
@@ -237,10 +272,20 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
async def get_next_sequence(session_id: str) -> int:
|
||||
"""Get the next sequence number for a new message in this session.
|
||||
|
||||
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
|
||||
More robust than COUNT(*) because it's immune to deleted messages.
|
||||
|
||||
Optimized to select only the sequence column using raw SQL.
|
||||
The unique index on (sessionId, sequence) makes this query fast.
|
||||
"""
|
||||
results = await db.query_raw_with_schema(
|
||||
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
|
||||
session_id,
|
||||
)
|
||||
return 0 if not results else results[0]["sequence"] + 1
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
|
||||
@@ -181,13 +181,13 @@ class CoPilotExecutor(AppProcess):
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
for session_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
@@ -267,20 +267,20 @@ class CoPilotExecutor(AppProcess):
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
session_id = request.session_id
|
||||
if not session_id:
|
||||
logger.warning("Cancel message missing 'session_id'")
|
||||
return
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
if session_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {session_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
_, cancel_event = self.active_tasks[session_id]
|
||||
logger.info(f"Received cancel for {session_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
logger.debug(f"Cancel already set for {session_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
@@ -352,12 +352,12 @@ class CoPilotExecutor(AppProcess):
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
task_id = entry.task_id
|
||||
session_id = entry.session_id
|
||||
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
# Check for local duplicate - session is already running on this executor
|
||||
if session_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
f"Session {session_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
@@ -365,53 +365,53 @@ class CoPilotExecutor(AppProcess):
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
key=f"copilot:session:{session_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
logger.warning(
|
||||
f"Session {session_id} already running on pod {current_owner}"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
f"Could not acquire lock for {session_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
self._task_locks[session_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
f"Acquired cluster lock for {session_id}, "
|
||||
f"executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
self.active_tasks[session_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
logger.warning(f"Failed to setup execution for {session_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
if session_id in self._task_locks:
|
||||
del self._task_locks[session_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
logger.info(f"Run completed for {session_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
logger.error(f"Execution for {session_id} failed: {exec_error}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
@@ -419,10 +419,10 @@ class CoPilotExecutor(AppProcess):
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
if session_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {session_id}")
|
||||
self._task_locks[session_id].release()
|
||||
del self._task_locks[session_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
@@ -433,11 +433,11 @@ class CoPilotExecutor(AppProcess):
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
with self._active_tasks_lock:
|
||||
for task_id, (future, _) in list(self.active_tasks.items()):
|
||||
for session_id, (future, _) in list(self.active_tasks.items()):
|
||||
if future.done():
|
||||
completed_tasks.append(task_id)
|
||||
self.active_tasks.pop(task_id, None)
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
completed_tasks.append(session_id)
|
||||
self.active_tasks.pop(session_id, None)
|
||||
logger.info(f"Cleaned up completed session {session_id}")
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
@@ -12,7 +12,7 @@ import time
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.copilot.response_model import StreamFinish, StreamFinishStep
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
@@ -151,7 +151,6 @@ class CoPilotProcessor:
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
@@ -240,48 +239,49 @@ class CoPilotProcessor:
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
entry.turn_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id,
|
||||
error_message="Operation cancelled",
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error publishing chunk {type(chunk).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await stream_registry.mark_session_completed(entry.session_id)
|
||||
log.info("Task completed successfully")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message="Task was cancelled"
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, StreamFinishStep())
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message=str(e)
|
||||
)
|
||||
except Exception as mark_err:
|
||||
logger.error(
|
||||
f"Failed to mark session {entry.session_id} as failed: {mark_err}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
|
||||
@@ -28,7 +28,7 @@ class CoPilotLogMetadata(TruncatedLogger):
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
@@ -135,18 +135,15 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
"""Chat session ID (also used for dedup/locking)"""
|
||||
|
||||
turn_id: str = ""
|
||||
"""Per-turn UUID for Redis stream isolation"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
@@ -160,40 +157,37 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
session_id: str
|
||||
"""Session ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
turn_id: str = "",
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
session_id: Chat session ID (also used for dedup/locking)
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
turn_id: Per-turn UUID for Redis stream isolation
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
@@ -207,15 +201,15 @@ async def enqueue_copilot_task(
|
||||
)
|
||||
|
||||
|
||||
async def enqueue_cancel_task(task_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot task.
|
||||
async def enqueue_cancel_task(session_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot session.
|
||||
|
||||
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
|
||||
pods receive the cancellation signal.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
event = CancelCoPilotEvent(task_id=task_id)
|
||||
event = CancelCoPilotEvent(session_id=session_id)
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key="", # FANOUT ignores routing key
|
||||
|
||||
@@ -434,8 +434,6 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
*,
|
||||
existing_message_count: int | None = None,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
@@ -443,12 +441,6 @@ async def upsert_chat_session(
|
||||
operations (e.g., background title update and main stream handler)
|
||||
attempt to upsert the same session simultaneously.
|
||||
|
||||
Args:
|
||||
existing_message_count: If provided, skip the DB query to count
|
||||
existing messages. The caller is responsible for tracking this
|
||||
accurately. Useful for incremental saves in a streaming loop
|
||||
where the caller already knows how many messages are persisted.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. The cache is still updated
|
||||
as a best-effort optimization, but the error is propagated to ensure
|
||||
@@ -459,11 +451,8 @@ async def upsert_chat_session(
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
if existing_message_count is None:
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
@@ -587,9 +576,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session_id
|
||||
)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
|
||||
@@ -331,3 +331,96 @@ def test_to_openai_messages_merges_split_assistants():
|
||||
tc_list = merged.get("tool_calls")
|
||||
assert tc_list is not None and len(list(tc_list)) == 1
|
||||
assert list(tc_list)[0]["id"] == "tc1"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Concurrent save collision detection #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_saves_collision_detection(setup_test_user, test_user_id):
|
||||
"""Test that concurrent saves from streaming loop and callback handle collisions correctly.
|
||||
|
||||
Simulates the race condition where:
|
||||
1. Streaming loop starts with saved_msg_count=5
|
||||
2. Long-running callback appends message #5 and saves
|
||||
3. Streaming loop tries to save with stale count=5
|
||||
|
||||
The collision detection should handle this gracefully.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Create a session with initial messages
|
||||
session = ChatSession.new(user_id=test_user_id)
|
||||
for i in range(3):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if i % 2 == 0 else "assistant", content=f"Message {i}"
|
||||
)
|
||||
)
|
||||
|
||||
# Save initial messages
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Simulate streaming loop and callback saving concurrently
|
||||
async def streaming_loop_save():
|
||||
"""Simulates streaming loop saving messages."""
|
||||
# Add 2 messages
|
||||
session.messages.append(ChatMessage(role="user", content="Streaming message 1"))
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content="Streaming message 2")
|
||||
)
|
||||
|
||||
# Wait a bit to let callback potentially save first
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Save (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
async def callback_save():
|
||||
"""Simulates long-running callback saving a message."""
|
||||
# Add 1 message
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content="Callback result", tool_call_id="tc1")
|
||||
)
|
||||
|
||||
# Save immediately (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
# Run both saves concurrently - one will hit collision detection
|
||||
results = await asyncio.gather(streaming_loop_save(), callback_save())
|
||||
|
||||
# Both should succeed
|
||||
assert all(r is not None for r in results)
|
||||
|
||||
# Reload session from DB to verify
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key) # Clear cache to force DB load
|
||||
|
||||
loaded_session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert loaded_session is not None
|
||||
|
||||
# Should have all 6 messages (3 initial + 2 streaming + 1 callback)
|
||||
assert len(loaded_session.messages) == 6
|
||||
|
||||
# Verify no duplicate sequences
|
||||
sequences = []
|
||||
for i, msg in enumerate(loaded_session.messages):
|
||||
# Messages should have sequential sequence numbers starting from 0
|
||||
sequences.append(i)
|
||||
|
||||
# All sequences should be unique and sequential
|
||||
assert sequences == list(range(6))
|
||||
|
||||
# Verify message content is preserved
|
||||
contents = [m.content for m in loaded_session.messages]
|
||||
assert "Message 0" in contents
|
||||
assert "Message 1" in contents
|
||||
assert "Message 2" in contents
|
||||
assert "Streaming message 1" in contents
|
||||
assert "Streaming message 2" in contents
|
||||
assert "Callback result" in contents
|
||||
|
||||
@@ -14,7 +14,6 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
# Import here to allow module-level mocking if needed
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
@@ -32,7 +31,6 @@ async def test_parallel_tool_calls_run_concurrently():
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
# Minimal session mock
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
@@ -42,7 +40,7 @@ async def test_parallel_tool_calls_run_concurrently():
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
@@ -101,7 +99,7 @@ async def test_single_tool_call_works():
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
@@ -144,7 +142,7 @@ async def test_retryable_error_propagates():
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
@@ -175,8 +173,8 @@ async def test_retryable_error_propagates():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_lock_shared():
|
||||
"""All parallel tools should receive the same lock instance."""
|
||||
async def test_session_shared_across_parallel_tools():
|
||||
"""All parallel tools should receive the same session instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
@@ -199,10 +197,10 @@ async def test_session_lock_shared():
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_locks = []
|
||||
observed_sessions = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
observed_locks.append(lock)
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
observed_sessions.append(sess)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
@@ -222,9 +220,8 @@ async def test_session_lock_shared():
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_locks) == 3
|
||||
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
|
||||
assert isinstance(observed_locks[0], asyncio.Lock)
|
||||
assert len(observed_sessions) == 3
|
||||
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -251,7 +248,7 @@ async def test_cancellation_cleans_up():
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -12,6 +14,8 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
@@ -47,7 +51,8 @@ class StreamBaseResponse(BaseModel):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
json_str = self.model_dump_json(exclude_none=True)
|
||||
return f"data: {json_str}\n\n"
|
||||
|
||||
|
||||
# ========== Message Lifecycle ==========
|
||||
@@ -58,15 +63,13 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
taskId: str | None = Field(
|
||||
sessionId: str | None = Field(
|
||||
default=None,
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
description="Session ID for SSE reconnection.",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||
import json
|
||||
|
||||
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
|
||||
data: dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"messageId": self.messageId,
|
||||
@@ -163,8 +166,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
import json
|
||||
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"toolCallId": self.toolCallId,
|
||||
|
||||
65
autogpt_platform/backend/backend/copilot/sdk/dummy.py
Normal file
65
autogpt_platform/backend/backend/copilot/sdk/dummy.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
Returns a simple streaming response with text deltas to test:
|
||||
- Streaming infrastructure works
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Finish the stream
|
||||
yield StreamFinish()
|
||||
@@ -55,13 +55,8 @@ class SDKResponseAdapter:
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
@property
|
||||
def has_unresolved_tool_calls(self) -> bool:
|
||||
"""True when there are tool calls that haven't received output yet."""
|
||||
@@ -74,7 +69,7 @@ class SDKResponseAdapter:
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||
StreamStart(messageId=self.message_id, sessionId=self.session_id)
|
||||
)
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
|
||||
@@ -37,9 +37,7 @@ from .tool_adapter import wait_for_stash
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
a = SDKResponseAdapter(message_id="msg-1")
|
||||
a.set_task_id("task-1")
|
||||
return a
|
||||
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
|
||||
|
||||
|
||||
# -- SystemMessage -----------------------------------------------------------
|
||||
@@ -51,7 +49,7 @@ def test_system_init_emits_start_and_step():
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamStart)
|
||||
assert results[0].messageId == "msg-1"
|
||||
assert results[0].taskId == "task-1"
|
||||
assert results[0].sessionId == "session-1"
|
||||
assert isinstance(results[1], StreamStartStep)
|
||||
|
||||
|
||||
|
||||
@@ -7,11 +7,12 @@ import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import stream_registry
|
||||
from ..config import ChatConfig
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
@@ -31,12 +32,7 @@ from ..response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_execute_long_running_tool_with_streaming,
|
||||
_generate_session_title,
|
||||
)
|
||||
from ..tools.models import OperationPendingResponse, OperationStartedResponse
|
||||
from ..service import _build_system_prompt, _generate_session_title
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
@@ -44,7 +40,6 @@ from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
COPILOT_TOOL_NAMES,
|
||||
SDK_DISALLOWED_TOOLS,
|
||||
LongRunningCallback,
|
||||
create_copilot_mcp_server,
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
@@ -61,6 +56,7 @@ from .transcript import (
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
@@ -81,7 +77,8 @@ class CapturedTranscript:
|
||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
|
||||
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
|
||||
_HEARTBEAT_INTERVAL = 15.0 # seconds
|
||||
# IMPORTANT: Must be less than frontend timeout (12s in useCopilotPage.ts)
|
||||
_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
|
||||
# Appended to the system prompt to inform the agent about available tools.
|
||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
||||
@@ -132,120 +129,7 @@ is delivered to the user via a background stream.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
||||
|
||||
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
||||
existing background infrastructure: stream_registry (Redis Streams),
|
||||
database persistence, and SSE reconnection. This means results survive
|
||||
page refreshes / pod restarts, and the frontend shows the proper loading
|
||||
widget with progress updates.
|
||||
|
||||
The returned callback matches the ``LongRunningCallback`` signature:
|
||||
``(tool_name, args, session) -> MCP response dict``.
|
||||
"""
|
||||
|
||||
async def _callback(
|
||||
tool_name: str, args: dict[str, Any], session: ChatSession
|
||||
) -> dict[str, Any]:
|
||||
operation_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
session_id = session.session_id
|
||||
|
||||
# --- Build user-friendly messages (matches non-SDK service) ---
|
||||
if tool_name == "create_agent":
|
||||
desc = args.get("description", "")
|
||||
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
|
||||
pending_msg = (
|
||||
f"Creating your agent: {desc_preview}"
|
||||
if desc_preview
|
||||
else "Creating agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent creation started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
elif tool_name == "edit_agent":
|
||||
changes = args.get("changes", "")
|
||||
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||
pending_msg = (
|
||||
f"Editing agent: {changes_preview}"
|
||||
if changes_preview
|
||||
else "Editing agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent edit started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
else:
|
||||
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||
started_msg = (
|
||||
f"{tool_name} started. You can close this tab - "
|
||||
"check back in a few minutes."
|
||||
)
|
||||
|
||||
# --- Register task in Redis for SSE reconnection ---
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# --- Save OperationPendingResponse to chat history ---
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
||||
bg_task = asyncio.create_task(
|
||||
_execute_long_running_tool_with_streaming(
|
||||
tool_name=tool_name,
|
||||
parameters=args,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(bg_task)
|
||||
bg_task.add_done_callback(_background_tasks.discard)
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Long-running tool {tool_name} delegated to background "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
|
||||
# --- Return OperationStartedResponse as MCP tool result ---
|
||||
# This flows through SDK → response adapter → frontend, triggering
|
||||
# the loading widget with SSE reconnection support.
|
||||
started_json = OperationStartedResponse(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
task_id=task_id,
|
||||
).model_dump_json()
|
||||
|
||||
return {
|
||||
"content": [{"type": "text", "text": started_json}],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
return _callback
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
|
||||
|
||||
def _resolve_sdk_model() -> str | None:
|
||||
@@ -527,6 +411,9 @@ async def stream_chat_completion_sdk(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Type narrowing: session is guaranteed ChatSession after the check above
|
||||
session = cast(ChatSession, session)
|
||||
|
||||
# Append the new message to the session if it's not already there
|
||||
new_message_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
@@ -563,9 +450,31 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||
message_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
stream_id = str(uuid.uuid4())
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
# Acquire stream lock to prevent concurrent streams to the same session
|
||||
lock = AsyncClusterLock(
|
||||
redis=await get_redis_async(),
|
||||
key=f"{STREAM_LOCK_PREFIX}{session_id}",
|
||||
owner_id=stream_id,
|
||||
timeout=config.stream_lock_ttl,
|
||||
)
|
||||
|
||||
lock_owner = await lock.try_acquire()
|
||||
if lock_owner != stream_id:
|
||||
# Another stream is active
|
||||
logger.warning(
|
||||
f"[SDK] Session {session_id} already has an active stream: {lock_owner}"
|
||||
)
|
||||
yield StreamError(
|
||||
errorText="Another stream is already active for this session. "
|
||||
"Please wait or stop it.",
|
||||
code="stream_already_active",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
stream_completed = False
|
||||
# Initialise variables before the try so the finally block can
|
||||
@@ -581,11 +490,7 @@ async def stream_chat_completion_sdk(
|
||||
sdk_cwd = _make_sdk_cwd(session_id)
|
||||
os.makedirs(sdk_cwd, exist_ok=True)
|
||||
|
||||
set_execution_context(
|
||||
user_id,
|
||||
session,
|
||||
long_running_callback=_build_long_running_callback(user_id),
|
||||
)
|
||||
set_execution_context(user_id, session)
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
@@ -677,7 +582,6 @@ async def stream_chat_completion_sdk(
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
current_message = message or ""
|
||||
@@ -702,8 +606,7 @@ async def stream_chat_completion_sdk(
|
||||
session_id,
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] [%s] Sending query — resume=%s, "
|
||||
"total_msgs=%d, query_len=%d",
|
||||
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
|
||||
session_id[:12],
|
||||
use_resume,
|
||||
len(session.messages),
|
||||
@@ -715,9 +618,6 @@ async def stream_chat_completion_sdk(
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
# Track persisted message count to skip DB count queries
|
||||
# on incremental saves. Initial save happened at line 545.
|
||||
saved_msg_count = len(session.messages)
|
||||
|
||||
# Use an explicit async iterator with non-cancelling heartbeats.
|
||||
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
||||
@@ -744,6 +644,8 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
if not done:
|
||||
# Timeout — emit heartbeat but keep the task alive
|
||||
# Also refresh lock TTL to keep it alive
|
||||
await lock.refresh()
|
||||
yield StreamHeartbeat()
|
||||
continue
|
||||
|
||||
@@ -753,8 +655,7 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg = done.pop().result()
|
||||
except StopAsyncIteration:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended normally "
|
||||
"(StopAsyncIteration)",
|
||||
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
|
||||
session_id[:12],
|
||||
)
|
||||
break
|
||||
@@ -891,21 +792,6 @@ async def stream_chat_completion_sdk(
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
# Save before tool execution starts so the
|
||||
# pending tool call is visible on refresh /
|
||||
# other devices.
|
||||
try:
|
||||
await upsert_chat_session(
|
||||
session,
|
||||
existing_message_count=saved_msg_count,
|
||||
)
|
||||
saved_msg_count = len(session.messages)
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Incremental save " "failed: %s",
|
||||
session_id[:12],
|
||||
save_err,
|
||||
)
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
@@ -920,20 +806,6 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
# Save after tool completes so the result is
|
||||
# visible on refresh / other devices.
|
||||
try:
|
||||
await upsert_chat_session(
|
||||
session,
|
||||
existing_message_count=saved_msg_count,
|
||||
)
|
||||
saved_msg_count = len(session.messages)
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Incremental save " "failed: %s",
|
||||
session_id[:12],
|
||||
save_err,
|
||||
)
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
@@ -943,8 +815,7 @@ async def stream_chat_completion_sdk(
|
||||
# server shutdown). Log and let the safety-net / finally
|
||||
# blocks handle cleanup.
|
||||
logger.warning(
|
||||
"[SDK] [%s] Streaming loop cancelled "
|
||||
"(asyncio.CancelledError)",
|
||||
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
|
||||
session_id[:12],
|
||||
)
|
||||
raise
|
||||
@@ -1024,7 +895,7 @@ async def stream_chat_completion_sdk(
|
||||
elif captured_transcript.path:
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
logger.debug(
|
||||
"[SDK] Transcript source: stop hook (%s), " "read result: %s",
|
||||
"[SDK] Transcript source: stop hook (%s), read result: %s",
|
||||
captured_transcript.path,
|
||||
f"{len(raw_transcript)}B" if raw_transcript else "None",
|
||||
)
|
||||
@@ -1059,7 +930,7 @@ async def stream_chat_completion_sdk(
|
||||
"to use the OpenAI-compatible fallback."
|
||||
)
|
||||
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
session = cast(ChatSession, await asyncio.shield(upsert_chat_session(session)))
|
||||
logger.info(
|
||||
"[SDK] [%s] Session saved with %d messages",
|
||||
session_id[:12],
|
||||
@@ -1069,17 +940,31 @@ async def stream_chat_completion_sdk(
|
||||
yield StreamFinish()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnect / server shutdown — log but re-raise so
|
||||
# the framework can clean up. The finally block still runs
|
||||
# for transcript upload.
|
||||
# Client disconnect / server shutdown — save session before re-raising
|
||||
# so accumulated messages aren't lost.
|
||||
logger.warning("[SDK] [%s] Session cancelled (CancelledError)", session_id[:12])
|
||||
if session:
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
logger.info(
|
||||
"[SDK] [%s] Session saved on cancel (%d messages)",
|
||||
session_id[:12],
|
||||
len(session.messages),
|
||||
)
|
||||
except Exception as save_err:
|
||||
logger.error(
|
||||
"[SDK] [%s] Failed to save session on cancel: %s",
|
||||
session_id[:12],
|
||||
save_err,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
if session:
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="sdk_error",
|
||||
@@ -1101,7 +986,7 @@ async def stream_chat_completion_sdk(
|
||||
if not raw_transcript and use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
|
||||
if raw_transcript:
|
||||
if raw_transcript and session is not None:
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
@@ -1121,6 +1006,9 @@ async def stream_chat_completion_sdk(
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
# Release stream lock to allow new streams for this session
|
||||
await lock.release()
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
This module provides the adapter layer that converts existing BaseTool implementations
|
||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||
|
||||
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
|
||||
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
|
||||
via a callback provided by the service layer. This avoids wasteful SDK polling
|
||||
and makes results survive page refreshes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -15,7 +10,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
@@ -43,7 +37,8 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
"pending_tool_outputs",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
# Event signaled whenever stash_pending_tool_output() adds a new entry.
|
||||
# Used by the streaming loop to wait for PostToolUse hooks to complete
|
||||
@@ -54,22 +49,10 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
||||
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
||||
LongRunningCallback = Callable[
|
||||
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
|
||||
]
|
||||
|
||||
# ContextVar so the service layer can inject the callback per-request.
|
||||
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
|
||||
"long_running_callback", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
long_running_callback: LongRunningCallback | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
@@ -79,14 +62,11 @@ def set_execution_context(
|
||||
Args:
|
||||
user_id: Current user's ID.
|
||||
session: Current chat session.
|
||||
long_running_callback: Optional callback to delegate long-running tools
|
||||
to the non-SDK background infrastructure (stream_registry + Redis).
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_long_running_callback.set(long_running_callback)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
@@ -276,11 +256,6 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
|
||||
This wraps the existing BaseTool._execute method to be compatible
|
||||
with the Claude Agent SDK MCP tool format.
|
||||
|
||||
Long-running tools (``is_long_running=True``) are delegated to the
|
||||
non-SDK background infrastructure via a callback set in the execution
|
||||
context. The callback persists the operation in Redis (stream_registry)
|
||||
so results survive page refreshes and pod restarts.
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -290,25 +265,6 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
if session is None:
|
||||
return _mcp_error("No session context available")
|
||||
|
||||
# --- Long-running: delegate to non-SDK background infrastructure ---
|
||||
if base_tool.is_long_running:
|
||||
callback = _long_running_callback.get(None)
|
||||
if callback:
|
||||
try:
|
||||
return await callback(base_tool.name, args, session)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Long-running callback failed for {base_tool.name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
|
||||
# No callback — fall through to synchronous execution
|
||||
logger.warning(
|
||||
f"[SDK] No long-running callback for {base_tool.name}, "
|
||||
f"executing synchronously (may block)"
|
||||
)
|
||||
|
||||
# --- Normal (fast) tool: execute synchronously ---
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
420
autogpt_platform/backend/backend/copilot/test_copilot_e2e.py
Normal file
420
autogpt_platform/backend/backend/copilot/test_copilot_e2e.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""End-to-end tests for Copilot streaming with dummy implementations.
|
||||
|
||||
These tests verify the complete copilot flow using dummy implementations
|
||||
for agent generator and SDK service, allowing automated testing without
|
||||
external LLM calls.
|
||||
|
||||
Enable test mode with COPILOT_TEST_MODE=true environment variable.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
)
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_test_mode():
|
||||
"""Enable test mode for all tests in this module."""
|
||||
os.environ["COPILOT_TEST_MODE"] = "true"
|
||||
yield
|
||||
os.environ.pop("COPILOT_TEST_MODE", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_streaming_basic_flow():
|
||||
"""Test that dummy streaming produces correct event sequence."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-basic",
|
||||
message="Hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Verify we got events
|
||||
assert len(events) > 0, "Should receive events"
|
||||
|
||||
# Verify StreamStart
|
||||
start_events = [e for e in events if isinstance(e, StreamStart)]
|
||||
assert len(start_events) == 1
|
||||
assert start_events[0].messageId
|
||||
assert start_events[0].sessionId
|
||||
|
||||
# Verify StreamTextDelta events
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) > 0
|
||||
full_text = "".join(e.delta for e in text_events)
|
||||
assert len(full_text) > 0
|
||||
|
||||
# Verify StreamFinish
|
||||
finish_events = [e for e in events if isinstance(e, StreamFinish)]
|
||||
assert len(finish_events) == 1
|
||||
|
||||
# Verify order: start before text before finish
|
||||
start_idx = events.index(start_events[0])
|
||||
finish_idx = events.index(finish_events[0])
|
||||
first_text_idx = events.index(text_events[0]) if text_events else -1
|
||||
if first_text_idx >= 0:
|
||||
assert start_idx < first_text_idx < finish_idx
|
||||
|
||||
print(f"✅ Basic flow: {len(events)} events, {len(text_events)} text deltas")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_no_timeout():
|
||||
"""Test that streaming completes within reasonable time without timeout."""
|
||||
import time
|
||||
|
||||
start_time = time.monotonic()
|
||||
event_count = 0
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-timeout",
|
||||
message="count to 10",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
|
||||
# Should complete in < 5 seconds (dummy has 0.1s delays between words)
|
||||
assert elapsed < 5.0, f"Streaming took {elapsed:.1f}s, expected < 5s"
|
||||
assert event_count > 0, "Should receive events"
|
||||
|
||||
print(f"✅ No timeout: completed in {elapsed:.2f}s with {event_count} events")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_event_types():
|
||||
"""Test that all expected event types are present."""
|
||||
event_types = set()
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-types",
|
||||
message="test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
event_types.add(type(event).__name__)
|
||||
|
||||
# Required event types
|
||||
assert "StreamStart" in event_types, "Missing StreamStart"
|
||||
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
|
||||
assert "StreamFinish" in event_types, "Missing StreamFinish"
|
||||
|
||||
print(f"✅ Event types: {sorted(event_types)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_text_content():
|
||||
"""Test that streamed text is coherent and complete."""
|
||||
text_events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-content",
|
||||
message="count to 3",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
text_events.append(event)
|
||||
|
||||
# Verify text deltas
|
||||
assert len(text_events) > 0, "Should have text deltas"
|
||||
|
||||
# Reconstruct full text
|
||||
full_text = "".join(e.delta for e in text_events)
|
||||
assert len(full_text) > 0, "Text should not be empty"
|
||||
assert (
|
||||
"1" in full_text or "counted" in full_text.lower()
|
||||
), "Text should contain count"
|
||||
|
||||
# Verify all deltas have IDs
|
||||
for text_event in text_events:
|
||||
assert text_event.id, "Text delta must have ID"
|
||||
assert text_event.delta, "Text delta must have content"
|
||||
|
||||
print(f"✅ Text content: '{full_text}' ({len(text_events)} deltas)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_heartbeat_timing():
|
||||
"""Test that heartbeats are sent at correct interval during long operations."""
|
||||
# This test would need a dummy that takes longer
|
||||
# For now, just verify heartbeat structure if we receive one
|
||||
heartbeats = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-session-heartbeat",
|
||||
message="test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
if isinstance(event, StreamHeartbeat):
|
||||
heartbeats.append(event)
|
||||
|
||||
# Dummy is fast, so we might not get heartbeats
|
||||
# But if we do, verify they're valid
|
||||
if heartbeats:
|
||||
print(f"✅ Heartbeat structure verified ({len(heartbeats)} received)")
|
||||
else:
|
||||
print("✅ No heartbeats (dummy executes quickly)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling():
|
||||
"""Test that errors are properly formatted and sent."""
|
||||
# This would require a dummy that can trigger errors
|
||||
# For now, just verify error event structure
|
||||
|
||||
error = StreamError(errorText="Test error", code="test_error")
|
||||
assert error.errorText == "Test error"
|
||||
assert error.code == "test_error"
|
||||
assert str(error.type.value) in ["error", "error"]
|
||||
|
||||
print("✅ Error structure verified")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_sessions():
|
||||
"""Test that multiple sessions can stream concurrently."""
|
||||
|
||||
async def stream_session(session_id: str) -> int:
|
||||
count = 0
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id=session_id,
|
||||
message="test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
# Run 3 concurrent sessions
|
||||
results = await asyncio.gather(
|
||||
stream_session("session-1"),
|
||||
stream_session("session-2"),
|
||||
stream_session("session-3"),
|
||||
)
|
||||
|
||||
# All should complete successfully
|
||||
assert all(count > 0 for count in results), "All sessions should produce events"
|
||||
print(f"✅ Concurrent sessions: {results} events each")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.xfail(
|
||||
reason="Event loop isolation issue with DB operations in tests - needs fixture refactoring"
|
||||
)
|
||||
async def test_session_state_persistence():
|
||||
"""Test that session state is maintained across multiple messages."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
session_id = f"test-session-{uuid4()}"
|
||||
user_id = "test-user"
|
||||
|
||||
# Create session with first message
|
||||
session = ChatSession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
messages=[
|
||||
ChatMessage(role="user", content="Hello"),
|
||||
ChatMessage(role="assistant", content="Hi there!"),
|
||||
],
|
||||
usage=[],
|
||||
started_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# Stream second message
|
||||
events = []
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id=session_id,
|
||||
message="How are you?",
|
||||
is_user_message=True,
|
||||
user_id=user_id,
|
||||
session=session, # Pass existing session
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Verify events were produced
|
||||
assert len(events) > 0, "Should produce events for second message"
|
||||
|
||||
# Verify we got a complete response
|
||||
finish_events = [e for e in events if isinstance(e, StreamFinish)]
|
||||
assert len(finish_events) == 1, "Should have StreamFinish"
|
||||
|
||||
print(f"✅ Session persistence: {len(events)} events for second message")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_deduplication():
|
||||
"""Test that duplicate messages are filtered out."""
|
||||
|
||||
# Simulate receiving duplicate events (e.g., from reconnection)
|
||||
events = []
|
||||
|
||||
# First stream
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-dedup-1",
|
||||
message="Hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
if isinstance(event, StreamFinish):
|
||||
break
|
||||
|
||||
# Count unique message IDs in StreamStart events
|
||||
start_events = [e for e in events if isinstance(e, StreamStart)]
|
||||
message_ids = [e.messageId for e in start_events]
|
||||
|
||||
# Verify all IDs are present
|
||||
assert len(message_ids) == len(set(message_ids)), "Message IDs should be unique"
|
||||
|
||||
print(f"✅ Deduplication: {len(events)} events, all unique")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_ordering():
|
||||
"""Test that events arrive in correct order."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-ordering",
|
||||
message="Test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Find event indices
|
||||
start_idx = next(
|
||||
(i for i, e in enumerate(events) if isinstance(e, StreamStart)), None
|
||||
)
|
||||
text_indices = [i for i, e in enumerate(events) if isinstance(e, StreamTextDelta)]
|
||||
finish_idx = next(
|
||||
(i for i, e in enumerate(events) if isinstance(e, StreamFinish)), None
|
||||
)
|
||||
|
||||
# Verify ordering
|
||||
assert start_idx is not None, "Should have StreamStart"
|
||||
assert finish_idx is not None, "Should have StreamFinish"
|
||||
assert start_idx == 0, "StreamStart should be first"
|
||||
assert finish_idx == len(events) - 1, "StreamFinish should be last"
|
||||
|
||||
if text_indices:
|
||||
assert all(
|
||||
start_idx < i < finish_idx for i in text_indices
|
||||
), "Text deltas should be between start and finish"
|
||||
|
||||
print(f"✅ Event ordering: start({start_idx}) < text < finish({finish_idx})")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_completeness():
|
||||
"""Test that stream includes all required event types."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-completeness",
|
||||
message="Complete stream test",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Check for required events
|
||||
has_start = any(isinstance(e, StreamStart) for e in events)
|
||||
has_text = any(isinstance(e, StreamTextDelta) for e in events)
|
||||
has_finish = any(isinstance(e, StreamFinish) for e in events)
|
||||
|
||||
assert has_start, "Stream must include StreamStart"
|
||||
assert has_text, "Stream must include text deltas"
|
||||
assert has_finish, "Stream must include StreamFinish"
|
||||
|
||||
# Verify exactly one start and one finish
|
||||
start_count = sum(1 for e in events if isinstance(e, StreamStart))
|
||||
finish_count = sum(1 for e in events if isinstance(e, StreamFinish))
|
||||
|
||||
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
|
||||
assert finish_count == 1, f"Should have exactly 1 StreamFinish, got {finish_count}"
|
||||
|
||||
print(
|
||||
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text, 1 finish"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_delta_consistency():
|
||||
"""Test that text deltas have consistent IDs and build coherent text."""
|
||||
text_events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-consistency",
|
||||
message="Test consistency",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
text_events.append(event)
|
||||
|
||||
# Verify all text deltas have IDs
|
||||
assert all(e.id for e in text_events), "All text deltas must have IDs"
|
||||
|
||||
# Verify all deltas have the same ID (same text block)
|
||||
if text_events:
|
||||
first_id = text_events[0].id
|
||||
assert all(
|
||||
e.id == first_id for e in text_events
|
||||
), "All text deltas should share the same block ID"
|
||||
|
||||
# Verify deltas build coherent text
|
||||
full_text = "".join(e.delta for e in text_events)
|
||||
assert len(full_text) > 0, "Deltas should build non-empty text"
|
||||
assert (
|
||||
full_text == full_text.strip()
|
||||
), "Text should not have leading/trailing whitespace artifacts"
|
||||
|
||||
print(
|
||||
f"✅ Consistency: {len(text_events)} deltas with ID '{text_events[0].id if text_events else 'N/A'}', text: '{full_text}'"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
|
||||
print("Running Copilot E2E tests with dummy implementations...")
|
||||
print("=" * 60)
|
||||
|
||||
asyncio.run(test_dummy_streaming_basic_flow())
|
||||
asyncio.run(test_streaming_no_timeout())
|
||||
asyncio.run(test_streaming_event_types())
|
||||
asyncio.run(test_streaming_text_content())
|
||||
asyncio.run(test_streaming_heartbeat_timing())
|
||||
asyncio.run(test_error_handling())
|
||||
asyncio.run(test_concurrent_sessions())
|
||||
asyncio.run(test_session_state_persistence())
|
||||
asyncio.run(test_message_deduplication())
|
||||
asyncio.run(test_event_ordering())
|
||||
asyncio.run(test_stream_completeness())
|
||||
asyncio.run(test_text_delta_consistency())
|
||||
|
||||
print("=" * 60)
|
||||
print("✅ All E2E tests passed!")
|
||||
@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .check_operation_status import CheckOperationStatusTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -47,7 +46,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"check_operation_status": CheckOperationStatusTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
|
||||
@@ -19,6 +19,7 @@ from .core import (
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
get_library_agent_by_id,
|
||||
get_library_agents_by_ids,
|
||||
get_library_agents_for_generation,
|
||||
graph_to_json,
|
||||
json_to_graph,
|
||||
@@ -49,6 +50,7 @@ __all__ = [
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
"get_library_agent_by_id",
|
||||
"get_library_agents_by_ids",
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"graph_to_json",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
@@ -78,7 +79,7 @@ AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
def _to_dict_list(
|
||||
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
@@ -190,6 +191,36 @@ async def get_library_agent_by_id(
|
||||
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||
|
||||
|
||||
async def get_library_agents_by_ids(
|
||||
user_id: str,
|
||||
agent_ids: list[str],
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Fetch multiple library agents by their IDs.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary for found agents (silently skips not found)
|
||||
"""
|
||||
agents: list[LibraryAgentSummary] = []
|
||||
for agent_id in agent_ids:
|
||||
try:
|
||||
agent = await get_library_agent_by_id(user_id, agent_id)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.debug(f"Fetched library agent by ID: {agent['name']}")
|
||||
else:
|
||||
logger.warning(f"Library agent not found for ID: {agent_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
|
||||
return agents
|
||||
|
||||
|
||||
async def get_library_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
@@ -214,10 +245,17 @@ async def get_library_agents_for_generation(
|
||||
Returns:
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
search_term = search_query.strip() if search_query else None
|
||||
if search_term and len(search_term) > 100:
|
||||
raise ValueError(
|
||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
||||
f"Please use a shorter, more specific search term."
|
||||
)
|
||||
|
||||
try:
|
||||
response = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
search_term=search_term,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
include_executions=True,
|
||||
@@ -271,9 +309,16 @@ async def search_marketplace_agents_for_generation(
|
||||
Returns:
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
search_term = search_query.strip()
|
||||
if len(search_term) > 100:
|
||||
raise ValueError(
|
||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
||||
f"Please use a shorter, more specific search term."
|
||||
)
|
||||
|
||||
try:
|
||||
response = await store_db().get_store_agents(
|
||||
search_query=search_query,
|
||||
search_query=search_term,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
)
|
||||
@@ -424,7 +469,7 @@ def extract_search_terms_from_steps(
|
||||
async def enrich_library_agents_from_steps(
|
||||
user_id: str,
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
|
||||
exclude_graph_id: str | None = None,
|
||||
include_marketplace: bool = True,
|
||||
max_additional_results: int = 10,
|
||||
@@ -448,7 +493,7 @@ async def enrich_library_agents_from_steps(
|
||||
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
if not search_terms:
|
||||
return existing_agents
|
||||
return list(existing_agents)
|
||||
|
||||
existing_ids: set[str] = set()
|
||||
existing_names: set[str] = set()
|
||||
@@ -511,7 +556,7 @@ async def enrich_library_agents_from_steps(
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
@@ -539,22 +584,16 @@ async def decompose_goal(
|
||||
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams
|
||||
completion notification)
|
||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||
and SSE delivery)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -562,13 +601,9 @@ async def generate_agent(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
if result:
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
@@ -758,9 +793,7 @@ async def get_agent_as_json(
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -773,12 +806,10 @@ async def generate_agent_patch(
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -789,8 +820,6 @@ async def generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
operation_id,
|
||||
task_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -102,10 +102,15 @@ async def generate_agent_dummy(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy agent JSON after a simulated delay."""
|
||||
logger.info("Using dummy agent generator for generate_agent (30s delay)")
|
||||
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
return _generate_dummy_agent_json()
|
||||
|
||||
@@ -115,10 +120,16 @@ async def generate_agent_patch_dummy(
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy patched agent (returns the current agent with updated description)."""
|
||||
logger.info("Using dummy agent generator for generate_agent_patch")
|
||||
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
patched = current_agent.copy()
|
||||
patched["description"] = (
|
||||
f"{current_agent.get('description', '')} (updated: {update_request})"
|
||||
|
||||
@@ -242,24 +242,18 @@ async def decompose_goal_external(
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_dummy(
|
||||
instructions, library_agents, operation_id, task_id
|
||||
)
|
||||
return await generate_agent_dummy(instructions, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
@@ -267,25 +261,9 @@ async def generate_agent_external(
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -317,8 +295,6 @@ async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
@@ -327,14 +303,14 @@ async def generate_agent_patch_external(
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_patch_dummy(
|
||||
update_request, current_agent, library_agents, operation_id, task_id
|
||||
update_request, current_agent, library_agents
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
@@ -346,25 +322,9 @@ async def generate_agent_patch_external(
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -419,6 +379,8 @@ async def customize_template_external(
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
|
||||
@@ -36,16 +36,6 @@ class BaseTool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
"""Whether this tool is long-running and should execute in background.
|
||||
|
||||
Long-running tools (like agent generation) are executed via background
|
||||
tasks to survive SSE disconnections. The result is persisted to chat
|
||||
history and visible when the user refreshes.
|
||||
"""
|
||||
return False
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""CheckOperationStatusTool — query the status of a long-running operation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperationStatusResponse(ToolResponseBase):
|
||||
"""Response for check_operation_status tool."""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STATUS
|
||||
task_id: str
|
||||
operation_id: str
|
||||
status: str # "running", "completed", "failed"
|
||||
tool_name: str | None = None
|
||||
message: str = ""
|
||||
|
||||
|
||||
class CheckOperationStatusTool(BaseTool):
|
||||
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
|
||||
|
||||
The CoPilot uses this tool to report back to the user whether an
|
||||
operation that was started earlier has completed, failed, or is still
|
||||
running.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "check_operation_status"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Check the current status of a long-running operation such as "
|
||||
"create_agent or edit_agent. Accepts either an operation_id or "
|
||||
"task_id from a previous operation_started response. "
|
||||
"Returns the current status: running, completed, or failed."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The operation_id from an operation_started response."
|
||||
),
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The task_id from an operation_started response. "
|
||||
"Used as fallback if operation_id is not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
from backend.copilot import stream_registry
|
||||
|
||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
||||
task_id = (kwargs.get("task_id") or "").strip()
|
||||
|
||||
if not operation_id and not task_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an operation_id or task_id.",
|
||||
error="missing_parameter",
|
||||
)
|
||||
|
||||
task = None
|
||||
if operation_id:
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None and task_id:
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
# Task not in Redis — it may have already expired (TTL).
|
||||
# Check conversation history for the result instead.
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Operation not found — it may have already completed and "
|
||||
"expired from the status tracker. Check the conversation "
|
||||
"history for the result."
|
||||
),
|
||||
error="not_found",
|
||||
)
|
||||
|
||||
status_messages = {
|
||||
"running": (
|
||||
f"The {task.tool_name or 'operation'} is still running. "
|
||||
"Please wait for it to complete."
|
||||
),
|
||||
"completed": (
|
||||
f"The {task.tool_name or 'operation'} has completed successfully."
|
||||
),
|
||||
"failed": f"The {task.tool_name or 'operation'} has failed.",
|
||||
}
|
||||
|
||||
return OperationStatusResponse(
|
||||
task_id=task.task_id,
|
||||
operation_id=task.operation_id,
|
||||
status=task.status,
|
||||
tool_name=task.tool_name,
|
||||
message=status_messages.get(task.status, f"Status: {task.status}"),
|
||||
)
|
||||
@@ -10,7 +10,6 @@ from .agent_generator import (
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -18,7 +17,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -40,17 +38,16 @@ class CreateAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true."
|
||||
"First generates a preview, then saves to library if save=true. "
|
||||
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
@@ -70,6 +67,15 @@ class CreateAgentTool(BaseTool):
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks. "
|
||||
"Search for relevant agents using find_library_agent first, "
|
||||
"then pass their IDs here so they can be composed into the new agent."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
@@ -97,12 +103,14 @@ class CreateAgentTool(BaseTool):
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
|
||||
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
|
||||
)
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
@@ -111,25 +119,34 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id:
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
search_query=description,
|
||||
include_marketplace=True,
|
||||
agent_ids=library_agent_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
@@ -230,10 +247,17 @@ class CreateAgentTool(BaseTool):
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] GENERATE - "
|
||||
f"success={agent_json is not None}, "
|
||||
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
@@ -276,25 +300,20 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
|
||||
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
|
||||
)
|
||||
|
||||
if not save:
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
@@ -320,6 +339,13 @@ class CreateAgentTool(BaseTool):
|
||||
agent_json, user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
|
||||
f"library_agent_id={library_agent.id}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
@@ -330,6 +356,12 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
|
||||
@@ -43,11 +43,6 @@ async def test_vague_goal_returns_suggested_goal_response(tool, session):
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
@@ -78,11 +73,6 @@ async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
@@ -120,11 +110,6 @@ async def test_clarifying_questions_returns_clarification_needed_response(
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
|
||||
@@ -46,10 +46,6 @@ class CustomizeAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -9,7 +9,6 @@ from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -17,7 +16,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -38,17 +36,16 @@ class EditAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates updates to the agent while preserving unchanged parts."
|
||||
"Generates updates to the agent while preserving unchanged parts. "
|
||||
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"functionality, search for relevant existing agents using find_library_agent "
|
||||
"that could be used as building blocks. Pass their IDs in library_agent_ids."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
@@ -74,6 +71,15 @@ class EditAgentTool(BaseTool):
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks for the changes. "
|
||||
"If adding new functionality, search for relevant agents using "
|
||||
"find_library_agent first, then pass their IDs here."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
@@ -102,13 +108,10 @@ class EditAgentTool(BaseTool):
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -132,21 +135,25 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id:
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
# Filter out the current agent being edited
|
||||
filtered_ids = [id for id in library_agent_ids if id != graph_id]
|
||||
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
search_query=changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
agent_ids=filtered_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
update_request = changes
|
||||
if context:
|
||||
@@ -157,8 +164,6 @@ class EditAgentTool(BaseTool):
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -178,19 +183,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
|
||||
@@ -366,12 +366,15 @@ class TestFindBlockFiltering:
|
||||
return_value=(search_results, len(search_results))
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
), patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=lambda bid: mock_blocks.get(bid),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=lambda bid: mock_blocks.get(bid),
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
|
||||
@@ -36,8 +36,6 @@ class ResponseType(str, Enum):
|
||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||
# Long-running operation types
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
@@ -45,8 +43,6 @@ class ResponseType(str, Enum):
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Operation status check
|
||||
OPERATION_STATUS = "operation_status"
|
||||
# Feature request types
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
@@ -420,34 +416,6 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
|
||||
|
||||
# Long-running operation models
|
||||
class OperationStartedResponse(ToolResponseBase):
|
||||
"""Response when a long-running operation has been started in the background.
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
"""Response stored in chat history while a long-running operation is executing.
|
||||
|
||||
This is persisted to the database so users see a pending state when they
|
||||
refresh before the operation completes.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class OperationInProgressResponse(ToolResponseBase):
|
||||
"""Response when an operation is already in progress.
|
||||
|
||||
@@ -459,23 +427,6 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||
consumer will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
|
||||
class WebFetchResponse(ToolResponseBase):
|
||||
"""Response for web_fetch tool."""
|
||||
|
||||
|
||||
@@ -160,9 +160,10 @@ class RunBlockTool(BaseTool):
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = (
|
||||
await self._resolve_block_credentials(user_id, block, input_data)
|
||||
)
|
||||
(
|
||||
matched_credentials,
|
||||
missing_credentials,
|
||||
) = await self._resolve_block_credentials(user_id, block, input_data)
|
||||
|
||||
# Get block schemas for details/validation
|
||||
try:
|
||||
|
||||
@@ -303,7 +303,7 @@ class DatabaseManager(AppService):
|
||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||
get_user_session_count = _(chat_db.get_user_session_count)
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
|
||||
|
||||
@@ -473,5 +473,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_chat_sessions = d.get_user_chat_sessions
|
||||
get_user_session_count = d.get_user_session_count
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_chat_session_message_count = d.get_chat_session_message_count
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -7,6 +8,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,3 +128,124 @@ class ClusterLock:
|
||||
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
|
||||
class AsyncClusterLock:
|
||||
"""Async Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(
|
||||
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
|
||||
):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
self._refresh_lock = asyncio.Lock()
|
||||
|
||||
async def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = await self.redis.set(
|
||||
self.key, self.owner_id, nx=True, ex=self.timeout
|
||||
)
|
||||
if success:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = await self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AsyncClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
async def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
|
||||
Async-safe: uses asyncio.Lock to protect _last_refresh access.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period (async-safe read)
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
async with self._refresh_lock:
|
||||
last_refresh = self._last_refresh
|
||||
is_rate_limited = (
|
||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = await self.redis.get(self.key)
|
||||
if not current_value:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if await self.redis.expire(self.key, self.timeout):
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AsyncClusterLock.refresh failed for key {self.key}: {e}")
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
async def release(self):
|
||||
"""Release the lock."""
|
||||
async with self._refresh_lock:
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
@@ -372,7 +372,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for the Agent Generator service",
|
||||
)
|
||||
agentgenerator_timeout: int = Field(
|
||||
default=600,
|
||||
default=1800,
|
||||
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||
)
|
||||
agentgenerator_use_dummy: bool = Field(
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestGenerateAgent:
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await core.generate_agent(instructions)
|
||||
|
||||
mock_external.assert_called_once_with(instructions, None, None, None)
|
||||
mock_external.assert_called_once_with(instructions, None)
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Agent"
|
||||
assert "id" in result
|
||||
@@ -173,9 +173,7 @@ class TestGenerateAgentPatch:
|
||||
current_agent = {"nodes": [], "links": []}
|
||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||
|
||||
mock_external.assert_called_once_with(
|
||||
"Add a node", current_agent, None, None, None
|
||||
)
|
||||
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration test for the requeue fix implementation.
|
||||
Tests actual RabbitMQ behavior to verify that republishing sends messages to back of queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from threading import Event
|
||||
from typing import List
|
||||
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
|
||||
|
||||
class QueueOrderTester:
|
||||
"""Helper class to test message ordering in RabbitMQ using a dedicated test queue."""
|
||||
|
||||
def __init__(self):
|
||||
self.received_messages: List[dict] = []
|
||||
self.stop_consuming = Event()
|
||||
self.queue_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
self.queue_client.connect()
|
||||
|
||||
# Use a dedicated test queue name to avoid conflicts
|
||||
self.test_queue_name = "test_requeue_ordering"
|
||||
self.test_exchange = "test_exchange"
|
||||
self.test_routing_key = "test.requeue"
|
||||
|
||||
def setup_queue(self):
|
||||
"""Set up a dedicated test queue for testing."""
|
||||
channel = self.queue_client.get_channel()
|
||||
|
||||
# Declare test exchange
|
||||
channel.exchange_declare(
|
||||
exchange=self.test_exchange, exchange_type="direct", durable=True
|
||||
)
|
||||
|
||||
# Declare test queue
|
||||
channel.queue_declare(
|
||||
queue=self.test_queue_name, durable=True, auto_delete=False
|
||||
)
|
||||
|
||||
# Bind queue to exchange
|
||||
channel.queue_bind(
|
||||
exchange=self.test_exchange,
|
||||
queue=self.test_queue_name,
|
||||
routing_key=self.test_routing_key,
|
||||
)
|
||||
|
||||
# Purge the queue to start fresh
|
||||
channel.queue_purge(self.test_queue_name)
|
||||
print(f"✅ Test queue {self.test_queue_name} setup and purged")
|
||||
|
||||
def create_test_message(self, message_id: str, user_id: str = "test-user") -> str:
|
||||
"""Create a test graph execution message."""
|
||||
return json.dumps(
|
||||
{
|
||||
"graph_exec_id": f"exec-{message_id}",
|
||||
"graph_id": f"graph-{message_id}",
|
||||
"user_id": user_id,
|
||||
"execution_context": {"timezone": "UTC"},
|
||||
"nodes_input_masks": {},
|
||||
"starting_nodes_input": [],
|
||||
}
|
||||
)
|
||||
|
||||
def publish_message(self, message: str):
|
||||
"""Publish a message to the test queue."""
|
||||
channel = self.queue_client.get_channel()
|
||||
channel.basic_publish(
|
||||
exchange=self.test_exchange,
|
||||
routing_key=self.test_routing_key,
|
||||
body=message,
|
||||
)
|
||||
|
||||
def consume_messages(self, max_messages: int = 10, timeout: float = 5.0):
|
||||
"""Consume messages and track their order."""
|
||||
|
||||
def callback(ch, method, properties, body):
|
||||
try:
|
||||
message_data = json.loads(body.decode())
|
||||
self.received_messages.append(message_data)
|
||||
ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
|
||||
if len(self.received_messages) >= max_messages:
|
||||
self.stop_consuming.set()
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {e}")
|
||||
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
|
||||
|
||||
# Use synchronous consumption with blocking
|
||||
channel = self.queue_client.get_channel()
|
||||
|
||||
# Check if there are messages in the queue first
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=self.test_queue_name, auto_ack=False
|
||||
)
|
||||
if method_frame:
|
||||
# There are messages, set up consumer
|
||||
channel.basic_nack(
|
||||
delivery_tag=method_frame.delivery_tag, requeue=True
|
||||
) # Put message back
|
||||
|
||||
# Set up consumer
|
||||
channel.basic_consume(
|
||||
queue=self.test_queue_name,
|
||||
on_message_callback=callback,
|
||||
)
|
||||
|
||||
# Consume with timeout
|
||||
start_time = time.time()
|
||||
while (
|
||||
not self.stop_consuming.is_set()
|
||||
and (time.time() - start_time) < timeout
|
||||
and len(self.received_messages) < max_messages
|
||||
):
|
||||
try:
|
||||
channel.connection.process_data_events(time_limit=0.1)
|
||||
except Exception as e:
|
||||
print(f"Error during consumption: {e}")
|
||||
break
|
||||
|
||||
# Cancel the consumer
|
||||
try:
|
||||
channel.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# No messages in queue - this might be expected for some tests
|
||||
pass
|
||||
|
||||
return self.received_messages
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up test resources."""
|
||||
try:
|
||||
channel = self.queue_client.get_channel()
|
||||
channel.queue_delete(queue=self.test_queue_name)
|
||||
channel.exchange_delete(exchange=self.test_exchange)
|
||||
print(f"✅ Test queue {self.test_queue_name} cleaned up")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Cleanup issue: {e}")
|
||||
|
||||
|
||||
def test_queue_ordering_behavior():
|
||||
"""
|
||||
Integration test to verify that our republishing method sends messages to back of queue.
|
||||
This tests the actual fix for the rate limiting queue blocking issue.
|
||||
"""
|
||||
tester = QueueOrderTester()
|
||||
|
||||
try:
|
||||
tester.setup_queue()
|
||||
|
||||
print("🧪 Testing actual RabbitMQ queue ordering behavior...")
|
||||
|
||||
# Test 1: Normal FIFO behavior
|
||||
print("1. Testing normal FIFO queue behavior")
|
||||
|
||||
# Publish messages in order: A, B, C
|
||||
msg_a = tester.create_test_message("A")
|
||||
msg_b = tester.create_test_message("B")
|
||||
msg_c = tester.create_test_message("C")
|
||||
|
||||
tester.publish_message(msg_a)
|
||||
tester.publish_message(msg_b)
|
||||
tester.publish_message(msg_c)
|
||||
|
||||
# Consume and verify FIFO order: A, B, C
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=3)
|
||||
|
||||
assert len(messages) == 3, f"Expected 3 messages, got {len(messages)}"
|
||||
assert (
|
||||
messages[0]["graph_exec_id"] == "exec-A"
|
||||
), f"First message should be A, got {messages[0]['graph_exec_id']}"
|
||||
assert (
|
||||
messages[1]["graph_exec_id"] == "exec-B"
|
||||
), f"Second message should be B, got {messages[1]['graph_exec_id']}"
|
||||
assert (
|
||||
messages[2]["graph_exec_id"] == "exec-C"
|
||||
), f"Third message should be C, got {messages[2]['graph_exec_id']}"
|
||||
|
||||
print("✅ FIFO order confirmed: A -> B -> C")
|
||||
|
||||
# Test 2: Rate limiting simulation - the key test!
|
||||
print("2. Testing rate limiting fix scenario")
|
||||
|
||||
# Simulate the scenario where user1 is rate limited
|
||||
user1_msg = tester.create_test_message("RATE-LIMITED", "user1")
|
||||
user2_msg1 = tester.create_test_message("USER2-1", "user2")
|
||||
user2_msg2 = tester.create_test_message("USER2-2", "user2")
|
||||
|
||||
# Initially publish user1 message (gets consumed, then rate limited on retry)
|
||||
tester.publish_message(user1_msg)
|
||||
|
||||
# Other users publish their messages
|
||||
tester.publish_message(user2_msg1)
|
||||
tester.publish_message(user2_msg2)
|
||||
|
||||
# Now simulate: user1 message gets "requeued" using our new republishing method
|
||||
# This is what happens in manager.py when requeue_by_republishing=True
|
||||
tester.publish_message(user1_msg) # Goes to back via our method
|
||||
|
||||
# Expected order: RATE-LIMITED, USER2-1, USER2-2, RATE-LIMITED (republished to back)
|
||||
# This shows that user2 messages get processed instead of being blocked
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=4)
|
||||
|
||||
assert len(messages) == 4, f"Expected 4 messages, got {len(messages)}"
|
||||
|
||||
# The key verification: user2 messages are NOT blocked by user1's rate-limited message
|
||||
user2_messages = [msg for msg in messages if msg["user_id"] == "user2"]
|
||||
assert len(user2_messages) == 2, "Both user2 messages should be processed"
|
||||
assert user2_messages[0]["graph_exec_id"] == "exec-USER2-1"
|
||||
assert user2_messages[1]["graph_exec_id"] == "exec-USER2-2"
|
||||
|
||||
print("✅ Rate limiting fix confirmed: user2 executions NOT blocked by user1")
|
||||
|
||||
# Test 3: Verify our method behaves like going to back of queue
|
||||
print("3. Testing republishing sends messages to back")
|
||||
|
||||
# Start with message X in queue
|
||||
msg_x = tester.create_test_message("X")
|
||||
tester.publish_message(msg_x)
|
||||
|
||||
# Add message Y
|
||||
msg_y = tester.create_test_message("Y")
|
||||
tester.publish_message(msg_y)
|
||||
|
||||
# Republish X (simulates requeue using our method)
|
||||
tester.publish_message(msg_x)
|
||||
|
||||
# Expected: X, Y, X (X was republished to back)
|
||||
tester.received_messages = []
|
||||
tester.stop_consuming.clear()
|
||||
messages = tester.consume_messages(max_messages=3)
|
||||
|
||||
assert len(messages) == 3
|
||||
# Y should come before the republished X
|
||||
y_index = next(
|
||||
i for i, msg in enumerate(messages) if msg["graph_exec_id"] == "exec-Y"
|
||||
)
|
||||
republished_x_index = next(
|
||||
i
|
||||
for i, msg in enumerate(messages[1:], 1)
|
||||
if msg["graph_exec_id"] == "exec-X"
|
||||
)
|
||||
|
||||
assert (
|
||||
y_index < republished_x_index
|
||||
), f"Y should come before republished X, but got order: {[m['graph_exec_id'] for m in messages]}"
|
||||
|
||||
print("✅ Republishing confirmed: messages go to back of queue")
|
||||
|
||||
print("🎉 All integration tests passed!")
|
||||
print("🎉 Our republishing method works correctly with real RabbitMQ")
|
||||
print("🎉 Queue blocking issue is fixed!")
|
||||
|
||||
finally:
|
||||
tester.cleanup()
|
||||
|
||||
|
||||
def test_traditional_requeue_behavior():
|
||||
"""
|
||||
Test that traditional requeue (basic_nack with requeue=True) sends messages to FRONT of queue.
|
||||
This validates our hypothesis about why queue blocking occurs.
|
||||
"""
|
||||
tester = QueueOrderTester()
|
||||
|
||||
try:
|
||||
tester.setup_queue()
|
||||
print("🧪 Testing traditional requeue behavior (basic_nack with requeue=True)")
|
||||
|
||||
# Step 1: Publish message A
|
||||
msg_a = tester.create_test_message("A")
|
||||
tester.publish_message(msg_a)
|
||||
|
||||
# Step 2: Publish message B
|
||||
msg_b = tester.create_test_message("B")
|
||||
tester.publish_message(msg_b)
|
||||
|
||||
# Step 3: Consume message A and requeue it using traditional method
|
||||
channel = tester.queue_client.get_channel()
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=False
|
||||
)
|
||||
|
||||
assert method_frame is not None, "Should have received message A"
|
||||
consumed_msg = json.loads(body.decode())
|
||||
assert (
|
||||
consumed_msg["graph_exec_id"] == "exec-A"
|
||||
), f"Should have consumed message A, got {consumed_msg['graph_exec_id']}"
|
||||
|
||||
# Traditional requeue: basic_nack with requeue=True (sends to FRONT)
|
||||
channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True)
|
||||
print(f"🔄 Traditional requeue (to FRONT): {consumed_msg['graph_exec_id']}")
|
||||
|
||||
# Step 4: Consume all messages using basic_get for reliability
|
||||
received_messages = []
|
||||
|
||||
# Get first message
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=True
|
||||
)
|
||||
if method_frame:
|
||||
msg = json.loads(body.decode())
|
||||
received_messages.append(msg)
|
||||
|
||||
# Get second message
|
||||
method_frame, header_frame, body = channel.basic_get(
|
||||
queue=tester.test_queue_name, auto_ack=True
|
||||
)
|
||||
if method_frame:
|
||||
msg = json.loads(body.decode())
|
||||
received_messages.append(msg)
|
||||
|
||||
# CRITICAL ASSERTION: Traditional requeue should put A at FRONT
|
||||
# Expected order: A (requeued to front), B
|
||||
assert (
|
||||
len(received_messages) == 2
|
||||
), f"Expected 2 messages, got {len(received_messages)}"
|
||||
|
||||
first_msg = received_messages[0]["graph_exec_id"]
|
||||
second_msg = received_messages[1]["graph_exec_id"]
|
||||
|
||||
# This is the critical test: requeued message A should come BEFORE B
|
||||
assert (
|
||||
first_msg == "exec-A"
|
||||
), f"Traditional requeue should put A at FRONT, but first message was: {first_msg}"
|
||||
assert (
|
||||
second_msg == "exec-B"
|
||||
), f"B should come after requeued A, but second message was: {second_msg}"
|
||||
|
||||
print(
|
||||
"✅ HYPOTHESIS CONFIRMED: Traditional requeue sends messages to FRONT of queue"
|
||||
)
|
||||
print(f" Order: {first_msg} (requeued to front) → {second_msg}")
|
||||
print(" This explains why rate-limited messages block other users!")
|
||||
|
||||
finally:
|
||||
tester.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_queue_ordering_behavior()
|
||||
@@ -27,6 +27,7 @@ export function CopilotPage() {
|
||||
createSession,
|
||||
onSend,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUserLoading,
|
||||
isLoggedIn,
|
||||
@@ -71,6 +72,7 @@ export function CopilotPage() {
|
||||
error={error}
|
||||
sessionId={sessionId}
|
||||
isLoadingSession={isLoadingSession}
|
||||
isSessionError={isSessionError}
|
||||
isCreatingSession={isCreatingSession}
|
||||
isReconnecting={isReconnecting}
|
||||
onCreateSession={createSession}
|
||||
|
||||
@@ -13,6 +13,7 @@ export interface ChatContainerProps {
|
||||
error: Error | undefined;
|
||||
sessionId: string | null;
|
||||
isLoadingSession: boolean;
|
||||
isSessionError?: boolean;
|
||||
isCreatingSession: boolean;
|
||||
/** True when backend has an active stream but we haven't reconnected yet. */
|
||||
isReconnecting?: boolean;
|
||||
@@ -27,6 +28,7 @@ export const ChatContainer = ({
|
||||
error,
|
||||
sessionId,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isReconnecting,
|
||||
onCreateSession,
|
||||
@@ -34,7 +36,12 @@ export const ChatContainer = ({
|
||||
onStop,
|
||||
headerSlot,
|
||||
}: ChatContainerProps) => {
|
||||
const isBusy = status === "streaming" || !!isReconnecting;
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
status === "submitted" ||
|
||||
!!isReconnecting ||
|
||||
isLoadingSession ||
|
||||
!!isSessionError;
|
||||
const inputLayoutId = "copilot-2-chat-input";
|
||||
|
||||
return (
|
||||
|
||||
@@ -10,9 +10,8 @@ import {
|
||||
MessageResponse,
|
||||
} from "@/components/ai-elements/message";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -129,7 +128,6 @@ export const ChatMessagesContainer = ({
|
||||
headerSlot,
|
||||
}: ChatMessagesContainerProps) => {
|
||||
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
|
||||
const lastToastTimeRef = useRef(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (status === "submitted") {
|
||||
@@ -137,20 +135,6 @@ export const ChatMessagesContainer = ({
|
||||
}
|
||||
}, [status]);
|
||||
|
||||
// Show a toast when a new error occurs, debounced to avoid spam
|
||||
useEffect(() => {
|
||||
if (!error) return;
|
||||
const now = Date.now();
|
||||
if (now - lastToastTimeRef.current < 3_000) return;
|
||||
lastToastTimeRef.current = now;
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: "Something went wrong",
|
||||
description:
|
||||
"The assistant encountered an error. Please try sending your message again.",
|
||||
});
|
||||
}, [error]);
|
||||
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const lastAssistantHasVisibleContent =
|
||||
lastMessage?.role === "assistant" &&
|
||||
@@ -314,13 +298,15 @@ export const ChatMessagesContainer = ({
|
||||
</Message>
|
||||
)}
|
||||
{error && (
|
||||
<div className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
|
||||
<p className="font-medium">Something went wrong</p>
|
||||
<p className="mt-1 text-red-600">
|
||||
<details className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
|
||||
<summary className="cursor-pointer font-medium">
|
||||
The assistant encountered an error. Please try sending your
|
||||
message again.
|
||||
</p>
|
||||
</div>
|
||||
</summary>
|
||||
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words text-xs text-red-600">
|
||||
{error instanceof Error ? error.message : String(error)}
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
</ConversationContent>
|
||||
<ConversationScrollButton />
|
||||
|
||||
|
Before Width: | Height: | Size: 8.0 KiB After Width: | Height: | Size: 8.0 KiB |
@@ -116,12 +116,10 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
output: "",
|
||||
});
|
||||
} else {
|
||||
parts.push({
|
||||
type: `tool-${toolName}`,
|
||||
toolCallId,
|
||||
state: "input-available",
|
||||
input,
|
||||
});
|
||||
// Active stream exists: Skip incomplete tool calls during hydration.
|
||||
// The resume stream will deliver them fresh with proper SDK state.
|
||||
// This prevents "No tool invocation found" errors on page refresh.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
/**
|
||||
* Hook that returns a progress value that starts fast and slows down,
|
||||
* asymptotically approaching but never reaching the max value.
|
||||
*
|
||||
* Uses a half-life formula: progress = max * (1 - 0.5^(time/halfLife))
|
||||
* This creates a "loading bar" effect where:
|
||||
* - 50% is reached at halfLifeSeconds
|
||||
* - 75% is reached at 2 * halfLifeSeconds
|
||||
* - 87.5% is reached at 3 * halfLifeSeconds
|
||||
*
|
||||
* @param isActive - Whether the progress should be animating
|
||||
* @param halfLifeSeconds - Time in seconds to reach 50% progress (default: 30)
|
||||
* @param maxProgress - Maximum progress value to approach (default: 100)
|
||||
* @param intervalMs - Update interval in milliseconds (default: 100)
|
||||
* @returns Current progress value (0–maxProgress)
|
||||
*/
|
||||
export function useAsymptoticProgress(
|
||||
isActive: boolean,
|
||||
halfLifeSeconds = 30,
|
||||
maxProgress = 100,
|
||||
intervalMs = 100,
|
||||
) {
|
||||
const [progress, setProgress] = useState(0);
|
||||
const elapsedTimeRef = useRef(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isActive) {
|
||||
setProgress(0);
|
||||
elapsedTimeRef.current = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = setInterval(() => {
|
||||
elapsedTimeRef.current += intervalMs / 1000;
|
||||
const newProgress =
|
||||
maxProgress *
|
||||
(1 - Math.pow(0.5, elapsedTimeRef.current / halfLifeSeconds));
|
||||
setProgress(newProgress);
|
||||
}, intervalMs);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [isActive, halfLifeSeconds, maxProgress, intervalMs]);
|
||||
|
||||
return progress;
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
import { getGetV2GetSessionQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import type { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { convertChatSessionMessagesToUiMessages } from "../helpers/convertChatSessionToUiMessages";
|
||||
|
||||
const OPERATING_TYPES = new Set([
|
||||
"operation_started",
|
||||
"operation_pending",
|
||||
"operation_in_progress",
|
||||
]);
|
||||
|
||||
const POLL_INTERVAL_MS = 1_500;
|
||||
|
||||
/**
|
||||
* Detects whether any message contains a tool part whose output indicates
|
||||
* a long-running operation is still in progress.
|
||||
*/
|
||||
function hasOperatingTool(
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
) {
|
||||
for (const msg of messages) {
|
||||
for (const part of msg.parts) {
|
||||
if (!part.type.startsWith("tool-")) continue;
|
||||
const toolPart = part as { output?: unknown };
|
||||
if (!toolPart.output) continue;
|
||||
const output =
|
||||
typeof toolPart.output === "string"
|
||||
? safeParse(toolPart.output)
|
||||
: toolPart.output;
|
||||
if (
|
||||
output &&
|
||||
typeof output === "object" &&
|
||||
"type" in output &&
|
||||
OPERATING_TYPES.has((output as { type: string }).type)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function safeParse(value: string): unknown {
|
||||
try {
|
||||
return JSON.parse(value);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Polls the session endpoint while any tool is in an "operating" state
|
||||
* (operation_started / operation_pending / operation_in_progress).
|
||||
*
|
||||
* When the session data shows the tool output has changed (e.g. to
|
||||
* agent_saved), it calls `setMessages` with the updated messages.
|
||||
*/
|
||||
export function useLongRunningToolPolling(
|
||||
sessionId: string | null,
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
setMessages: (
|
||||
updater: (
|
||||
prev: UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
) => UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
) => void,
|
||||
) {
|
||||
const queryClient = useQueryClient();
|
||||
const intervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
|
||||
const stopPolling = useCallback(() => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const poll = useCallback(async () => {
|
||||
if (!sessionId) return;
|
||||
|
||||
// Invalidate the query cache so the next fetch gets fresh data
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
|
||||
// Fetch fresh session data
|
||||
const data = queryClient.getQueryData<{
|
||||
status: number;
|
||||
data: { messages?: unknown[] };
|
||||
}>(getGetV2GetSessionQueryKey(sessionId));
|
||||
|
||||
if (data?.status !== 200 || !data.data.messages) return;
|
||||
|
||||
const freshMessages = convertChatSessionMessagesToUiMessages(
|
||||
sessionId,
|
||||
data.data.messages,
|
||||
);
|
||||
|
||||
if (!freshMessages || freshMessages.length === 0) return;
|
||||
|
||||
// Update when the long-running tool completed
|
||||
if (!hasOperatingTool(freshMessages)) {
|
||||
setMessages(() => freshMessages);
|
||||
stopPolling();
|
||||
}
|
||||
}, [sessionId, queryClient, setMessages, stopPolling]);
|
||||
|
||||
useEffect(() => {
|
||||
const shouldPoll = hasOperatingTool(messages);
|
||||
|
||||
// Always clear any previous interval first so we never leak timers
|
||||
// when the effect re-runs due to dependency changes (e.g. messages
|
||||
// updating as the LLM streams text after the tool call).
|
||||
stopPolling();
|
||||
|
||||
if (shouldPoll && sessionId) {
|
||||
intervalRef.current = setInterval(() => {
|
||||
poll();
|
||||
}, POLL_INTERVAL_MS);
|
||||
}
|
||||
|
||||
return () => {
|
||||
stopPolling();
|
||||
};
|
||||
}, [messages, sessionId, poll, stopPolling]);
|
||||
}
|
||||
@@ -11,6 +11,11 @@ import {
|
||||
MessageResponse,
|
||||
} from "@/components/ai-elements/message";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
CredentialsProvidersContext,
|
||||
type CredentialsProviderData,
|
||||
type CredentialsProvidersContextType,
|
||||
} from "@/providers/agent-credentials/credentials-provider";
|
||||
import { CopilotChatActionsProvider } from "../components/CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { CreateAgentTool } from "../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../tools/EditAgent/EditAgent";
|
||||
@@ -97,6 +102,65 @@ function uid() {
|
||||
return `sg-${++_id}`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock credential providers for setup-requirements demos
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const noop = () => Promise.reject(new Error("Styleguide mock"));
|
||||
|
||||
function makeMockProvider(
|
||||
provider: string,
|
||||
providerName: string,
|
||||
savedCredentials: CredentialsProviderData["savedCredentials"] = [],
|
||||
): CredentialsProviderData {
|
||||
return {
|
||||
provider,
|
||||
providerName,
|
||||
savedCredentials,
|
||||
isSystemProvider: false,
|
||||
oAuthCallback: noop as CredentialsProviderData["oAuthCallback"],
|
||||
mcpOAuthCallback: noop as CredentialsProviderData["mcpOAuthCallback"],
|
||||
createAPIKeyCredentials:
|
||||
noop as CredentialsProviderData["createAPIKeyCredentials"],
|
||||
createUserPasswordCredentials:
|
||||
noop as CredentialsProviderData["createUserPasswordCredentials"],
|
||||
createHostScopedCredentials:
|
||||
noop as CredentialsProviderData["createHostScopedCredentials"],
|
||||
deleteCredentials: noop as CredentialsProviderData["deleteCredentials"],
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider context where the user already has saved credentials
|
||||
* so the credential picker shows a selection list.
|
||||
*/
|
||||
const MOCK_PROVIDERS_WITH_CREDENTIALS: CredentialsProvidersContextType = {
|
||||
google: makeMockProvider("google", "Google", [
|
||||
{
|
||||
id: "cred-google-1",
|
||||
provider: "google",
|
||||
type: "oauth2",
|
||||
title: "work@company.com",
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
{
|
||||
id: "cred-google-2",
|
||||
provider: "google",
|
||||
type: "oauth2",
|
||||
title: "personal@gmail.com",
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider context where the user has NO saved credentials,
|
||||
* so the credential picker shows an "add new" flow.
|
||||
*/
|
||||
const MOCK_PROVIDERS_WITHOUT_CREDENTIALS: CredentialsProvidersContextType = {
|
||||
openweathermap: makeMockProvider("openweathermap", "OpenWeatherMap"),
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Page
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -554,45 +618,80 @@ export default function StyleguidePage() {
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (setup requirements)">
|
||||
<RunBlockTool
|
||||
part={{
|
||||
type: "tool-run_block",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { block_id: "weather-block-123" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This block requires API credentials to run. Please configure them below.",
|
||||
setup_info: {
|
||||
agent_name: "Weather Agent",
|
||||
requirements: {
|
||||
inputs: [
|
||||
{
|
||||
name: "city",
|
||||
title: "City",
|
||||
type: "string",
|
||||
required: true,
|
||||
description: "The city to get weather for",
|
||||
},
|
||||
],
|
||||
},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
openweathermap: {
|
||||
provider: "openweathermap",
|
||||
credentials_type: "api_key",
|
||||
title: "OpenWeatherMap API Key",
|
||||
description:
|
||||
"Required to access weather data. Get your key at openweathermap.org",
|
||||
<SubSection label="Setup requirements — no credentials (add new)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITHOUT_CREDENTIALS}
|
||||
>
|
||||
<RunBlockTool
|
||||
part={{
|
||||
type: "tool-run_block",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { block_id: "weather-block-123" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This block requires API credentials to run. Please configure them below.",
|
||||
setup_info: {
|
||||
agent_id: "agent-weather-1",
|
||||
agent_name: "Weather Agent",
|
||||
requirements: {
|
||||
inputs: [
|
||||
{
|
||||
name: "city",
|
||||
title: "City",
|
||||
type: "string",
|
||||
required: true,
|
||||
description: "The city to get weather for",
|
||||
},
|
||||
],
|
||||
},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
openweathermap_key: {
|
||||
provider: "openweathermap",
|
||||
types: ["api_key"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Setup requirements — has credentials (pick from list)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITH_CREDENTIALS}
|
||||
>
|
||||
<RunBlockTool
|
||||
part={{
|
||||
type: "tool-run_block",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { block_id: "calendar-block-456" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This block requires Google credentials. Pick an account below or connect a new one.",
|
||||
setup_info: {
|
||||
agent_id: "agent-calendar-1",
|
||||
agent_name: "Calendar Agent",
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
google_oauth: {
|
||||
provider: "google",
|
||||
types: ["oauth2"],
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (error)">
|
||||
@@ -849,34 +948,71 @@ export default function StyleguidePage() {
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (setup requirements)">
|
||||
<RunAgentTool
|
||||
part={{
|
||||
type: "tool-run_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { username_agent_slug: "creator/my-agent" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message: "This agent requires additional setup.",
|
||||
setup_info: {
|
||||
agent_name: "YouTube Summarizer",
|
||||
requirements: {},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
youtube_api: {
|
||||
provider: "youtube",
|
||||
credentials_type: "api_key",
|
||||
title: "YouTube Data API Key",
|
||||
description:
|
||||
"Required to access YouTube video data.",
|
||||
<SubSection label="Setup requirements — no credentials (add new)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITHOUT_CREDENTIALS}
|
||||
>
|
||||
<RunAgentTool
|
||||
part={{
|
||||
type: "tool-run_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { username_agent_slug: "creator/weather-agent" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This agent requires an API key. Add your credentials below.",
|
||||
setup_info: {
|
||||
agent_id: "agent-weather-1",
|
||||
agent_name: "Weather Agent",
|
||||
requirements: {},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
openweathermap_key: {
|
||||
provider: "openweathermap",
|
||||
types: ["api_key"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Setup requirements — has credentials (pick from list)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITH_CREDENTIALS}
|
||||
>
|
||||
<RunAgentTool
|
||||
part={{
|
||||
type: "tool-run_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { username_agent_slug: "creator/calendar-agent" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This agent needs Google credentials. Pick an account or connect a new one.",
|
||||
setup_info: {
|
||||
agent_id: "agent-calendar-1",
|
||||
agent_name: "Google Calendar Agent",
|
||||
requirements: {},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
google_oauth: {
|
||||
provider: "google",
|
||||
types: ["oauth2"],
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (need login)">
|
||||
@@ -984,56 +1120,6 @@ export default function StyleguidePage() {
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (operation started)">
|
||||
<CreateAgentTool
|
||||
part={{
|
||||
type: "tool-create_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
output: {
|
||||
type: ResponseType.operation_started,
|
||||
operation_id: "op-create-123",
|
||||
tool_name: "create_agent",
|
||||
message:
|
||||
"Agent creation has been started. This may take a moment.",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (operation pending)">
|
||||
<CreateAgentTool
|
||||
part={{
|
||||
type: "tool-create_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
output: {
|
||||
type: ResponseType.operation_pending,
|
||||
operation_id: "op-create-123",
|
||||
tool_name: "create_agent",
|
||||
message:
|
||||
"Agent creation is queued and will begin shortly.",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (operation in progress)">
|
||||
<CreateAgentTool
|
||||
part={{
|
||||
type: "tool-create_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
output: {
|
||||
type: ResponseType.operation_in_progress,
|
||||
tool_call_id: "tc-456",
|
||||
message:
|
||||
"An agent creation operation is already in progress. Please wait for it to finish.",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (agent preview)">
|
||||
<CreateAgentTool
|
||||
part={{
|
||||
@@ -1156,22 +1242,6 @@ export default function StyleguidePage() {
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (operation started)">
|
||||
<EditAgentTool
|
||||
part={{
|
||||
type: "tool-edit_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
output: {
|
||||
type: ResponseType.operation_started,
|
||||
operation_id: "op-edit-456",
|
||||
tool_name: "edit_agent",
|
||||
message: "Agent editing has started.",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (agent preview)">
|
||||
<EditAgentTool
|
||||
part={{
|
||||
|
||||
@@ -24,8 +24,8 @@ import {
|
||||
ClarificationQuestionsCard,
|
||||
ClarifyingQuestion,
|
||||
} from "./components/ClarificationQuestionsCard";
|
||||
import sparklesImg from "./components/MiniGame/assets/sparkles.png";
|
||||
import { MiniGame } from "./components/MiniGame/MiniGame";
|
||||
import sparklesImg from "../../components/MiniGame/assets/sparkles.png";
|
||||
import { MiniGame } from "../../components/MiniGame/MiniGame";
|
||||
import { SuggestedGoalCard } from "./components/SuggestedGoalCard";
|
||||
import {
|
||||
AccordionIcon,
|
||||
@@ -36,9 +36,6 @@ import {
|
||||
isAgentSavedOutput,
|
||||
isClarificationNeededOutput,
|
||||
isErrorOutput,
|
||||
isOperationInProgressOutput,
|
||||
isOperationPendingOutput,
|
||||
isOperationStartedOutput,
|
||||
isSuggestedGoalOutput,
|
||||
ToolIcon,
|
||||
truncateText,
|
||||
@@ -57,9 +54,18 @@ interface Props {
|
||||
part: CreateAgentToolPart;
|
||||
}
|
||||
|
||||
function getAccordionMeta(output: CreateAgentToolOutput) {
|
||||
function getAccordionMeta(output: CreateAgentToolOutput | null) {
|
||||
const icon = <AccordionIcon />;
|
||||
|
||||
if (!output) {
|
||||
return {
|
||||
icon,
|
||||
title:
|
||||
"Creating agent, this may take a few minutes. Play while you wait.",
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (isAgentSavedOutput(output)) {
|
||||
return { icon, title: output.agent_name, expanded: true };
|
||||
}
|
||||
@@ -86,18 +92,6 @@ function getAccordionMeta(output: CreateAgentToolOutput) {
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
if (
|
||||
isOperationStartedOutput(output) ||
|
||||
isOperationPendingOutput(output) ||
|
||||
isOperationInProgressOutput(output)
|
||||
) {
|
||||
return {
|
||||
icon,
|
||||
title:
|
||||
"Creating agent, this may take a few minutes. Play while you wait.",
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
return {
|
||||
icon: (
|
||||
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
|
||||
@@ -119,23 +113,11 @@ export function CreateAgentTool({ part }: Props) {
|
||||
const isError =
|
||||
part.state === "output-error" || (!!output && isErrorOutput(output));
|
||||
|
||||
const isOperating =
|
||||
!!output &&
|
||||
(isOperationStartedOutput(output) ||
|
||||
isOperationPendingOutput(output) ||
|
||||
isOperationInProgressOutput(output));
|
||||
const isOperating = !output;
|
||||
|
||||
const hasExpandableContent =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
(isOperationStartedOutput(output) ||
|
||||
isOperationPendingOutput(output) ||
|
||||
isOperationInProgressOutput(output) ||
|
||||
isAgentPreviewOutput(output) ||
|
||||
isAgentSavedOutput(output) ||
|
||||
isClarificationNeededOutput(output) ||
|
||||
isSuggestedGoalOutput(output) ||
|
||||
isErrorOutput(output));
|
||||
// Show accordion for operating state and successful outputs, but not for errors
|
||||
// (errors are shown inline so they get replaced when retrying)
|
||||
const hasExpandableContent = !isError;
|
||||
|
||||
function handleUseSuggestedGoal(goal: string) {
|
||||
onSend(`Please create an agent with this goal: ${goal}`);
|
||||
@@ -161,15 +143,66 @@ export function CreateAgentTool({ part }: Props) {
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon isStreaming={isStreaming} isError={isError} />
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
{isOperating && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon isStreaming={isStreaming} isError={isError} />
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
{isError && output && isErrorOutput(output) && (
|
||||
<div className="space-y-3 rounded-lg border border-red-200 bg-red-50 p-4">
|
||||
<div className="flex items-start gap-2">
|
||||
<WarningDiamondIcon
|
||||
size={20}
|
||||
weight="regular"
|
||||
className="mt-0.5 shrink-0 text-red-500"
|
||||
/>
|
||||
<div className="flex-1 space-y-2">
|
||||
<Text variant="body-medium" className="text-red-900">
|
||||
{output.message ||
|
||||
"Failed to generate the agent. Please try again."}
|
||||
</Text>
|
||||
{output.error && (
|
||||
<details className="text-xs text-red-700">
|
||||
<summary className="cursor-pointer font-medium">
|
||||
Technical details
|
||||
</summary>
|
||||
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2">
|
||||
{formatMaybeJson(output.error)}
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
{output.details && (
|
||||
<pre className="max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2 text-xs text-red-700">
|
||||
{formatMaybeJson(output.details)}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => onSend("Please try creating the agent again.")}
|
||||
>
|
||||
Try again
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => onSend("Can you help me simplify this goal?")}
|
||||
>
|
||||
Simplify goal
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isOperating && (
|
||||
<ContentGrid>
|
||||
@@ -180,7 +213,7 @@ export function CreateAgentTool({ part }: Props) {
|
||||
</ContentGrid>
|
||||
)}
|
||||
|
||||
{isAgentSavedOutput(output) && (
|
||||
{output && isAgentSavedOutput(output) && (
|
||||
<div className="rounded-xl border border-border/60 bg-card p-4 shadow-sm">
|
||||
<div className="flex items-baseline gap-2">
|
||||
<Image
|
||||
@@ -226,7 +259,7 @@ export function CreateAgentTool({ part }: Props) {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isAgentPreviewOutput(output) && (
|
||||
{output && isAgentPreviewOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
{output.description?.trim() && (
|
||||
@@ -240,7 +273,7 @@ export function CreateAgentTool({ part }: Props) {
|
||||
</ContentGrid>
|
||||
)}
|
||||
|
||||
{isClarificationNeededOutput(output) && (
|
||||
{output && isClarificationNeededOutput(output) && (
|
||||
<ClarificationQuestionsCard
|
||||
questions={(output.questions ?? []).map((q) => {
|
||||
const item: ClarifyingQuestion = {
|
||||
@@ -259,7 +292,7 @@ export function CreateAgentTool({ part }: Props) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{isSuggestedGoalOutput(output) && (
|
||||
{output && isSuggestedGoalOutput(output) && (
|
||||
<SuggestedGoalCard
|
||||
message={output.message}
|
||||
suggestedGoal={output.suggested_goal}
|
||||
@@ -268,38 +301,6 @@ export function CreateAgentTool({ part }: Props) {
|
||||
onUseSuggestedGoal={handleUseSuggestedGoal}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isErrorOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
{output.error && (
|
||||
<ContentCodeBlock>
|
||||
{formatMaybeJson(output.error)}
|
||||
</ContentCodeBlock>
|
||||
)}
|
||||
{output.details && (
|
||||
<ContentCodeBlock>
|
||||
{formatMaybeJson(output.details)}
|
||||
</ContentCodeBlock>
|
||||
)}
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => onSend("Please try creating the agent again.")}
|
||||
>
|
||||
Try again
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => onSend("Can you help me simplify this goal?")}
|
||||
>
|
||||
Simplify goal
|
||||
</Button>
|
||||
</div>
|
||||
</ContentGrid>
|
||||
)}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -2,9 +2,6 @@ import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentP
|
||||
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
|
||||
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
|
||||
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
|
||||
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
|
||||
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
|
||||
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
|
||||
import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
||||
import type { SuggestedGoalResponse } from "@/app/api/__generated__/models/suggestedGoalResponse";
|
||||
import {
|
||||
@@ -16,9 +13,6 @@ import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
|
||||
export type CreateAgentToolOutput =
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
@@ -39,9 +33,6 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
|
||||
if (typeof output === "object") {
|
||||
const type = (output as { type?: unknown }).type;
|
||||
if (
|
||||
type === ResponseType.operation_started ||
|
||||
type === ResponseType.operation_pending ||
|
||||
type === ResponseType.operation_in_progress ||
|
||||
type === ResponseType.agent_preview ||
|
||||
type === ResponseType.agent_saved ||
|
||||
type === ResponseType.clarification_needed ||
|
||||
@@ -50,9 +41,6 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
|
||||
) {
|
||||
return output as CreateAgentToolOutput;
|
||||
}
|
||||
if ("operation_id" in output && "tool_name" in output)
|
||||
return output as OperationStartedResponse | OperationPendingResponse;
|
||||
if ("tool_call_id" in output) return output as OperationInProgressResponse;
|
||||
if ("agent_json" in output && "agent_name" in output)
|
||||
return output as AgentPreviewResponse;
|
||||
if ("agent_id" in output && "library_agent_id" in output)
|
||||
@@ -72,30 +60,6 @@ export function getCreateAgentToolOutput(
|
||||
return parseOutput((part as { output?: unknown }).output);
|
||||
}
|
||||
|
||||
export function isOperationStartedOutput(
|
||||
output: CreateAgentToolOutput,
|
||||
): output is OperationStartedResponse {
|
||||
return (
|
||||
output.type === ResponseType.operation_started ||
|
||||
("operation_id" in output && "tool_name" in output)
|
||||
);
|
||||
}
|
||||
|
||||
export function isOperationPendingOutput(
|
||||
output: CreateAgentToolOutput,
|
||||
): output is OperationPendingResponse {
|
||||
return output.type === ResponseType.operation_pending;
|
||||
}
|
||||
|
||||
export function isOperationInProgressOutput(
|
||||
output: CreateAgentToolOutput,
|
||||
): output is OperationInProgressResponse {
|
||||
return (
|
||||
output.type === ResponseType.operation_in_progress ||
|
||||
"tool_call_id" in output
|
||||
);
|
||||
}
|
||||
|
||||
export function isAgentPreviewOutput(
|
||||
output: CreateAgentToolOutput,
|
||||
): output is AgentPreviewResponse {
|
||||
@@ -144,10 +108,6 @@ export function getAnimationText(part: {
|
||||
case "output-available": {
|
||||
const output = parseOutput(part.output);
|
||||
if (!output) return "Creating a new agent";
|
||||
if (isOperationStartedOutput(output)) return "Agent creation started";
|
||||
if (isOperationPendingOutput(output)) return "Agent creation in progress";
|
||||
if (isOperationInProgressOutput(output))
|
||||
return "Agent creation already in progress";
|
||||
if (isAgentSavedOutput(output)) return `Saved ${output.agent_name}`;
|
||||
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
|
||||
if (isClarificationNeededOutput(output)) return "Needs clarification";
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { WarningDiamondIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
BookOpenIcon,
|
||||
PencilSimpleIcon,
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import Image from "next/image";
|
||||
import NextLink from "next/link";
|
||||
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import sparklesImg from "../../components/MiniGame/assets/sparkles.png";
|
||||
import { MiniGame } from "../../components/MiniGame/MiniGame";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import {
|
||||
ContentCardDescription,
|
||||
ContentCodeBlock,
|
||||
ContentGrid,
|
||||
ContentHint,
|
||||
ContentLink,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import { MiniGame } from "../CreateAgent/components/MiniGame/MiniGame";
|
||||
import {
|
||||
ClarificationQuestionsCard,
|
||||
ClarifyingQuestion,
|
||||
@@ -28,9 +35,6 @@ import {
|
||||
isAgentSavedOutput,
|
||||
isClarificationNeededOutput,
|
||||
isErrorOutput,
|
||||
isOperationInProgressOutput,
|
||||
isOperationPendingOutput,
|
||||
isOperationStartedOutput,
|
||||
ToolIcon,
|
||||
truncateText,
|
||||
type EditAgentToolOutput,
|
||||
@@ -48,7 +52,7 @@ interface Props {
|
||||
part: EditAgentToolPart;
|
||||
}
|
||||
|
||||
function getAccordionMeta(output: EditAgentToolOutput): {
|
||||
function getAccordionMeta(output: EditAgentToolOutput | null): {
|
||||
icon: React.ReactNode;
|
||||
title: string;
|
||||
titleClassName?: string;
|
||||
@@ -57,8 +61,16 @@ function getAccordionMeta(output: EditAgentToolOutput): {
|
||||
} {
|
||||
const icon = <AccordionIcon />;
|
||||
|
||||
if (!output) {
|
||||
return {
|
||||
icon,
|
||||
title: "Editing agent, this may take a few minutes. Play while you wait.",
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (isAgentSavedOutput(output)) {
|
||||
return { icon, title: output.agent_name };
|
||||
return { icon, title: output.agent_name, expanded: true };
|
||||
}
|
||||
if (isAgentPreviewOutput(output)) {
|
||||
return {
|
||||
@@ -75,17 +87,6 @@ function getAccordionMeta(output: EditAgentToolOutput): {
|
||||
description: `${questions.length} question${questions.length === 1 ? "" : "s"}`,
|
||||
};
|
||||
}
|
||||
if (
|
||||
isOperationStartedOutput(output) ||
|
||||
isOperationPendingOutput(output) ||
|
||||
isOperationInProgressOutput(output)
|
||||
) {
|
||||
return {
|
||||
icon: <OrbitLoader size={32} />,
|
||||
title: "Editing agent, this may take a few minutes. Play while you wait.",
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
return {
|
||||
icon: (
|
||||
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
|
||||
@@ -104,21 +105,12 @@ export function EditAgentTool({ part }: Props) {
|
||||
const output = getEditAgentToolOutput(part);
|
||||
const isError =
|
||||
part.state === "output-error" || (!!output && isErrorOutput(output));
|
||||
const isOperating =
|
||||
!!output &&
|
||||
(isOperationStartedOutput(output) ||
|
||||
isOperationPendingOutput(output) ||
|
||||
isOperationInProgressOutput(output));
|
||||
const hasExpandableContent =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
(isOperationStartedOutput(output) ||
|
||||
isOperationPendingOutput(output) ||
|
||||
isOperationInProgressOutput(output) ||
|
||||
isAgentPreviewOutput(output) ||
|
||||
isAgentSavedOutput(output) ||
|
||||
isClarificationNeededOutput(output) ||
|
||||
isErrorOutput(output));
|
||||
|
||||
const isOperating = !output;
|
||||
|
||||
// Show accordion for operating state and successful outputs, but not for errors
|
||||
// (errors are shown inline so they get replaced when retrying)
|
||||
const hasExpandableContent = !isError;
|
||||
|
||||
function handleClarificationAnswers(answers: Record<string, string>) {
|
||||
const questions =
|
||||
@@ -140,15 +132,57 @@ export function EditAgentTool({ part }: Props) {
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon isStreaming={isStreaming} isError={isError} />
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
{isOperating && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon isStreaming={isStreaming} isError={isError} />
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
{isError && output && isErrorOutput(output) && (
|
||||
<div className="space-y-3 rounded-lg border border-red-200 bg-red-50 p-4">
|
||||
<div className="flex items-start gap-2">
|
||||
<WarningDiamondIcon
|
||||
size={20}
|
||||
weight="regular"
|
||||
className="mt-0.5 shrink-0 text-red-500"
|
||||
/>
|
||||
<div className="flex-1 space-y-2">
|
||||
<Text variant="body-medium" className="text-red-900">
|
||||
{output.message ||
|
||||
"Failed to edit the agent. Please try again."}
|
||||
</Text>
|
||||
{output.error && (
|
||||
<details className="text-xs text-red-700">
|
||||
<summary className="cursor-pointer font-medium">
|
||||
Technical details
|
||||
</summary>
|
||||
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2">
|
||||
{formatMaybeJson(output.error)}
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
{output.details && (
|
||||
<pre className="max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2 text-xs text-red-700">
|
||||
{formatMaybeJson(output.details)}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => onSend("Please try editing the agent again.")}
|
||||
>
|
||||
Try again
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isOperating && (
|
||||
<ContentGrid>
|
||||
@@ -159,27 +193,53 @@ export function EditAgentTool({ part }: Props) {
|
||||
</ContentGrid>
|
||||
)}
|
||||
|
||||
{isAgentSavedOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<ContentLink href={output.library_agent_link}>
|
||||
Open in library
|
||||
</ContentLink>
|
||||
<ContentLink href={output.agent_page_link}>
|
||||
Open in builder
|
||||
</ContentLink>
|
||||
{output && isAgentSavedOutput(output) && (
|
||||
<div className="rounded-xl border border-border/60 bg-card p-4 shadow-sm">
|
||||
<div className="flex items-baseline gap-2">
|
||||
<Image
|
||||
src={sparklesImg}
|
||||
alt="sparkles"
|
||||
width={24}
|
||||
height={24}
|
||||
className="relative top-1"
|
||||
/>
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="mb-2 text-[16px] text-black"
|
||||
>
|
||||
Agent{" "}
|
||||
<span className="text-violet-600">{output.agent_name}</span>{" "}
|
||||
has been updated!
|
||||
</Text>
|
||||
</div>
|
||||
<ContentCodeBlock>
|
||||
{truncateText(
|
||||
formatMaybeJson({ agent_id: output.agent_id }),
|
||||
800,
|
||||
)}
|
||||
</ContentCodeBlock>
|
||||
</ContentGrid>
|
||||
<div className="mt-3 flex flex-wrap gap-4">
|
||||
<Button variant="outline" size="small">
|
||||
<NextLink
|
||||
href={output.library_agent_link}
|
||||
className="inline-flex items-center gap-1.5"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
<BookOpenIcon size={14} weight="regular" />
|
||||
Open in library
|
||||
</NextLink>
|
||||
</Button>
|
||||
<Button variant="outline" size="small">
|
||||
<NextLink
|
||||
href={output.agent_page_link}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="inline-flex items-center gap-1.5"
|
||||
>
|
||||
<PencilSimpleIcon size={14} weight="regular" />
|
||||
Open in builder
|
||||
</NextLink>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isAgentPreviewOutput(output) && (
|
||||
{output && isAgentPreviewOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
{output.description?.trim() && (
|
||||
@@ -193,7 +253,7 @@ export function EditAgentTool({ part }: Props) {
|
||||
</ContentGrid>
|
||||
)}
|
||||
|
||||
{isClarificationNeededOutput(output) && (
|
||||
{output && isClarificationNeededOutput(output) && (
|
||||
<ClarificationQuestionsCard
|
||||
questions={(output.questions ?? []).map((q) => {
|
||||
const item: ClarifyingQuestion = {
|
||||
@@ -211,22 +271,6 @@ export function EditAgentTool({ part }: Props) {
|
||||
onSubmitAnswers={handleClarificationAnswers}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isErrorOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
{output.error && (
|
||||
<ContentCodeBlock>
|
||||
{formatMaybeJson(output.error)}
|
||||
</ContentCodeBlock>
|
||||
)}
|
||||
{output.details && (
|
||||
<ContentCodeBlock>
|
||||
{formatMaybeJson(output.details)}
|
||||
</ContentCodeBlock>
|
||||
)}
|
||||
</ContentGrid>
|
||||
)}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -2,9 +2,6 @@ import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentP
|
||||
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
|
||||
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
|
||||
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
|
||||
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
|
||||
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
|
||||
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
|
||||
import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
||||
import {
|
||||
NotePencilIcon,
|
||||
@@ -15,9 +12,6 @@ import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
|
||||
export type EditAgentToolOutput =
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
@@ -37,9 +31,6 @@ function parseOutput(output: unknown): EditAgentToolOutput | null {
|
||||
if (typeof output === "object") {
|
||||
const type = (output as { type?: unknown }).type;
|
||||
if (
|
||||
type === ResponseType.operation_started ||
|
||||
type === ResponseType.operation_pending ||
|
||||
type === ResponseType.operation_in_progress ||
|
||||
type === ResponseType.agent_preview ||
|
||||
type === ResponseType.agent_saved ||
|
||||
type === ResponseType.clarification_needed ||
|
||||
@@ -47,9 +38,6 @@ function parseOutput(output: unknown): EditAgentToolOutput | null {
|
||||
) {
|
||||
return output as EditAgentToolOutput;
|
||||
}
|
||||
if ("operation_id" in output && "tool_name" in output)
|
||||
return output as OperationStartedResponse | OperationPendingResponse;
|
||||
if ("tool_call_id" in output) return output as OperationInProgressResponse;
|
||||
if ("agent_json" in output && "agent_name" in output)
|
||||
return output as AgentPreviewResponse;
|
||||
if ("agent_id" in output && "library_agent_id" in output)
|
||||
@@ -68,30 +56,6 @@ export function getEditAgentToolOutput(
|
||||
return parseOutput((part as { output?: unknown }).output);
|
||||
}
|
||||
|
||||
export function isOperationStartedOutput(
|
||||
output: EditAgentToolOutput,
|
||||
): output is OperationStartedResponse {
|
||||
return (
|
||||
output.type === ResponseType.operation_started ||
|
||||
("operation_id" in output && "tool_name" in output)
|
||||
);
|
||||
}
|
||||
|
||||
export function isOperationPendingOutput(
|
||||
output: EditAgentToolOutput,
|
||||
): output is OperationPendingResponse {
|
||||
return output.type === ResponseType.operation_pending;
|
||||
}
|
||||
|
||||
export function isOperationInProgressOutput(
|
||||
output: EditAgentToolOutput,
|
||||
): output is OperationInProgressResponse {
|
||||
return (
|
||||
output.type === ResponseType.operation_in_progress ||
|
||||
"tool_call_id" in output
|
||||
);
|
||||
}
|
||||
|
||||
export function isAgentPreviewOutput(
|
||||
output: EditAgentToolOutput,
|
||||
): output is AgentPreviewResponse {
|
||||
@@ -132,10 +96,6 @@ export function getAnimationText(part: {
|
||||
case "output-available": {
|
||||
const output = parseOutput(part.output);
|
||||
if (!output) return "Editing the agent";
|
||||
if (isOperationStartedOutput(output)) return "Agent update started";
|
||||
if (isOperationPendingOutput(output)) return "Agent update in progress";
|
||||
if (isOperationInProgressOutput(output))
|
||||
return "Agent update already in progress";
|
||||
if (isAgentSavedOutput(output)) return `Saved "${output.agent_name}"`;
|
||||
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
|
||||
if (isClarificationNeededOutput(output)) return "Needs clarification";
|
||||
|
||||
@@ -686,17 +686,20 @@ export function GenericTool({ part }: Props) {
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon
|
||||
category={category}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
{/* Only show loading text when NOT showing accordion */}
|
||||
{!showAccordion && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon
|
||||
category={category}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showAccordion && accordionData ? (
|
||||
<ToolAccordion
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
ContentHint,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { MiniGame } from "../CreateAgent/components/MiniGame/MiniGame";
|
||||
import { MiniGame } from "../../components/MiniGame/MiniGame";
|
||||
import {
|
||||
getAccordionMeta,
|
||||
getAnimationText,
|
||||
@@ -47,24 +47,42 @@ export function RunAgentTool({ part }: Props) {
|
||||
const isError =
|
||||
part.state === "output-error" ||
|
||||
(!!output && isRunAgentErrorOutput(output));
|
||||
const isOutputAvailable = part.state === "output-available" && !!output;
|
||||
|
||||
const setupRequirementsOutput =
|
||||
isOutputAvailable && isRunAgentSetupRequirementsOutput(output)
|
||||
? output
|
||||
: null;
|
||||
|
||||
const agentDetailsOutput =
|
||||
isOutputAvailable && isRunAgentAgentDetailsOutput(output) ? output : null;
|
||||
|
||||
const needLoginOutput =
|
||||
isOutputAvailable && isRunAgentNeedLoginOutput(output) ? output : null;
|
||||
|
||||
const hasExpandableContent =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
(isRunAgentExecutionStartedOutput(output) ||
|
||||
isRunAgentAgentDetailsOutput(output) ||
|
||||
isRunAgentSetupRequirementsOutput(output) ||
|
||||
isRunAgentNeedLoginOutput(output) ||
|
||||
isRunAgentErrorOutput(output));
|
||||
isOutputAvailable &&
|
||||
!setupRequirementsOutput &&
|
||||
!agentDetailsOutput &&
|
||||
!needLoginOutput &&
|
||||
(isRunAgentExecutionStartedOutput(output) || isRunAgentErrorOutput(output));
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon isStreaming={isStreaming} isError={isError} />
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
{/* Only show loading text when NOT showing accordion or other content */}
|
||||
{!isStreaming &&
|
||||
!setupRequirementsOutput &&
|
||||
!agentDetailsOutput &&
|
||||
!needLoginOutput &&
|
||||
!hasExpandableContent && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon isStreaming={isStreaming} isError={isError} />
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isStreaming && !output && (
|
||||
<ToolAccordion
|
||||
@@ -81,24 +99,30 @@ export function RunAgentTool({ part }: Props) {
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
{setupRequirementsOutput && (
|
||||
<div className="mt-2">
|
||||
<SetupRequirementsCard output={setupRequirementsOutput} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{agentDetailsOutput && (
|
||||
<div className="mt-2">
|
||||
<AgentDetailsCard output={agentDetailsOutput} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{needLoginOutput && (
|
||||
<div className="mt-2">
|
||||
<ContentMessage>{needLoginOutput.message}</ContentMessage>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isRunAgentExecutionStartedOutput(output) && (
|
||||
<ExecutionStartedCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunAgentAgentDetailsOutput(output) && (
|
||||
<AgentDetailsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunAgentSetupRequirementsOutput(output) && (
|
||||
<SetupRequirementsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunAgentNeedLoginOutput(output) && (
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
)}
|
||||
|
||||
{isRunAgentErrorOutput(output) && <ErrorCard output={output} />}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import { useState } from "react";
|
||||
import { useCopilotChatActions } from "../../../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import {
|
||||
ContentBadge,
|
||||
@@ -38,40 +39,40 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
setInputCredentials((prev) => ({ ...prev, [key]: value }));
|
||||
}
|
||||
|
||||
const isAllComplete =
|
||||
credentialFields.length > 0 &&
|
||||
const needsCredentials = credentialFields.length > 0;
|
||||
const isAllCredentialsComplete =
|
||||
needsCredentials &&
|
||||
[...requiredCredentials].every((key) => !!inputCredentials[key]);
|
||||
|
||||
const canProceed =
|
||||
!hasSent && (!needsCredentials || isAllCredentialsComplete);
|
||||
|
||||
function handleProceed() {
|
||||
setHasSent(true);
|
||||
onSend(
|
||||
"I've configured the required credentials. Please check if everything is ready and proceed with running the agent.",
|
||||
);
|
||||
const message = needsCredentials
|
||||
? "I've configured the required credentials. Please check if everything is ready and proceed with running the agent."
|
||||
: "Please proceed with running the agent.";
|
||||
onSend(message);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="grid gap-2">
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
|
||||
{credentialFields.length > 0 && (
|
||||
{needsCredentials && (
|
||||
<div className="rounded-2xl border bg-background p-3">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
{isAllComplete && !hasSent && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="mt-3 w-full"
|
||||
onClick={handleProceed}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
<Text variant="small" className="w-fit border-b text-zinc-500">
|
||||
Agent credentials
|
||||
</Text>
|
||||
<div className="mt-6">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -100,6 +101,18 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(needsCredentials || expectedInputs.length > 0) && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="mt-4 w-fit"
|
||||
disabled={!canProceed}
|
||||
onClick={handleProceed}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -39,12 +39,19 @@ export function RunBlockTool({ part }: Props) {
|
||||
const isError =
|
||||
part.state === "output-error" ||
|
||||
(!!output && isRunBlockErrorOutput(output));
|
||||
const setupRequirementsOutput =
|
||||
part.state === "output-available" &&
|
||||
output &&
|
||||
isRunBlockSetupRequirementsOutput(output)
|
||||
? output
|
||||
: null;
|
||||
|
||||
const hasExpandableContent =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
!setupRequirementsOutput &&
|
||||
(isRunBlockBlockOutput(output) ||
|
||||
isRunBlockDetailsOutput(output) ||
|
||||
isRunBlockSetupRequirementsOutput(output) ||
|
||||
isRunBlockErrorOutput(output));
|
||||
|
||||
return (
|
||||
@@ -57,6 +64,12 @@ export function RunBlockTool({ part }: Props) {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{setupRequirementsOutput && (
|
||||
<div className="mt-2">
|
||||
<SetupRequirementsCard output={setupRequirementsOutput} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isRunBlockBlockOutput(output) && <BlockOutputCard output={output} />}
|
||||
@@ -65,10 +78,6 @@ export function RunBlockTool({ part }: Props) {
|
||||
<BlockDetailsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunBlockSetupRequirementsOutput(output) && (
|
||||
<SetupRequirementsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunBlockErrorOutput(output) && <ErrorCard output={output} />}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
@@ -6,15 +6,9 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { useState } from "react";
|
||||
import { useCopilotChatActions } from "../../../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import {
|
||||
ContentBadge,
|
||||
ContentCardDescription,
|
||||
ContentCardTitle,
|
||||
ContentMessage,
|
||||
} from "../../../../components/ToolAccordion/AccordionContent";
|
||||
import { ContentMessage } from "../../../../components/ToolAccordion/AccordionContent";
|
||||
import {
|
||||
buildExpectedInputsSchema,
|
||||
coerceCredentialFields,
|
||||
@@ -31,10 +25,8 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
const [inputCredentials, setInputCredentials] = useState<
|
||||
Record<string, CredentialsMetaInput | undefined>
|
||||
>({});
|
||||
const [hasSentCredentials, setHasSentCredentials] = useState(false);
|
||||
|
||||
const [showInputForm, setShowInputForm] = useState(false);
|
||||
const [inputValues, setInputValues] = useState<Record<string, unknown>>({});
|
||||
const [hasSent, setHasSent] = useState(false);
|
||||
|
||||
const { credentialFields, requiredCredentials } = coerceCredentialFields(
|
||||
output.setup_info.user_readiness?.missing_credentials,
|
||||
@@ -50,27 +42,49 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
setInputCredentials((prev) => ({ ...prev, [key]: value }));
|
||||
}
|
||||
|
||||
const needsCredentials = credentialFields.length > 0;
|
||||
const isAllCredentialsComplete =
|
||||
credentialFields.length > 0 &&
|
||||
needsCredentials &&
|
||||
[...requiredCredentials].every((key) => !!inputCredentials[key]);
|
||||
|
||||
function handleProceedCredentials() {
|
||||
setHasSentCredentials(true);
|
||||
onSend(
|
||||
"I've configured the required credentials. Please re-run the block now.",
|
||||
);
|
||||
}
|
||||
const needsInputs = inputSchema !== null;
|
||||
const requiredInputNames = expectedInputs
|
||||
.filter((i) => i.required)
|
||||
.map((i) => i.name);
|
||||
const isAllInputsComplete =
|
||||
needsInputs &&
|
||||
requiredInputNames.every((name) => {
|
||||
const v = inputValues[name];
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
function handleRunWithInputs() {
|
||||
const nonEmpty = Object.fromEntries(
|
||||
Object.entries(inputValues).filter(
|
||||
([, v]) => v !== undefined && v !== null && v !== "",
|
||||
),
|
||||
);
|
||||
onSend(
|
||||
`Run the block with these inputs: ${JSON.stringify(nonEmpty, null, 2)}`,
|
||||
);
|
||||
setShowInputForm(false);
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
parts.push("I've configured the required credentials.");
|
||||
}
|
||||
|
||||
if (needsInputs) {
|
||||
const nonEmpty = Object.fromEntries(
|
||||
Object.entries(inputValues).filter(
|
||||
([, v]) => v !== undefined && v !== null && v !== "",
|
||||
),
|
||||
);
|
||||
parts.push(
|
||||
`Run the block with these inputs: ${JSON.stringify(nonEmpty, null, 2)}`,
|
||||
);
|
||||
} else {
|
||||
parts.push("Please re-run the block now.");
|
||||
}
|
||||
|
||||
onSend(parts.join(" "));
|
||||
setInputValues({});
|
||||
}
|
||||
|
||||
@@ -78,119 +92,54 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
<div className="grid gap-2">
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
|
||||
{credentialFields.length > 0 && (
|
||||
{needsCredentials && (
|
||||
<div className="rounded-2xl border bg-background p-3">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
{isAllCredentialsComplete && !hasSentCredentials && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="mt-3 w-full"
|
||||
onClick={handleProceedCredentials}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
<Text variant="small" className="w-fit border-b text-zinc-500">
|
||||
Block credentials
|
||||
</Text>
|
||||
<div className="mt-6">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{inputSchema && (
|
||||
<div className="flex gap-2 pt-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
onClick={() => setShowInputForm((prev) => !prev)}
|
||||
>
|
||||
{showInputForm ? "Hide inputs" : "Fill in inputs"}
|
||||
</Button>
|
||||
<div className="rounded-2xl border bg-background p-3 pt-4">
|
||||
<Text variant="small" className="w-fit border-b text-zinc-500">
|
||||
Block inputs
|
||||
</Text>
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema}
|
||||
className="mb-3 mt-3"
|
||||
handleChange={(v) => setInputValues(v.formData ?? {})}
|
||||
uiSchema={{
|
||||
"ui:submitButtonOptions": { norender: true },
|
||||
}}
|
||||
initialValues={inputValues}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "small",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<AnimatePresence initial={false}>
|
||||
{showInputForm && inputSchema && (
|
||||
<motion.div
|
||||
initial={{ height: 0, opacity: 0, filter: "blur(6px)" }}
|
||||
animate={{ height: "auto", opacity: 1, filter: "blur(0px)" }}
|
||||
exit={{ height: 0, opacity: 0, filter: "blur(6px)" }}
|
||||
transition={{
|
||||
height: { type: "spring", bounce: 0.15, duration: 0.5 },
|
||||
opacity: { duration: 0.25 },
|
||||
filter: { duration: 0.2 },
|
||||
}}
|
||||
className="overflow-hidden"
|
||||
style={{ willChange: "height, opacity, filter" }}
|
||||
>
|
||||
<div className="rounded-2xl border bg-background p-3 pt-4">
|
||||
<Text variant="body-medium">Block inputs</Text>
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema}
|
||||
handleChange={(v) => setInputValues(v.formData ?? {})}
|
||||
uiSchema={{
|
||||
"ui:submitButtonOptions": { norender: true },
|
||||
}}
|
||||
initialValues={inputValues}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "small",
|
||||
}}
|
||||
/>
|
||||
<div className="-mt-8 flex gap-2">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
onClick={handleRunWithInputs}
|
||||
>
|
||||
Run
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
onClick={() => {
|
||||
setShowInputForm(false);
|
||||
setInputValues({});
|
||||
}}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
{expectedInputs.length > 0 && !inputSchema && (
|
||||
<div className="rounded-2xl border bg-background p-3">
|
||||
<ContentCardTitle className="text-xs">
|
||||
Expected inputs
|
||||
</ContentCardTitle>
|
||||
<div className="mt-2 grid gap-2">
|
||||
{expectedInputs.map((input) => (
|
||||
<div key={input.name} className="rounded-xl border p-2">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<ContentCardTitle className="text-xs">
|
||||
{input.title}
|
||||
</ContentCardTitle>
|
||||
<ContentBadge>
|
||||
{input.required ? "Required" : "Optional"}
|
||||
</ContentBadge>
|
||||
</div>
|
||||
<ContentCardDescription className="mt-1">
|
||||
{input.name} • {input.type}
|
||||
{input.description ? ` \u2022 ${input.description}` : ""}
|
||||
</ContentCardDescription>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
{(needsCredentials || needsInputs) && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
disabled={!canRun}
|
||||
onClick={handleRun}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -115,6 +115,7 @@ export function useChatSession() {
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
isSessionError: sessionQuery.isError,
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
};
|
||||
|
||||
@@ -14,7 +14,6 @@ import { DefaultChatTransport } from "ai";
|
||||
import type { UIMessage } from "ai";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useLongRunningToolPolling } from "./hooks/useLongRunningToolPolling";
|
||||
|
||||
const STREAM_START_TIMEOUT_MS = 12_000;
|
||||
|
||||
@@ -36,6 +35,46 @@ function resolveInProgressTools(
|
||||
}));
|
||||
}
|
||||
|
||||
/** Build a fingerprint from a message's role + text/tool content for cross-boundary dedup. */
|
||||
function messageFingerprint(msg: UIMessage): string {
|
||||
const fragments = msg.parts.map((p) => {
|
||||
if ("text" in p && typeof p.text === "string") return p.text;
|
||||
if ("toolCallId" in p && typeof p.toolCallId === "string")
|
||||
return `tool:${p.toolCallId}`;
|
||||
return "";
|
||||
});
|
||||
return `${msg.role}::${fragments.join("\n")}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID *and* by content fingerprint.
|
||||
* ID-based dedup catches duplicates within the same source (e.g. two
|
||||
* identical stream events). Fingerprint-based dedup catches duplicates
|
||||
* across the hydration/stream boundary where IDs differ (synthetic
|
||||
* `${sessionId}-${index}` vs AI SDK nanoid).
|
||||
*
|
||||
* NOTE: Fingerprint dedup only applies to assistant messages, not user messages.
|
||||
* Users should be able to send the same message multiple times.
|
||||
*/
|
||||
function deduplicateMessages(messages: UIMessage[]): UIMessage[] {
|
||||
const seenIds = new Set<string>();
|
||||
const seenFingerprints = new Set<string>();
|
||||
return messages.filter((msg) => {
|
||||
if (seenIds.has(msg.id)) return false;
|
||||
seenIds.add(msg.id);
|
||||
|
||||
// Only apply fingerprint deduplication to assistant messages
|
||||
// User messages should allow duplicates (same text sent multiple times)
|
||||
if (msg.role === "assistant") {
|
||||
const fp = messageFingerprint(msg);
|
||||
if (fp !== "::" && seenFingerprints.has(fp)) return false;
|
||||
seenFingerprints.add(fp);
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
export function useCopilotPage() {
|
||||
const { isUserLoading, isLoggedIn } = useSupabase();
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
@@ -52,6 +91,7 @@ export function useCopilotPage() {
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
} = useChatSession();
|
||||
@@ -114,7 +154,7 @@ export function useCopilotPage() {
|
||||
);
|
||||
|
||||
const {
|
||||
messages,
|
||||
messages: rawMessages,
|
||||
sendMessage,
|
||||
stop: sdkStop,
|
||||
status,
|
||||
@@ -129,6 +169,12 @@ export function useCopilotPage() {
|
||||
// call resumeStream() manually after hydration + active_stream detection.
|
||||
});
|
||||
|
||||
// Deduplicate messages continuously to prevent duplicates when resuming streams
|
||||
const messages = useMemo(
|
||||
() => deduplicateMessages(rawMessages),
|
||||
[rawMessages],
|
||||
);
|
||||
|
||||
// Wrap AI SDK's stop() to also cancel the backend executor task.
|
||||
// sdkStop() aborts the SSE fetch instantly (UI feedback), then we fire
|
||||
// the cancel API to actually stop the executor and wait for confirmation.
|
||||
@@ -184,14 +230,14 @@ export function useCopilotPage() {
|
||||
if (status === "streaming" || status === "submitted") return;
|
||||
setMessages((prev) => {
|
||||
if (prev.length >= hydratedMessages.length) return prev;
|
||||
return hydratedMessages;
|
||||
// Deduplicate to handle rare cases where duplicate streams might occur
|
||||
return deduplicateMessages(hydratedMessages);
|
||||
});
|
||||
}, [hydratedMessages, setMessages, status]);
|
||||
|
||||
// Ref: tracks whether we've already resumed for a given session.
|
||||
// Reset when the stream ends so re-resume is possible if the backend
|
||||
// task is still running (SSE dropped but executor didn't finish).
|
||||
const hasResumedRef = useRef<string | null>(null);
|
||||
// Format: Map<sessionId, hasResumed>
|
||||
const hasResumedRef = useRef<Map<string, boolean>>(new Map());
|
||||
|
||||
// When the stream ends (or drops), invalidate the session cache so the
|
||||
// next hydration fetches fresh messages from the backend. Without this,
|
||||
@@ -208,29 +254,27 @@ export function useCopilotPage() {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
// Allow re-resume if the backend task is still running.
|
||||
hasResumedRef.current = null;
|
||||
}
|
||||
}, [status, sessionId, queryClient]);
|
||||
|
||||
// Resume an active stream AFTER hydration completes.
|
||||
// The backend returns active_stream info when a task is still running.
|
||||
// We wait for hydration so the AI SDK has the conversation history
|
||||
// before the resumed stream appends the in-progress assistant message.
|
||||
// IMPORTANT: Only runs when page loads with existing active stream (reconnection).
|
||||
// Does NOT run when new streams start during active conversation.
|
||||
useEffect(() => {
|
||||
if (!hasActiveStream || !sessionId) return;
|
||||
if (!sessionId) return;
|
||||
if (!hasActiveStream) return;
|
||||
if (!hydratedMessages || hydratedMessages.length === 0) return;
|
||||
if (status === "streaming" || status === "submitted") return;
|
||||
// Only resume once per session to avoid re-triggering after stream ends
|
||||
if (hasResumedRef.current === sessionId) return;
|
||||
hasResumedRef.current = sessionId;
|
||||
resumeStream();
|
||||
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
|
||||
|
||||
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
|
||||
// is in progress. When the backend completes, the session data will contain
|
||||
// the final tool output — this hook detects the change and updates messages.
|
||||
useLongRunningToolPolling(sessionId, messages, setMessages);
|
||||
// Never resume if currently streaming
|
||||
if (status === "streaming" || status === "submitted") return;
|
||||
|
||||
// Only resume once per session
|
||||
if (hasResumedRef.current.get(sessionId)) return;
|
||||
|
||||
// Mark as resumed immediately to prevent race conditions
|
||||
hasResumedRef.current.set(sessionId, true);
|
||||
resumeStream();
|
||||
}, [sessionId, hasActiveStream, hydratedMessages, status, resumeStream]);
|
||||
|
||||
// Clear messages when session is null
|
||||
useEffect(() => {
|
||||
@@ -321,6 +365,7 @@ export function useCopilotPage() {
|
||||
stop,
|
||||
isReconnecting,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUserLoading,
|
||||
isLoggedIn,
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
import { environment } from "@/services/environment";
|
||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||
import { NextRequest } from "next/server";
|
||||
import { normalizeSSEStream, SSE_HEADERS } from "../../../sse-helpers";
|
||||
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ taskId: string }> },
|
||||
) {
|
||||
const { taskId } = await params;
|
||||
const searchParams = request.nextUrl.searchParams;
|
||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
||||
|
||||
try {
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
Accept: "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
Connection: "keep-alive",
|
||||
};
|
||||
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(streamUrl.toString(), {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
return new Response(error, {
|
||||
status: response.status,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
return new Response(null, { status: 204 });
|
||||
}
|
||||
|
||||
return new Response(normalizeSSEStream(response.body), {
|
||||
headers: SSE_HEADERS,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Task stream proxy error:", error);
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
error: "Failed to connect to task stream",
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
}),
|
||||
{
|
||||
status: 500,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -961,63 +961,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/operations/{operation_id}/complete": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Complete Operation",
|
||||
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
||||
"operationId": "postV2CompleteOperation",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "operation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Operation Id" }
|
||||
},
|
||||
{
|
||||
"name": "x-api-key",
|
||||
"in": "header",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "X-Api-Key"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/OperationCompleteRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Postv2Completeoperation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/schema/tool-responses": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -1057,12 +1000,7 @@
|
||||
{ "$ref": "#/components/schemas/BlockDetailsResponse" },
|
||||
{ "$ref": "#/components/schemas/BlockOutputResponse" },
|
||||
{ "$ref": "#/components/schemas/DocSearchResultsResponse" },
|
||||
{ "$ref": "#/components/schemas/DocPageResponse" },
|
||||
{ "$ref": "#/components/schemas/OperationStartedResponse" },
|
||||
{ "$ref": "#/components/schemas/OperationPendingResponse" },
|
||||
{
|
||||
"$ref": "#/components/schemas/OperationInProgressResponse"
|
||||
}
|
||||
{ "$ref": "#/components/schemas/DocPageResponse" }
|
||||
],
|
||||
"title": "Response Getv2[Dummy] Tool Response Type Export For Codegen"
|
||||
}
|
||||
@@ -1185,7 +1123,7 @@
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Session",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns active_stream info for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1337,7 +1275,7 @@
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Chat Post",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"operationId": "postV2StreamChatPost",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1375,94 +1313,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Task Status",
|
||||
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2GetTaskStatus",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Getv2Gettaskstatus"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Task",
|
||||
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
|
||||
"operationId": "getV2StreamTask",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
},
|
||||
{
|
||||
"name": "last_message_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
"default": "0-0",
|
||||
"title": "Last Message Id"
|
||||
},
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -6562,13 +6412,11 @@
|
||||
},
|
||||
"ActiveStreamInfo": {
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "title": "Task Id" },
|
||||
"last_message_id": { "type": "string", "title": "Last Message Id" },
|
||||
"operation_id": { "type": "string", "title": "Operation Id" },
|
||||
"tool_name": { "type": "string", "title": "Tool Name" }
|
||||
"turn_id": { "type": "string", "title": "Turn Id" },
|
||||
"last_message_id": { "type": "string", "title": "Last Message Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
|
||||
"required": ["turn_id", "last_message_id"],
|
||||
"title": "ActiveStreamInfo",
|
||||
"description": "Information about an active stream for reconnection."
|
||||
},
|
||||
@@ -7578,10 +7426,6 @@
|
||||
"CancelTaskResponse": {
|
||||
"properties": {
|
||||
"cancelled": { "type": "boolean", "title": "Cancelled" },
|
||||
"task_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Task Id"
|
||||
},
|
||||
"reason": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Reason"
|
||||
@@ -10107,87 +9951,6 @@
|
||||
],
|
||||
"title": "OnboardingStep"
|
||||
},
|
||||
"OperationCompleteRequest": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"result": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Result"
|
||||
},
|
||||
"error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Error"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["success"],
|
||||
"title": "OperationCompleteRequest",
|
||||
"description": "Request model for external completion webhook."
|
||||
},
|
||||
"OperationInProgressResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "operation_in_progress"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"tool_call_id": { "type": "string", "title": "Tool Call Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message", "tool_call_id"],
|
||||
"title": "OperationInProgressResponse",
|
||||
"description": "Response when an operation is already in progress.\n\nReturned for idempotency when the same tool_call_id is requested again\nwhile the background task is still running."
|
||||
},
|
||||
"OperationPendingResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "operation_pending"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"operation_id": { "type": "string", "title": "Operation Id" },
|
||||
"tool_name": { "type": "string", "title": "Tool Name" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message", "operation_id", "tool_name"],
|
||||
"title": "OperationPendingResponse",
|
||||
"description": "Response stored in chat history while a long-running operation is executing.\n\nThis is persisted to the database so users see a pending state when they\nrefresh before the operation completes."
|
||||
},
|
||||
"OperationStartedResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "operation_started"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"operation_id": { "type": "string", "title": "Operation Id" },
|
||||
"tool_name": { "type": "string", "title": "Tool Name" },
|
||||
"task_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Task Id"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message", "operation_id", "tool_name"],
|
||||
"title": "OperationStartedResponse",
|
||||
"description": "Response when a long-running operation has been started in the background.\n\nThis is returned immediately to the client while the operation continues\nto execute. The user can close the tab and check back later.\n\nThe task_id can be used to reconnect to the SSE stream via\nGET /chat/tasks/{task_id}/stream?last_idx=0"
|
||||
},
|
||||
"Pagination": {
|
||||
"properties": {
|
||||
"total_items": {
|
||||
@@ -10844,13 +10607,10 @@
|
||||
"workspace_file_metadata",
|
||||
"workspace_file_written",
|
||||
"workspace_file_deleted",
|
||||
"operation_started",
|
||||
"operation_pending",
|
||||
"operation_in_progress",
|
||||
"input_validation_error",
|
||||
"web_fetch",
|
||||
"bash_exec",
|
||||
"operation_status",
|
||||
"feature_request_search",
|
||||
"feature_request_created",
|
||||
"suggested_goal"
|
||||
|
||||
@@ -119,7 +119,7 @@ export function CredentialsFlatView({
|
||||
) : (
|
||||
!readOnly && (
|
||||
<Button
|
||||
variant="secondary"
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onAddCredential}
|
||||
className="w-fit"
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { preprocessInputSchema } from "./utils/input-schema-pre-processor";
|
||||
import { useMemo } from "react";
|
||||
import { customValidator } from "./utils/custom-validator";
|
||||
import Form from "./registry";
|
||||
import { ExtendedFormContextType } from "./types";
|
||||
import { customValidator } from "./utils/custom-validator";
|
||||
import { generateUiSchemaForCustomFields } from "./utils/generate-ui-schema";
|
||||
import { preprocessInputSchema } from "./utils/input-schema-pre-processor";
|
||||
|
||||
type FormRendererProps = {
|
||||
jsonSchema: RJSFSchema;
|
||||
@@ -12,15 +13,17 @@ type FormRendererProps = {
|
||||
uiSchema: any;
|
||||
initialValues: any;
|
||||
formContext: ExtendedFormContextType;
|
||||
className?: string;
|
||||
};
|
||||
|
||||
export const FormRenderer = ({
|
||||
export function FormRenderer({
|
||||
jsonSchema,
|
||||
handleChange,
|
||||
uiSchema,
|
||||
initialValues,
|
||||
formContext,
|
||||
}: FormRendererProps) => {
|
||||
className,
|
||||
}: FormRendererProps) {
|
||||
const preprocessedSchema = useMemo(() => {
|
||||
return preprocessInputSchema(jsonSchema);
|
||||
}, [jsonSchema]);
|
||||
@@ -31,7 +34,10 @@ export const FormRenderer = ({
|
||||
}, [preprocessedSchema, uiSchema]);
|
||||
|
||||
return (
|
||||
<div className={"mb-6 mt-4"} data-tutorial-id="input-handles">
|
||||
<div
|
||||
className={cn("mb-6 mt-4", className)}
|
||||
data-tutorial-id="input-handles"
|
||||
>
|
||||
<Form
|
||||
formContext={formContext}
|
||||
idPrefix="agpt"
|
||||
@@ -45,4 +51,4 @@ export const FormRenderer = ({
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -218,6 +218,17 @@ If you initially installed Docker with Hyper-V, you **don’t need to reinstall*
|
||||
|
||||
For more details, refer to [Docker's official documentation](https://docs.docker.com/desktop/windows/wsl/).
|
||||
|
||||
### ⚠️ Podman Not Supported
|
||||
|
||||
AutoGPT requires **Docker** (Docker Desktop or Docker Engine). **Podman and podman-compose are not supported** and may cause path resolution issues, particularly on Windows.
|
||||
|
||||
If you see errors like:
|
||||
```text
|
||||
Error: the specified Containerfile or Dockerfile does not exist, ..\..\autogpt_platform\backend\Dockerfile
|
||||
```
|
||||
|
||||
This indicates you're using Podman instead of Docker. Please install [Docker Desktop](https://docs.docker.com/desktop/) and use `docker compose` instead of `podman-compose`.
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
|
||||
Reference in New Issue
Block a user