Compare commits

..

4 Commits

Author SHA1 Message Date
Bently
5c01eb4fc8 Merge branch 'dev' into docs/deployment-env-variables 2026-02-23 16:43:20 +00:00
Bently
2d7431bde6 Merge branch 'dev' into docs/deployment-env-variables 2026-02-19 17:49:09 +00:00
Bentlybro
e934df3c0c fix: address code review feedback
- Add 'text' language identifier to code blocks (MD040)
- Add VAULT_ENC_KEY generation command (openssl rand -hex 16)
- Fix DB_HOST default to 'localhost' (not 'db')
- Add info box clarifying port numbers are internal Docker ports
- Update OAuth callback URL to not include port by default
- Clarify Docker service names are internal container DNS
2026-02-16 12:10:09 +00:00
Bentlybro
8d557d33e1 docs: add deployment environment variables guide
Closes #10961, Closes OPEN-2715

Documents all environment variables that must be configured when deploying
AutoGPT to a new server:

- Quick reference table of critical URLs that must change
- Configuration file locations and loading order
- Security keys that must be regenerated (with generation commands)
- Database, Redis, RabbitMQ configuration
- Default ports for all services
- OAuth callback URLs for all supported providers
- Full deployment checklist
- Docker vs external services guidance
2026-02-16 11:59:34 +00:00
55 changed files with 4470 additions and 1522 deletions

View File

@@ -2,17 +2,21 @@
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.completion_handler import (
process_operation_failure,
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
from backend.copilot.model import (
@@ -42,6 +46,9 @@ from backend.copilot.tools.models import (
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
SuggestedGoalResponse,
UnderstandingUpdatedResponse,
@@ -92,8 +99,10 @@ class CreateSessionResponse(BaseModel):
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
turn_id: str
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
@@ -127,9 +136,18 @@ class CancelTaskResponse(BaseModel):
"""Response model for the cancel task endpoint."""
cancelled: bool
task_id: str | None = None
reason: str | None = None
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -252,7 +270,7 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
@@ -270,21 +288,28 @@ async def get_session(
# Check if there's an active stream for this session
active_stream_info = None
active_session, last_message_id = await stream_registry.get_active_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
@@ -313,32 +338,39 @@ async def cancel_session_task(
"""
await _validate_and_get_session(session_id, user_id)
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
if not active_session:
active_task, _ = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return CancelTaskResponse(cancelled=False, reason="no_active_task")
await enqueue_cancel_task(session_id)
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
task_id = active_task.task_id
await enqueue_cancel_task(task_id)
logger.info(
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
f"session ...{session_id[-8:]}"
)
# Poll until the executor confirms the task is no longer running.
# Keep max_wait below typical reverse-proxy read timeouts.
poll_interval = 0.5
max_wait = 5.0
waited = 0.0
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
task = await stream_registry.get_session(session_id)
task = await stream_registry.get_task(task_id)
if task is None or task.status != "running":
logger.info(
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
)
return CancelTaskResponse(cancelled=True)
return CancelTaskResponse(cancelled=True, task_id=task_id)
logger.warning(
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s"
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
return CancelTaskResponse(
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
)
return CancelTaskResponse(cancelled=True, reason="cancel_published_not_confirmed")
@router.post(
@@ -358,15 +390,16 @@ async def stream_chat_post(
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
"""
import asyncio
@@ -413,19 +446,21 @@ async def stream_chat_post(
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_session_task(
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
turn_id=turn_id,
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_session_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -434,14 +469,12 @@ async def stream_chat_post(
},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
await enqueue_copilot_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
operation_id=operation_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
)
@@ -458,7 +491,7 @@ async def stream_chat_post(
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
@@ -466,12 +499,11 @@ async def stream_chat_post(
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=subscribe_from_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
@@ -554,19 +586,19 @@ async def stream_chat_post(
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {session_id}: {unsub_err}",
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
@@ -613,22 +645,17 @@ async def resume_session_stream(
"""
import asyncio
active_session, last_message_id = await stream_registry.get_active_session(
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_session:
if not active_task:
return Response(status_code=204)
# Subscribe from the beginning ("0-0") to replay all chunks for this turn.
# This is necessary because hydrated messages filter out incomplete tool calls
# to avoid "No tool invocation found" errors. The resume stream delivers
# those tool calls fresh with proper SDK state.
# The AI SDK's deduplication will handle any duplicate chunks.
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0",
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
@@ -664,12 +691,12 @@ async def resume_session_stream(
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
logger.info(
@@ -720,6 +747,229 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@@ -800,6 +1050,9 @@ ToolResponseUnion = (
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)

View File

@@ -42,6 +42,10 @@ import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.copilot.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -119,9 +123,21 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:

View File

@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
# Run the last process in the foreground.
processes[-1].start(background=False, **kwargs)
finally:
for process in reversed(processes):
for process in processes:
try:
process.stop()
except Exception as e:

View File

@@ -0,0 +1,349 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import uuid
from typing import Any
import orjson
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
Database operations are handled through the chat_db() accessor, which
routes through DatabaseManager RPC when Prisma is not directly connected.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
await process_operation_success(task, message.result)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
await process_operation_failure(task, message.error)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -0,0 +1,329 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from backend.data.db_accessors import chat_db
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
) -> None:
"""Update tool message in database using the chat_db accessor.
Routes through DatabaseManager RPC when Prisma is not directly
connected (e.g. in the CoPilot Executor microservice).
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
Raises:
ToolMessageUpdateError: If the database update fails.
"""
try:
updated = await chat_db().update_tool_message_content(
session_id=session_id,
tool_call_id=tool_call_id,
new_content=content,
)
if not updated:
raise ToolMessageUpdateError(
f"No message found with tool_call_id="
f"{tool_call_id} in session {session_id}"
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}",
exc_info=True,
)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
Raises:
ToolMessageUpdateError: If the database update fails. The task
will be marked as failed instead of completed.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database
with the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -36,6 +36,14 @@ class ChatConfig(BaseSettings):
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=3600,
description="TTL in seconds for long-running operation deduplication lock "
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
"For longer operations, the stream_registry heartbeat keeps them alive.",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
@@ -51,14 +59,36 @@ class ChatConfig(BaseSettings):
description="Maximum number of messages to store per stream",
)
# Redis key prefixes for stream registry
session_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for session metadata hash keys",
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
turn_stream_prefix: str = Field(
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for turn message stream keys",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
@@ -130,6 +160,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):

View File

@@ -181,13 +181,13 @@ class CoPilotExecutor(AppProcess):
self._executor.shutdown(wait=False)
# Release any remaining locks
for session_id, lock in list(self._task_locks.items()):
for task_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
)
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
@@ -267,20 +267,20 @@ class CoPilotExecutor(AppProcess):
):
"""Handle cancel message from FANOUT exchange."""
request = CancelCoPilotEvent.model_validate_json(body)
session_id = request.session_id
if not session_id:
logger.warning("Cancel message missing 'session_id'")
task_id = request.task_id
if not task_id:
logger.warning("Cancel message missing 'task_id'")
return
if session_id not in self.active_tasks:
logger.debug(f"Cancel received for {session_id} but not active")
if task_id not in self.active_tasks:
logger.debug(f"Cancel received for {task_id} but not active")
return
_, cancel_event = self.active_tasks[session_id]
logger.info(f"Received cancel for {session_id}")
_, cancel_event = self.active_tasks[task_id]
logger.info(f"Received cancel for {task_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(f"Cancel already set for {session_id}")
logger.debug(f"Cancel already set for {task_id}")
def _handle_run_message(
self,
@@ -352,12 +352,12 @@ class CoPilotExecutor(AppProcess):
ack_message(reject=True, requeue=False)
return
session_id = entry.session_id
task_id = entry.task_id
# Check for local duplicate - session is already running on this executor
if session_id in self.active_tasks:
# Check for local duplicate - task is already running on this executor
if task_id in self.active_tasks:
logger.warning(
f"Session {session_id} already running locally, rejecting duplicate"
f"Task {task_id} already running locally, rejecting duplicate"
)
ack_message(reject=True, requeue=False)
return
@@ -365,53 +365,53 @@ class CoPilotExecutor(AppProcess):
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:session:{session_id}:lock",
key=f"copilot:task:{task_id}:lock",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
if current_owner is not None:
logger.warning(
f"Session {session_id} already running on pod {current_owner}"
)
logger.warning(f"Task {task_id} already running on pod {current_owner}")
ack_message(reject=True, requeue=False)
else:
logger.warning(
f"Could not acquire lock for {session_id} - Redis unavailable"
f"Could not acquire lock for {task_id} - Redis unavailable"
)
ack_message(reject=True, requeue=True)
return
# Execute the task
try:
self._task_locks[session_id] = cluster_lock
self._task_locks[task_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {session_id}, "
f"executor_id={self.executor_id}"
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_task, entry, cancel_event, cluster_lock
)
self.active_tasks[session_id] = (future, cancel_event)
self.active_tasks[task_id] = (future, cancel_event)
except Exception as e:
logger.warning(f"Failed to setup execution for {session_id}: {e}")
logger.warning(f"Failed to setup execution for {task_id}: {e}")
cluster_lock.release()
if session_id in self._task_locks:
del self._task_locks[session_id]
if task_id in self._task_locks:
del self._task_locks[task_id]
ack_message(reject=True, requeue=True)
return
self._update_metrics()
def on_run_done(f: Future):
logger.info(f"Run completed for {session_id}")
logger.info(f"Run completed for {task_id}")
try:
if exec_error := f.exception():
logger.error(f"Execution for {session_id} failed: {exec_error}")
logger.error(f"Execution for {task_id} failed: {exec_error}")
# Don't requeue failed tasks - they've been marked as failed
# in the stream registry. Requeuing would cause infinite retries
# for deterministic failures.
ack_message(reject=True, requeue=False)
else:
ack_message(reject=False, requeue=False)
@@ -419,10 +419,10 @@ class CoPilotExecutor(AppProcess):
logger.exception(f"Error in run completion callback: {e}")
finally:
# Release the cluster lock
if session_id in self._task_locks:
logger.info(f"Releasing cluster lock for {session_id}")
self._task_locks[session_id].release()
del self._task_locks[session_id]
if task_id in self._task_locks:
logger.info(f"Releasing cluster lock for {task_id}")
self._task_locks[task_id].release()
del self._task_locks[task_id]
self._cleanup_completed_tasks()
future.add_done_callback(on_run_done)
@@ -433,11 +433,11 @@ class CoPilotExecutor(AppProcess):
"""Remove completed futures from active_tasks and update metrics."""
completed_tasks = []
with self._active_tasks_lock:
for session_id, (future, _) in list(self.active_tasks.items()):
for task_id, (future, _) in list(self.active_tasks.items()):
if future.done():
completed_tasks.append(session_id)
self.active_tasks.pop(session_id, None)
logger.info(f"Cleaned up completed session {session_id}")
completed_tasks.append(task_id)
self.active_tasks.pop(task_id, None)
logger.info(f"Cleaned up completed task {task_id}")
self._update_metrics()
return completed_tasks

View File

@@ -12,7 +12,7 @@ import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamFinish, StreamFinishStep
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
from backend.copilot.sdk import service as sdk_service
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
@@ -151,6 +151,7 @@ class CoPilotProcessor:
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
task_id=entry.task_id,
session_id=entry.session_id,
user_id=entry.user_id,
)
@@ -239,49 +240,52 @@ class CoPilotProcessor:
if cancel.is_set():
log.info("Cancelled during streaming")
await stream_registry.publish_chunk(
entry.turn_id, StreamFinishStep()
entry.task_id, StreamError(errorText="Operation cancelled")
)
await stream_registry.mark_session_completed(
entry.session_id,
error_message="Operation cancelled",
await stream_registry.publish_chunk(
entry.task_id, StreamFinishStep()
)
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
await stream_registry.mark_task_completed(
entry.task_id, status="failed"
)
return
# Refresh cluster lock periodically
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
if isinstance(chunk, StreamFinish):
break
# Publish chunk to stream registry
await stream_registry.publish_chunk(entry.task_id, chunk)
try:
await stream_registry.publish_chunk(entry.turn_id, chunk)
except Exception as e:
log.error(
f"Error publishing chunk {type(chunk).__name__}: {e}",
exc_info=True,
)
await stream_registry.mark_session_completed(entry.session_id)
# Mark task as completed
await stream_registry.mark_task_completed(entry.task_id, status="completed")
log.info("Task completed successfully")
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_session_completed(
entry.session_id, error_message="Task was cancelled"
await stream_registry.mark_task_completed(
entry.task_id,
status="failed",
error_message="Task was cancelled",
)
raise
except Exception as e:
log.error(f"Task failed: {e}")
try:
await stream_registry.publish_chunk(entry.turn_id, StreamFinishStep())
await stream_registry.mark_session_completed(
entry.session_id, error_message=str(e)
)
except Exception as mark_err:
logger.error(
f"Failed to mark session {entry.session_id} as failed: {mark_err}"
)
await self._mark_task_failed(entry.task_id, str(e))
raise
async def _mark_task_failed(self, task_id: str, error_message: str):
"""Mark a task as failed and publish error to stream registry."""
try:
await stream_registry.publish_chunk(
task_id, StreamError(errorText=error_message)
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception as e:
logger.error(f"Failed to mark task {task_id} as failed: {e}")

View File

@@ -28,7 +28,7 @@ class CoPilotLogMetadata(TruncatedLogger):
Args:
logger: The underlying logger instance
max_length: Maximum log message length before truncation
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
These are added to json_fields in cloud mode, or to the prefix in local mode.
"""
@@ -135,15 +135,18 @@ class CoPilotExecutionEntry(BaseModel):
This model represents a chat generation task to be processed by the executor.
"""
session_id: str
"""Chat session ID (also used for dedup/locking)"""
task_id: str
"""Unique identifier for this task (used for stream registry)"""
turn_id: str = ""
"""Per-turn UUID for Redis stream isolation"""
session_id: str
"""Chat session ID"""
user_id: str | None
"""User ID (may be None for anonymous users)"""
operation_id: str
"""Operation ID for webhook callbacks and completion tracking"""
message: str
"""User's message to process"""
@@ -157,37 +160,40 @@ class CoPilotExecutionEntry(BaseModel):
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
session_id: str
"""Session ID to cancel"""
task_id: str
"""Task ID to cancel"""
# ============ Queue Publishing Helpers ============ #
async def enqueue_copilot_task(
task_id: str,
session_id: str,
user_id: str | None,
operation_id: str,
message: str,
turn_id: str = "",
is_user_message: bool = True,
context: dict[str, str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
Args:
session_id: Chat session ID (also used for dedup/locking)
task_id: Unique identifier for this task (used for stream registry)
session_id: Chat session ID
user_id: User ID (may be None for anonymous users)
operation_id: Operation ID for webhook callbacks and completion tracking
message: User's message to process
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
"""
from backend.util.clients import get_async_copilot_queue
entry = CoPilotExecutionEntry(
task_id=task_id,
session_id=session_id,
turn_id=turn_id,
user_id=user_id,
operation_id=operation_id,
message=message,
is_user_message=is_user_message,
context=context,
@@ -201,15 +207,15 @@ async def enqueue_copilot_task(
)
async def enqueue_cancel_task(session_id: str) -> None:
"""Publish a cancel request for a running CoPilot session.
async def enqueue_cancel_task(task_id: str) -> None:
"""Publish a cancel request for a running CoPilot task.
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
pods receive the cancellation signal.
"""
from backend.util.clients import get_async_copilot_queue
event = CancelCoPilotEvent(session_id=session_id)
event = CancelCoPilotEvent(task_id=task_id)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key="", # FANOUT ignores routing key

View File

@@ -14,6 +14,7 @@ import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
# Import here to allow module-level mocking if needed
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
@@ -31,6 +32,7 @@ async def test_parallel_tool_calls_run_concurrently():
for i in range(n_tools)
]
# Minimal session mock
class FakeSession:
session_id = "test"
user_id = "test"
@@ -40,7 +42,7 @@ async def test_parallel_tool_calls_run_concurrently():
original_yield = None
async def fake_yield(tc_list, idx, sess):
async def fake_yield(tc_list, idx, sess, lock=None):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
@@ -99,7 +101,7 @@ async def test_single_tool_call_works():
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
async def fake_yield(tc_list, idx, sess, lock=None):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
@@ -142,7 +144,7 @@ async def test_retryable_error_propagates():
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
async def fake_yield(tc_list, idx, sess, lock=None):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
@@ -173,8 +175,8 @@ async def test_retryable_error_propagates():
@pytest.mark.asyncio
async def test_session_shared_across_parallel_tools():
"""All parallel tools should receive the same session instance."""
async def test_session_lock_shared():
"""All parallel tools should receive the same lock instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
@@ -197,10 +199,10 @@ async def test_session_shared_across_parallel_tools():
def __init__(self):
self.messages = []
observed_sessions = []
observed_locks = []
async def fake_yield(tc_list, idx, sess):
observed_sessions.append(sess)
async def fake_yield(tc_list, idx, sess, lock=None):
observed_locks.append(lock)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
@@ -220,8 +222,9 @@ async def test_session_shared_across_parallel_tools():
finally:
svc._yield_tool_call = orig
assert len(observed_sessions) == 3
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
assert len(observed_locks) == 3
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
assert isinstance(observed_locks[0], asyncio.Lock)
@pytest.mark.asyncio
@@ -248,7 +251,7 @@ async def test_cancellation_cleans_up():
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess):
async def fake_yield(tc_list, idx, sess, lock=None):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)

View File

@@ -5,8 +5,6 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
import json
import logging
from enum import Enum
from typing import Any
@@ -14,8 +12,6 @@ from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
logger = logging.getLogger(__name__)
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -51,8 +47,7 @@ class StreamBaseResponse(BaseModel):
def to_sse(self) -> str:
"""Convert to SSE format."""
json_str = self.model_dump_json(exclude_none=True)
return f"data: {json_str}\n\n"
return f"data: {self.model_dump_json()}\n\n"
# ========== Message Lifecycle ==========
@@ -63,13 +58,15 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
sessionId: str | None = Field(
taskId: str | None = Field(
default=None,
description="Session ID for SSE reconnection.",
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
@@ -166,6 +163,8 @@ class StreamToolOutputAvailable(StreamBaseResponse):
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,

View File

@@ -1,65 +0,0 @@
"""Dummy SDK service for testing copilot streaming.
Returns mock streaming responses without calling Claude Agent SDK.
Enable via COPILOT_TEST_MODE=true environment variable.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from ..model import ChatSession
from ..response_model import (
StreamBaseResponse,
StreamFinish,
StreamStart,
StreamTextDelta,
)
logger = logging.getLogger(__name__)
async def stream_chat_completion_dummy(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream dummy chat completion for testing.
Returns a simple streaming response with text deltas to test:
- Streaming infrastructure works
- No timeout occurs
- Text arrives in chunks
- StreamFinish is sent
"""
logger.warning(
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
)
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
# Start the stream
yield StreamStart(messageId=message_id, sessionId=session_id)
# Simulate streaming text response with delays
dummy_response = "I counted: 1... 2... 3. All done!"
words = dummy_response.split()
for i, word in enumerate(words):
# Add space except for last word
text = word if i == len(words) - 1 else f"{word} "
yield StreamTextDelta(id=text_block_id, delta=text)
# Small delay to simulate real streaming
await asyncio.sleep(0.1)
# Finish the stream
yield StreamFinish()

View File

@@ -55,8 +55,13 @@ class SDKResponseAdapter:
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.resolved_tool_calls: set[str] = set()
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
@property
def has_unresolved_tool_calls(self) -> bool:
"""True when there are tool calls that haven't received output yet."""
@@ -69,7 +74,7 @@ class SDKResponseAdapter:
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, sessionId=self.session_id)
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())

View File

@@ -37,7 +37,9 @@ from .tool_adapter import wait_for_stash
def _adapter() -> SDKResponseAdapter:
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
@@ -49,7 +51,7 @@ def test_system_init_emits_start_and_step():
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].sessionId == "session-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)

View File

@@ -13,6 +13,7 @@ from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
@@ -32,7 +33,12 @@ from ..response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import _build_system_prompt, _generate_session_title
from ..service import (
_build_system_prompt,
_execute_long_running_tool_with_streaming,
_generate_session_title,
)
from ..tools.models import OperationPendingResponse, OperationStartedResponse
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
@@ -40,6 +46,7 @@ from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
SDK_DISALLOWED_TOOLS,
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
wait_for_stash,
@@ -77,8 +84,7 @@ class CapturedTranscript:
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
# IMPORTANT: Must be less than frontend timeout (12s in useCopilotPage.ts)
_HEARTBEAT_INTERVAL = 10.0 # seconds
_HEARTBEAT_INTERVAL = 15.0 # seconds
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
@@ -132,6 +138,127 @@ is delivered to the user via a background stream.
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
def _build_long_running_callback(
user_id: str | None,
) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
existing background infrastructure: stream_registry (Redis Streams),
database persistence, and SSE reconnection. This means results survive
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
Args:
user_id: User ID for the session
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
async def _callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
session_id = session.session_id
# --- Build user-friendly messages (matches non-SDK service) ---
if tool_name == "create_agent":
desc = args.get("description", "")
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = args.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Save OperationPendingResponse to chat history ---
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
# Collision detection happens in add_chat_messages_batch (db.py)
session = await upsert_chat_session(session)
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(
f"[SDK] Long-running tool {tool_name} delegated to background "
f"(operation_id={operation_id}, task_id={task_id})"
)
# --- Return OperationStartedResponse as MCP tool result ---
# This flows through SDK → response adapter → frontend, triggering
# the loading widget with SSE reconnection support.
started_json = OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id,
).model_dump_json()
return {
"content": [{"type": "text", "text": started_json}],
"isError": False,
}
return _callback
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
@@ -450,7 +577,8 @@ async def stream_chat_completion_sdk(
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
stream_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
stream_id = task_id # Use task_id as unique stream identifier
# Acquire stream lock to prevent concurrent streams to the same session
lock = AsyncClusterLock(
@@ -474,7 +602,7 @@ async def stream_chat_completion_sdk(
yield StreamFinish()
return
yield StreamStart(messageId=message_id, sessionId=session_id)
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Initialise variables before the try so the finally block can
@@ -490,7 +618,11 @@ async def stream_chat_completion_sdk(
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(user_id, session)
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
@@ -582,6 +714,7 @@ async def stream_chat_completion_sdk(
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
@@ -606,7 +739,8 @@ async def stream_chat_completion_sdk(
session_id,
)
logger.info(
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
"[SDK] [%s] Sending query — resume=%s, "
"total_msgs=%d, query_len=%d",
session_id[:12],
use_resume,
len(session.messages),
@@ -655,7 +789,8 @@ async def stream_chat_completion_sdk(
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
"[SDK] [%s] Stream ended normally "
"(StopAsyncIteration)",
session_id[:12],
)
break
@@ -792,6 +927,18 @@ async def stream_chat_completion_sdk(
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# Save before tool execution starts so the
# pending tool call is visible on refresh /
# other devices. Collision detection happens
# in add_chat_messages_batch (db.py).
try:
session = await upsert_chat_session(session)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
@@ -806,6 +953,17 @@ async def stream_chat_completion_sdk(
)
)
has_tool_results = True
# Save after tool completes so the result is
# visible on refresh / other devices.
# Collision detection happens in add_chat_messages_batch (db.py).
try:
session = await upsert_chat_session(session)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamFinish):
stream_completed = True
@@ -815,7 +973,8 @@ async def stream_chat_completion_sdk(
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
"[SDK] [%s] Streaming loop cancelled "
"(asyncio.CancelledError)",
session_id[:12],
)
raise
@@ -895,7 +1054,7 @@ async def stream_chat_completion_sdk(
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), read result: %s",
"[SDK] Transcript source: stop hook (%s), " "read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
@@ -940,23 +1099,10 @@ async def stream_chat_completion_sdk(
yield StreamFinish()
except asyncio.CancelledError:
# Client disconnect / server shutdown — save session before re-raising
# so accumulated messages aren't lost.
# Client disconnect / server shutdown — log but re-raise so
# the framework can clean up. The finally block still runs
# for transcript upload.
logger.warning("[SDK] [%s] Session cancelled (CancelledError)", session_id[:12])
if session:
try:
await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session saved on cancel (%d messages)",
session_id[:12],
len(session.messages),
)
except Exception as save_err:
logger.error(
"[SDK] [%s] Failed to save session on cancel: %s",
session_id[:12],
save_err,
)
raise
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)

View File

@@ -2,6 +2,11 @@
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import asyncio
@@ -10,6 +15,7 @@ import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
@@ -37,8 +43,7 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs",
default=None, # type: ignore[arg-type]
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
# Event signaled whenever stash_pending_tool_output() adds a new entry.
# Used by the streaming loop to wait for PostToolUse hooks to complete
@@ -49,10 +54,22 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
"_stash_event", default=None
)
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
@@ -62,11 +79,14 @@ def set_execution_context(
Args:
user_id: Current user's ID.
session: Current chat session.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
@@ -256,6 +276,11 @@ def create_tool_handler(base_tool: BaseTool):
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
@@ -265,6 +290,25 @@ def create_tool_handler(base_tool: BaseTool):
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,12 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming sessions and their
This module provides a registry for tracking active streaming tasks and their
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay and real-time delivery
- Redis Hash: Session metadata (status, session_id, etc.)
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream (XREAD)
@@ -16,7 +16,6 @@ Subscribers:
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
@@ -26,25 +25,17 @@ import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS
from .response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamTextDelta,
StreamTextStart,
)
from .response_model import StreamBaseResponse, StreamError, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_sessions: dict[str, asyncio.Task] = {}
_local_tasks: dict[str, asyncio.Task] = {}
# Track listener tasks per subscriber queue for cleanup
# Maps queue id() to (session_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_sessions: dict[int, tuple[str, asyncio.Task]] = {}
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
# If the queue is full and doesn't drain within this time, send an overflow error
@@ -52,7 +43,7 @@ QUEUE_PUT_TIMEOUT = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
COMPLETE_SESSION_SCRIPT = """
COMPLETE_TASK_SCRIPT = """
local current = redis.call("HGET", KEYS[1], "status")
if current == "running" then
redis.call("HSET", KEYS[1], "status", ARGV[1])
@@ -63,92 +54,81 @@ return 0
@dataclass
class ActiveSession:
"""Represents an active streaming session (metadata only, no in-memory queues)."""
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
turn_id: str = ""
blocking: bool = False # If True, HTTP request is waiting for completion
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_session_meta_key(session_id: str) -> str:
"""Get Redis key for session metadata (keyed by session_id)."""
return f"{config.session_meta_prefix}{session_id}"
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_turn_stream_key(turn_id: str) -> str:
"""Get Redis key for turn message stream (keyed by turn_id for per-turn isolation)."""
return f"{config.turn_stream_prefix}{turn_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSession:
"""Parse a raw Redis hash into a typed ActiveSession.
Centralises the ``meta.get(...)`` boilerplate so callers don't repeat it.
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
pre-dates the turn_id field (backward compat for in-flight sessions).
"""
return ActiveSession(
session_id=meta.get("session_id", "") or session_id,
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
turn_id=meta.get("turn_id", "") or session_id,
blocking=meta.get("blocking") == "1",
status=meta.get("status", "running"), # type: ignore[arg-type]
)
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{config.task_op_prefix}{operation_id}"
async def create_session_task(
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
turn_id: str = "",
blocking: bool = False,
) -> ActiveSession:
"""Create a new streaming session in Redis (keyed by session_id).
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
Args:
session_id: Chat session ID (used as session identifier)
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
turn_id: Unique per-turn UUID for stream isolation
blocking: If True, HTTP request is waiting for completion
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveSession instance (metadata only)
The created ActiveTask instance (metadata only)
"""
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_session_task STARTED, session={session_id}, user={user_id}, turn_id={turn_id}",
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
# Create session
session = ActiveSession(
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
turn_id=turn_id,
blocking=blocking,
operation_id=operation_id,
)
# Store metadata in Redis
@@ -160,21 +140,21 @@ async def create_session_task(
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_session_meta_key(session_id)
# No need to delete old stream — each turn_id is a fresh UUID
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"turn_id": turn_id,
"blocking": "1" if blocking else "0",
"status": session.status,
"created_at": session.created_at.isoformat(),
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
hset_time = (time.perf_counter() - hset_start) * 1000
@@ -185,17 +165,20 @@ async def create_session_task(
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] create_session_task COMPLETED in {total_time:.1f}ms; session={session_id}",
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
return session
return task
async def publish_chunk(
turn_id: str,
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream.
@@ -203,12 +186,14 @@ async def publish_chunk(
All delivery is via Redis Streams - no in-memory state.
Args:
turn_id: Turn ID (per-turn UUID) identifying the stream
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID
"""
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json()
@@ -217,13 +202,13 @@ async def publish_chunk(
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"turn_id": turn_id,
"task_id": task_id,
"chunk_type": chunk_type,
}
try:
redis = await get_redis_async()
stream_key = _get_turn_stream_key(turn_id)
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
@@ -235,7 +220,7 @@ async def publish_chunk(
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match session metadata TTL
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
@@ -274,39 +259,41 @@ async def publish_chunk(
return message_id
async def subscribe_to_session(
session_id: str,
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a session's stream with replay of missed messages.
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
session_id: Session ID to subscribe to
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if session not found
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "session_id": session_id}
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_session STARTED, session={session_id}, user={user_id}, last_msg={last_message_id}",
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
@@ -314,69 +301,54 @@ async def subscribe_to_session(
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
# RACE CONDITION FIX: If session not found, retry once after small delay
# This handles the case where subscribe_to_session is called immediately
# after create_session_task but before Redis propagates the write
if not meta:
logger.warning(
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
extra={"json_fields": {**log_meta}},
)
await asyncio.sleep(0.05) # 50ms
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "session_not_found_after_retry",
}
},
)
return None
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
"[TIMING] Session found after retry",
extra={"json_fields": {**log_meta}},
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
return None
# Note: Redis client uses decode_responses=True, so keys are strings
session_status = meta.get("status", "")
session_user_id = meta.get("user_id", "") or None
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if session has an owner, requester must match
if session_user_id:
if user_id != session_user_id:
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"[TIMING] Access denied: user {user_id} tried to access session owned by {session_user_id}",
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
extra={
"json_fields": {
**log_meta,
"session_owner": session_user_id,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
)
return None
session = _parse_session_meta(meta, session_id)
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_turn_stream_key(session.turn_id)
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=None, count=1000)
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={session_status}",
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"session_status": session_status,
"task_status": task_status,
}
},
)
@@ -409,30 +381,28 @@ async def subscribe_to_session(
},
)
# Step 2: If session is still running, start stream listener for live updates
if session_status == "running":
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
logger.info(
"[TIMING] Session still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "session_status": session_status}},
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task(
_stream_listener(
session_id, subscriber_queue, replay_last_id, log_meta, session.turn_id
)
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
)
# Track listener task for cleanup on unsubscribe
_listener_sessions[id(subscriber_queue)] = (session_id, listener_task)
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Session is completed/failed - add finish marker
# Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Session already {session_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "session_status": session_status}},
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_session COMPLETED in {total_time:.1f}ms; session={session_id}, "
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
@@ -446,11 +416,10 @@ async def subscribe_to_session(
async def _stream_listener(
session_id: str,
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
log_meta: dict | None = None,
turn_id: str = "",
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
@@ -458,20 +427,21 @@ async def _stream_listener(
when messages are published during the gap between replay and subscription.
Args:
session_id: Session ID to listen for
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
turn_id: Per-turn UUID for stream key resolution
"""
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "session_id": session_id}
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, session={session_id}, last_id={last_replayed_id}",
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
@@ -484,17 +454,16 @@ async def _stream_listener(
try:
redis = await get_redis_async()
stream_key = _get_turn_stream_key(turn_id)
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
while True:
# Block for up to 5 seconds waiting for new messages
# This allows periodic checking if session is still running
# Short timeout prevents frontend timeout (12s) while waiting for heartbeats (15s)
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread(
{stream_key: current_id}, block=5000, count=100
{stream_key: current_id}, block=30000, count=100
)
xread_time = (time.perf_counter() - xread_start) * 1000
@@ -526,8 +495,8 @@ async def _stream_listener(
)
if not messages:
# Timeout - check if session is still running
meta_key = _get_session_meta_key(session_id)
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
try:
@@ -537,20 +506,9 @@ async def _stream_listener(
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for session {session_id}"
f"Timeout delivering finish event for task {task_id}"
)
break
# Session still running - send heartbeat to keep connection alive
# This prevents frontend timeout (12s) during long-running operations
try:
await asyncio.wait_for(
subscriber_queue.put(StreamHeartbeat()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering heartbeat for session {session_id}"
)
continue
for _stream_name, stream_messages in messages:
@@ -610,7 +568,7 @@ async def _stream_listener(
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for session {session_id}, "
f"Cannot deliver overflow error for task {task_id}, "
"queue completely blocked"
)
@@ -618,7 +576,7 @@ async def _stream_listener(
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
@@ -666,10 +624,10 @@ async def _stream_listener(
extra={"json_fields": log_meta},
)
finally:
# Clean up listener session mapping on exit
# Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time / 1000:.1f}s; session={session_id}, "
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
@@ -680,198 +638,238 @@ async def _stream_listener(
}
},
)
_listener_sessions.pop(queue_id, None)
_listener_tasks.pop(queue_id, None)
async def mark_session_completed(
session_id: str,
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
*,
error_message: str | None = None,
) -> bool:
"""Mark a session as completed, then publish StreamFinish.
This is the SINGLE place that publishes StreamFinish to the turn stream.
Services must NOT yield StreamFinish themselves — the processor intercepts
it and calls this function instead, ensuring status is set before
StreamFinish reaches the frontend.
"""Mark a task as completed and publish finish event.
This is idempotent - calling multiple times with the same task_id is safe.
Uses atomic compare-and-swap via Lua script to prevent race conditions.
Idempotent — calling multiple times is safe (returns False on no-op).
Status is updated first (source of truth), then finish event is published (best-effort).
Args:
session_id: Session ID to mark as completed
error_message: If provided, marks as "failed" and publishes a
StreamError before StreamFinish. Otherwise marks as "completed".
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
error_message: If provided and status="failed", publish a StreamError
before StreamFinish so connected clients see why the task ended.
If not provided, no StreamError is published (caller should publish
manually if needed to avoid duplicates).
Returns:
True if session was newly marked completed, False if already completed/failed
True if task was newly marked completed, False if already completed/failed
"""
status: Literal["completed", "failed"] = "failed" if error_message else "completed"
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
# Resolve turn_id for publishing to the correct stream
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
turn_id = _parse_session_meta(meta, session_id).turn_id if meta else session_id
meta_key = _get_task_meta_key(task_id)
# Atomic compare-and-swap: only update if status is "running"
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
# This prevents race conditions when multiple callers try to complete simultaneously
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Session {session_id} already completed/failed, skipping")
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
if error_message:
# Publish error event before finish so connected clients know WHY the
# task ended. Only publish if caller provided an explicit error message
# to avoid duplicates with code paths that manually publish StreamError.
# This is best-effort — if it fails, the StreamFinish still ensures
# listeners clean up.
if status == "failed" and error_message:
try:
await publish_chunk(turn_id, StreamError(errorText=error_message))
await publish_chunk(task_id, StreamError(errorText=error_message))
except Exception as e:
logger.warning(
f"Failed to publish error event for session {session_id}: {e}"
)
logger.warning(f"Failed to publish error event for task {task_id}: {e}")
# Publish StreamFinish AFTER status is set to "completed"/"failed".
# This is the SINGLE place that publishes StreamFinish — services and
# the processor must NOT publish it themselves.
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(turn_id, StreamFinish())
await publish_chunk(task_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish StreamFinish for session {session_id}: {e}. "
"The _stream_listener will detect completion via status polling."
f"Failed to publish finish event for task {task_id}: {e}. "
"Listeners will detect completion via status polling."
)
# Clean up local session reference if exists
_local_sessions.pop(session_id, None)
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
return True
async def get_session(session_id: str) -> ActiveSession | None:
"""Get a session by its ID from Redis.
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
session_id: Session ID to look up
operation_id: Operation ID to search for
Returns:
ActiveSession if found, None otherwise
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if not task_id:
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None
return _parse_session_meta(meta, session_id)
# Note: Redis client uses decode_responses=True, so keys/values are strings
return ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
async def get_session_with_expiry_info(
session_id: str,
) -> tuple[ActiveSession | None, str | None]:
"""Get a session by its ID with expiration detection.
async def get_task_with_expiry_info(
task_id: str,
) -> tuple[ActiveTask | None, str | None]:
"""Get a task by its ID with expiration detection.
Returns (session, error_code) where error_code is:
- None if session found
Returns (task, error_code) where error_code is:
- None if task found
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
- "TASK_NOT_FOUND" if neither exists
Args:
session_id: Session ID to look up
task_id: Task ID to look up
Returns:
Tuple of (ActiveSession or None, error_code or None)
Tuple of (ActiveTask or None, error_code or None)
"""
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
meta_key = _get_task_meta_key(task_id)
stream_key = _get_task_stream_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
# Metadata expired — we can't resolve turn_id, so check using
# session_id as a best-effort fallback for the stream key.
stream_key = _get_turn_stream_key(session_id)
# Check if stream still has data (metadata expired but stream hasn't)
stream_len = await redis.xlen(stream_key)
if stream_len > 0:
return None, "TASK_EXPIRED"
return None, "TASK_NOT_FOUND"
return _parse_session_meta(meta, session_id), None
# Note: Redis client uses decode_responses=True, so keys/values are strings
return (
ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
),
None,
)
async def get_active_session(
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveSession | None, str]:
"""Get the active (running) session, if any.
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Direct O(1) lookup by session_id.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveSession if found and running, last_message_id from Redis Stream)
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None, "0-0"
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
session_status = meta.get("status", "")
session_user_id = meta.get("user_id", "") or None
while True:
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
if session_status != "running":
return None, "0-0"
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Validate ownership - if session has an owner, requester must match
if session_user_id and user_id != session_user_id:
return None, "0-0"
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
# Check if session is stale (running beyond tool timeout + buffer)
# Auto-complete it to prevent infinite polling loops
# Note: Synchronous tools can run up to COPILOT_CONSUMER_TIMEOUT_SECONDS (1 hour)
# so we add a 5-minute buffer to avoid false positives during legitimate operations
created_at_str = meta.get("created_at")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (datetime.now(timezone.utc) - created_at).total_seconds()
stale_threshold = (
COPILOT_CONSUMER_TIMEOUT_SECONDS + 300
) # + 5 minutes buffer
if age_seconds > stale_threshold:
logger.warning(
f"[STALE_SESSION] Auto-completing stale session {session_id[:8]}... "
f"(running for {age_seconds:.0f}s, threshold: {stale_threshold}s)"
if task_session_id == session_id and task_status == "running":
# Validate ownership - if task has an owner, requester must match
if task_user_id and user_id != task_user_id:
continue
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
await mark_session_completed(
session_id,
error_message=f"Session timed out after {age_seconds:.0f}s",
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
return None, "0-0"
except (ValueError, TypeError) as e:
logger.warning(f"Failed to parse created_at: {e}")
session = _parse_session_meta(meta, session_id)
logger.info(
f"[SESSION_LOOKUP] Found running session {session_id[:8]}..., turn_id={session.turn_id[:8]}"
)
if cursor == 0:
break
# Get the last message ID from Redis Stream (keyed by turn_id)
stream_key = _get_turn_stream_key(session.turn_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
return session, last_id
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
@@ -891,7 +889,9 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
@@ -929,20 +929,20 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
return None
async def set_session_asyncio_task(session_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a session (local reference only).
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the session state is in Redis.
This is just for cleanup purposes - the task state is in Redis.
Args:
session_id: Session ID
task_id: Task ID
asyncio_task: The asyncio Task to track
"""
_local_sessions[session_id] = asyncio_task
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_session(
session_id: str,
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
@@ -951,29 +951,29 @@ async def unsubscribe_from_session(
to prevent resource leaks.
Args:
session_id: Session ID
task_id: Task ID
subscriber_queue: The subscriber's queue used to look up the listener task
"""
queue_id = id(subscriber_queue)
listener_entry = _listener_sessions.pop(queue_id, None)
listener_entry = _listener_tasks.pop(queue_id, None)
if listener_entry is None:
logger.debug(
f"No listener task found for session {session_id} queue {queue_id} "
f"No listener task found for task {task_id} queue {queue_id} "
"(may have already completed)"
)
return
stored_session_id, listener_task = listener_entry
stored_task_id, listener_task = listener_entry
if stored_session_id != session_id:
if stored_task_id != task_id:
logger.warning(
f"Session ID mismatch in unsubscribe: expected {session_id}, "
f"found {stored_session_id}"
f"Task ID mismatch in unsubscribe: expected {task_id}, "
f"found {stored_task_id}"
)
if listener_task.done():
logger.debug(f"Listener task for session {session_id} already completed")
logger.debug(f"Listener task for task {task_id} already completed")
return
# Cancel the listener task
@@ -987,11 +987,9 @@ async def unsubscribe_from_session(
pass
except asyncio.TimeoutError:
logger.warning(
f"Timeout waiting for listener task cancellation for session {session_id}"
f"Timeout waiting for listener task cancellation for task {task_id}"
)
except Exception as e:
logger.error(
f"Error during listener task cancellation for session {session_id}: {e}"
)
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
logger.debug(f"Successfully unsubscribed from session {session_id}")
logger.debug(f"Successfully unsubscribed from task {task_id}")

View File

@@ -1,420 +0,0 @@
"""End-to-end tests for Copilot streaming with dummy implementations.
These tests verify the complete copilot flow using dummy implementations
for agent generator and SDK service, allowing automated testing without
external LLM calls.
Enable test mode with COPILOT_TEST_MODE=true environment variable.
"""
import asyncio
import os
from uuid import uuid4
import pytest
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
)
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@pytest.fixture(autouse=True)
def enable_test_mode():
"""Enable test mode for all tests in this module."""
os.environ["COPILOT_TEST_MODE"] = "true"
yield
os.environ.pop("COPILOT_TEST_MODE", None)
@pytest.mark.asyncio
async def test_dummy_streaming_basic_flow():
"""Test that dummy streaming produces correct event sequence."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-basic",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Verify we got events
assert len(events) > 0, "Should receive events"
# Verify StreamStart
start_events = [e for e in events if isinstance(e, StreamStart)]
assert len(start_events) == 1
assert start_events[0].messageId
assert start_events[0].sessionId
# Verify StreamTextDelta events
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
assert len(text_events) > 0
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0
# Verify StreamFinish
finish_events = [e for e in events if isinstance(e, StreamFinish)]
assert len(finish_events) == 1
# Verify order: start before text before finish
start_idx = events.index(start_events[0])
finish_idx = events.index(finish_events[0])
first_text_idx = events.index(text_events[0]) if text_events else -1
if first_text_idx >= 0:
assert start_idx < first_text_idx < finish_idx
print(f"✅ Basic flow: {len(events)} events, {len(text_events)} text deltas")
@pytest.mark.asyncio
async def test_streaming_no_timeout():
"""Test that streaming completes within reasonable time without timeout."""
import time
start_time = time.monotonic()
event_count = 0
async for event in stream_chat_completion_dummy(
session_id="test-session-timeout",
message="count to 10",
is_user_message=True,
user_id="test-user",
):
event_count += 1
elapsed = time.monotonic() - start_time
# Should complete in < 5 seconds (dummy has 0.1s delays between words)
assert elapsed < 5.0, f"Streaming took {elapsed:.1f}s, expected < 5s"
assert event_count > 0, "Should receive events"
print(f"✅ No timeout: completed in {elapsed:.2f}s with {event_count} events")
@pytest.mark.asyncio
async def test_streaming_event_types():
"""Test that all expected event types are present."""
event_types = set()
async for event in stream_chat_completion_dummy(
session_id="test-session-types",
message="test",
is_user_message=True,
user_id="test-user",
):
event_types.add(type(event).__name__)
# Required event types
assert "StreamStart" in event_types, "Missing StreamStart"
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
assert "StreamFinish" in event_types, "Missing StreamFinish"
print(f"✅ Event types: {sorted(event_types)}")
@pytest.mark.asyncio
async def test_streaming_text_content():
"""Test that streamed text is coherent and complete."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-content",
message="count to 3",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify text deltas
assert len(text_events) > 0, "Should have text deltas"
# Reconstruct full text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Text should not be empty"
assert (
"1" in full_text or "counted" in full_text.lower()
), "Text should contain count"
# Verify all deltas have IDs
for text_event in text_events:
assert text_event.id, "Text delta must have ID"
assert text_event.delta, "Text delta must have content"
print(f"✅ Text content: '{full_text}' ({len(text_events)} deltas)")
@pytest.mark.asyncio
async def test_streaming_heartbeat_timing():
"""Test that heartbeats are sent at correct interval during long operations."""
# This test would need a dummy that takes longer
# For now, just verify heartbeat structure if we receive one
heartbeats = []
async for event in stream_chat_completion_dummy(
session_id="test-session-heartbeat",
message="test",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamHeartbeat):
heartbeats.append(event)
# Dummy is fast, so we might not get heartbeats
# But if we do, verify they're valid
if heartbeats:
print(f"✅ Heartbeat structure verified ({len(heartbeats)} received)")
else:
print("✅ No heartbeats (dummy executes quickly)")
@pytest.mark.asyncio
async def test_error_handling():
"""Test that errors are properly formatted and sent."""
# This would require a dummy that can trigger errors
# For now, just verify error event structure
error = StreamError(errorText="Test error", code="test_error")
assert error.errorText == "Test error"
assert error.code == "test_error"
assert str(error.type.value) in ["error", "error"]
print("✅ Error structure verified")
@pytest.mark.asyncio
async def test_concurrent_sessions():
"""Test that multiple sessions can stream concurrently."""
async def stream_session(session_id: str) -> int:
count = 0
async for event in stream_chat_completion_dummy(
session_id=session_id,
message="test",
is_user_message=True,
user_id="test-user",
):
count += 1
return count
# Run 3 concurrent sessions
results = await asyncio.gather(
stream_session("session-1"),
stream_session("session-2"),
stream_session("session-3"),
)
# All should complete successfully
assert all(count > 0 for count in results), "All sessions should produce events"
print(f"✅ Concurrent sessions: {results} events each")
@pytest.mark.asyncio
@pytest.mark.xfail(
reason="Event loop isolation issue with DB operations in tests - needs fixture refactoring"
)
async def test_session_state_persistence():
"""Test that session state is maintained across multiple messages."""
from datetime import datetime, timezone
session_id = f"test-session-{uuid4()}"
user_id = "test-user"
# Create session with first message
session = ChatSession(
session_id=session_id,
user_id=user_id,
messages=[
ChatMessage(role="user", content="Hello"),
ChatMessage(role="assistant", content="Hi there!"),
],
usage=[],
started_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
await upsert_chat_session(session)
# Stream second message
events = []
async for event in stream_chat_completion_dummy(
session_id=session_id,
message="How are you?",
is_user_message=True,
user_id=user_id,
session=session, # Pass existing session
):
events.append(event)
# Verify events were produced
assert len(events) > 0, "Should produce events for second message"
# Verify we got a complete response
finish_events = [e for e in events if isinstance(e, StreamFinish)]
assert len(finish_events) == 1, "Should have StreamFinish"
print(f"✅ Session persistence: {len(events)} events for second message")
@pytest.mark.asyncio
async def test_message_deduplication():
"""Test that duplicate messages are filtered out."""
# Simulate receiving duplicate events (e.g., from reconnection)
events = []
# First stream
async for event in stream_chat_completion_dummy(
session_id="test-dedup-1",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
if isinstance(event, StreamFinish):
break
# Count unique message IDs in StreamStart events
start_events = [e for e in events if isinstance(e, StreamStart)]
message_ids = [e.messageId for e in start_events]
# Verify all IDs are present
assert len(message_ids) == len(set(message_ids)), "Message IDs should be unique"
print(f"✅ Deduplication: {len(events)} events, all unique")
@pytest.mark.asyncio
async def test_event_ordering():
"""Test that events arrive in correct order."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-ordering",
message="Test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Find event indices
start_idx = next(
(i for i, e in enumerate(events) if isinstance(e, StreamStart)), None
)
text_indices = [i for i, e in enumerate(events) if isinstance(e, StreamTextDelta)]
finish_idx = next(
(i for i, e in enumerate(events) if isinstance(e, StreamFinish)), None
)
# Verify ordering
assert start_idx is not None, "Should have StreamStart"
assert finish_idx is not None, "Should have StreamFinish"
assert start_idx == 0, "StreamStart should be first"
assert finish_idx == len(events) - 1, "StreamFinish should be last"
if text_indices:
assert all(
start_idx < i < finish_idx for i in text_indices
), "Text deltas should be between start and finish"
print(f"✅ Event ordering: start({start_idx}) < text < finish({finish_idx})")
@pytest.mark.asyncio
async def test_stream_completeness():
"""Test that stream includes all required event types."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-completeness",
message="Complete stream test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Check for required events
has_start = any(isinstance(e, StreamStart) for e in events)
has_text = any(isinstance(e, StreamTextDelta) for e in events)
has_finish = any(isinstance(e, StreamFinish) for e in events)
assert has_start, "Stream must include StreamStart"
assert has_text, "Stream must include text deltas"
assert has_finish, "Stream must include StreamFinish"
# Verify exactly one start and one finish
start_count = sum(1 for e in events if isinstance(e, StreamStart))
finish_count = sum(1 for e in events if isinstance(e, StreamFinish))
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
assert finish_count == 1, f"Should have exactly 1 StreamFinish, got {finish_count}"
print(
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text, 1 finish"
)
@pytest.mark.asyncio
async def test_text_delta_consistency():
"""Test that text deltas have consistent IDs and build coherent text."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-consistency",
message="Test consistency",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify all text deltas have IDs
assert all(e.id for e in text_events), "All text deltas must have IDs"
# Verify all deltas have the same ID (same text block)
if text_events:
first_id = text_events[0].id
assert all(
e.id == first_id for e in text_events
), "All text deltas should share the same block ID"
# Verify deltas build coherent text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Deltas should build non-empty text"
assert (
full_text == full_text.strip()
), "Text should not have leading/trailing whitespace artifacts"
print(
f"✅ Consistency: {len(text_events)} deltas with ID '{text_events[0].id if text_events else 'N/A'}', text: '{full_text}'"
)
if __name__ == "__main__":
# Run tests directly
print("Running Copilot E2E tests with dummy implementations...")
print("=" * 60)
asyncio.run(test_dummy_streaming_basic_flow())
asyncio.run(test_streaming_no_timeout())
asyncio.run(test_streaming_event_types())
asyncio.run(test_streaming_text_content())
asyncio.run(test_streaming_heartbeat_timing())
asyncio.run(test_error_handling())
asyncio.run(test_concurrent_sessions())
asyncio.run(test_session_state_persistence())
asyncio.run(test_message_deduplication())
asyncio.run(test_event_ordering())
asyncio.run(test_stream_completeness())
asyncio.run(test_text_delta_consistency())
print("=" * 60)
print("✅ All E2E tests passed!")

View File

@@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -46,6 +47,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval

View File

@@ -19,7 +19,6 @@ from .core import (
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_by_ids,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
@@ -50,7 +49,6 @@ __all__ = [
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_by_ids",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",

View File

@@ -3,7 +3,6 @@
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any, NotRequired, TypedDict
from backend.data.db_accessors import graph_db, library_db, store_db
@@ -79,7 +78,7 @@ AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
agents: list[AgentSummary] | list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
@@ -191,36 +190,6 @@ async def get_library_agent_by_id(
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_by_ids(
user_id: str,
agent_ids: list[str],
) -> list[LibraryAgentSummary]:
"""Fetch multiple library agents by their IDs.
Args:
user_id: The user ID
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
Returns:
List of LibraryAgentSummary for found agents (silently skips not found)
"""
agents: list[LibraryAgentSummary] = []
for agent_id in agent_ids:
try:
agent = await get_library_agent_by_id(user_id, agent_id)
if agent:
agents.append(agent)
logger.debug(f"Fetched library agent by ID: {agent['name']}")
else:
logger.warning(f"Library agent not found for ID: {agent_id}")
except Exception as e:
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
continue
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
return agents
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
@@ -245,17 +214,10 @@ async def get_library_agents_for_generation(
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
search_term = search_query.strip() if search_query else None
if search_term and len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await library_db().list_library_agents(
user_id=user_id,
search_term=search_term,
search_term=search_query,
page=1,
page_size=max_results,
include_executions=True,
@@ -309,16 +271,9 @@ async def search_marketplace_agents_for_generation(
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
search_term = search_query.strip()
if len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await store_db().get_store_agents(
search_query=search_term,
search_query=search_query,
page=1,
page_size=max_results,
)
@@ -469,7 +424,7 @@ def extract_search_terms_from_steps(
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
@@ -493,7 +448,7 @@ async def enrich_library_agents_from_steps(
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return list(existing_agents)
return existing_agents
existing_ids: set[str] = set()
existing_names: set[str] = set()
@@ -556,7 +511,7 @@ async def enrich_library_agents_from_steps(
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
library_agents: list[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
@@ -584,16 +539,22 @@ async def decompose_goal(
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -601,9 +562,13 @@ async def generate_agent(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
@@ -793,7 +758,9 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -806,10 +773,12 @@ async def generate_agent_patch(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -820,6 +789,8 @@ async def generate_agent_patch(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)

View File

@@ -102,15 +102,10 @@ async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
"""Return dummy agent JSON after a simulated delay."""
logger.info("Using dummy agent generator for generate_agent (30s delay)")
await asyncio.sleep(30)
return _generate_dummy_agent_json()
@@ -120,16 +115,10 @@ async def generate_agent_patch_dummy(
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
"""Return dummy patched agent (returns the current agent with updated description)."""
logger.info("Using dummy agent generator for generate_agent_patch")
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"

View File

@@ -242,18 +242,24 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict or error dict {"type": "error", ...} on error
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(instructions, library_agents)
return await generate_agent_dummy(
instructions, library_agents, operation_id, task_id
)
client = _get_client()
@@ -261,9 +267,25 @@ async def generate_agent_external(
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -295,6 +317,8 @@ async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
@@ -303,14 +327,14 @@ async def generate_agent_patch_external(
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents
update_request, current_agent, library_agents, operation_id, task_id
)
client = _get_client()
@@ -322,9 +346,25 @@ async def generate_agent_patch_external(
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -379,8 +419,6 @@ async def customize_template_external(
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error

View File

@@ -36,6 +36,16 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -0,0 +1,124 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.copilot import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -10,6 +10,7 @@ from .agent_generator import (
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -17,6 +18,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -38,16 +40,17 @@ class CreateAgentTool(BaseTool):
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
"First generates a preview, then saves to library if save=true."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -67,15 +70,6 @@ class CreateAgentTool(BaseTool):
"Include any preferences or constraints mentioned by the user."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
),
},
"save": {
"type": "boolean",
"description": (
@@ -103,14 +97,12 @@ class CreateAgentTool(BaseTool):
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
)
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
@@ -119,34 +111,25 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
if user_id:
try:
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
agent_ids=library_agent_ids,
search_query=description,
include_marketplace=True,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
logger.warning(f"Failed to fetch library agents: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -247,17 +230,10 @@ class CreateAgentTool(BaseTool):
agent_json = await generate_agent(
decomposition_result,
library_agents,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -300,20 +276,25 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
@@ -339,13 +320,6 @@ class CreateAgentTool(BaseTool):
agent_json, user_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
@@ -356,12 +330,6 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",

View File

@@ -43,6 +43,11 @@ async def test_vague_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -73,6 +78,11 @@ async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -110,6 +120,11 @@ async def test_clarifying_questions_returns_clarification_needed_response(
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,

View File

@@ -46,6 +46,10 @@ class CustomizeAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -9,6 +9,7 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -16,6 +17,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -36,16 +38,17 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
"Generates updates to the agent while preserving unchanged parts."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -71,15 +74,6 @@ class EditAgentTool(BaseTool):
"Additional context or answers to previous clarifying questions."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
),
},
"save": {
"type": "boolean",
"description": (
@@ -108,10 +102,13 @@ class EditAgentTool(BaseTool):
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -135,25 +132,21 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
if user_id:
try:
from .agent_generator import get_library_agents_by_ids
graph_id = current_agent.get("id")
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
library_agents = await get_library_agents_by_ids(
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
agent_ids=filtered_ids,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
logger.warning(f"Failed to fetch library agents: {e}")
update_request = changes
if context:
@@ -164,6 +157,8 @@ class EditAgentTool(BaseTool):
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -183,6 +178,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")

View File

@@ -366,15 +366,12 @@ class TestFindBlockFiltering:
return_value=(search_results, len(search_results))
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
),
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
), patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
):
tool = FindBlockTool()
response = await tool._execute(

View File

@@ -36,6 +36,8 @@ class ResponseType(str, Enum):
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
@@ -43,6 +45,8 @@ class ResponseType(str, Enum):
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
@@ -416,6 +420,34 @@ class BlockOutputResponse(ToolResponseBase):
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
@@ -427,6 +459,23 @@ class OperationInProgressResponse(ToolResponseBase):
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""

View File

@@ -160,10 +160,9 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
)
# Get block schemas for details/validation
try:

View File

@@ -372,7 +372,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=1800,
default=600,
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
)
agentgenerator_use_dummy: bool = Field(

View File

@@ -109,7 +109,7 @@ class TestGenerateAgent:
instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await core.generate_agent(instructions)
mock_external.assert_called_once_with(instructions, None)
mock_external.assert_called_once_with(instructions, None, None, None)
assert result is not None
assert result["name"] == "Test Agent"
assert "id" in result
@@ -173,7 +173,9 @@ class TestGenerateAgentPatch:
current_agent = {"nodes": [], "links": []}
result = await core.generate_agent_patch("Add a node", current_agent)
mock_external.assert_called_once_with("Add a node", current_agent, None)
mock_external.assert_called_once_with(
"Add a node", current_agent, None, None, None
)
assert result == expected_result
@pytest.mark.asyncio

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python3
"""
Integration test for the requeue fix implementation.
Tests actual RabbitMQ behavior to verify that republishing sends messages to back of queue.
"""
import json
import time
from threading import Event
from typing import List
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
class QueueOrderTester:
"""Helper class to test message ordering in RabbitMQ using a dedicated test queue."""
def __init__(self):
self.received_messages: List[dict] = []
self.stop_consuming = Event()
self.queue_client = SyncRabbitMQ(create_execution_queue_config())
self.queue_client.connect()
# Use a dedicated test queue name to avoid conflicts
self.test_queue_name = "test_requeue_ordering"
self.test_exchange = "test_exchange"
self.test_routing_key = "test.requeue"
def setup_queue(self):
"""Set up a dedicated test queue for testing."""
channel = self.queue_client.get_channel()
# Declare test exchange
channel.exchange_declare(
exchange=self.test_exchange, exchange_type="direct", durable=True
)
# Declare test queue
channel.queue_declare(
queue=self.test_queue_name, durable=True, auto_delete=False
)
# Bind queue to exchange
channel.queue_bind(
exchange=self.test_exchange,
queue=self.test_queue_name,
routing_key=self.test_routing_key,
)
# Purge the queue to start fresh
channel.queue_purge(self.test_queue_name)
print(f"✅ Test queue {self.test_queue_name} setup and purged")
def create_test_message(self, message_id: str, user_id: str = "test-user") -> str:
"""Create a test graph execution message."""
return json.dumps(
{
"graph_exec_id": f"exec-{message_id}",
"graph_id": f"graph-{message_id}",
"user_id": user_id,
"execution_context": {"timezone": "UTC"},
"nodes_input_masks": {},
"starting_nodes_input": [],
}
)
def publish_message(self, message: str):
"""Publish a message to the test queue."""
channel = self.queue_client.get_channel()
channel.basic_publish(
exchange=self.test_exchange,
routing_key=self.test_routing_key,
body=message,
)
def consume_messages(self, max_messages: int = 10, timeout: float = 5.0):
"""Consume messages and track their order."""
def callback(ch, method, properties, body):
try:
message_data = json.loads(body.decode())
self.received_messages.append(message_data)
ch.basic_ack(delivery_tag=method.delivery_tag)
if len(self.received_messages) >= max_messages:
self.stop_consuming.set()
except Exception as e:
print(f"Error processing message: {e}")
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
# Use synchronous consumption with blocking
channel = self.queue_client.get_channel()
# Check if there are messages in the queue first
method_frame, header_frame, body = channel.basic_get(
queue=self.test_queue_name, auto_ack=False
)
if method_frame:
# There are messages, set up consumer
channel.basic_nack(
delivery_tag=method_frame.delivery_tag, requeue=True
) # Put message back
# Set up consumer
channel.basic_consume(
queue=self.test_queue_name,
on_message_callback=callback,
)
# Consume with timeout
start_time = time.time()
while (
not self.stop_consuming.is_set()
and (time.time() - start_time) < timeout
and len(self.received_messages) < max_messages
):
try:
channel.connection.process_data_events(time_limit=0.1)
except Exception as e:
print(f"Error during consumption: {e}")
break
# Cancel the consumer
try:
channel.cancel()
except Exception:
pass
else:
# No messages in queue - this might be expected for some tests
pass
return self.received_messages
def cleanup(self):
"""Clean up test resources."""
try:
channel = self.queue_client.get_channel()
channel.queue_delete(queue=self.test_queue_name)
channel.exchange_delete(exchange=self.test_exchange)
print(f"✅ Test queue {self.test_queue_name} cleaned up")
except Exception as e:
print(f"⚠️ Cleanup issue: {e}")
def test_queue_ordering_behavior():
"""
Integration test to verify that our republishing method sends messages to back of queue.
This tests the actual fix for the rate limiting queue blocking issue.
"""
tester = QueueOrderTester()
try:
tester.setup_queue()
print("🧪 Testing actual RabbitMQ queue ordering behavior...")
# Test 1: Normal FIFO behavior
print("1. Testing normal FIFO queue behavior")
# Publish messages in order: A, B, C
msg_a = tester.create_test_message("A")
msg_b = tester.create_test_message("B")
msg_c = tester.create_test_message("C")
tester.publish_message(msg_a)
tester.publish_message(msg_b)
tester.publish_message(msg_c)
# Consume and verify FIFO order: A, B, C
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=3)
assert len(messages) == 3, f"Expected 3 messages, got {len(messages)}"
assert (
messages[0]["graph_exec_id"] == "exec-A"
), f"First message should be A, got {messages[0]['graph_exec_id']}"
assert (
messages[1]["graph_exec_id"] == "exec-B"
), f"Second message should be B, got {messages[1]['graph_exec_id']}"
assert (
messages[2]["graph_exec_id"] == "exec-C"
), f"Third message should be C, got {messages[2]['graph_exec_id']}"
print("✅ FIFO order confirmed: A -> B -> C")
# Test 2: Rate limiting simulation - the key test!
print("2. Testing rate limiting fix scenario")
# Simulate the scenario where user1 is rate limited
user1_msg = tester.create_test_message("RATE-LIMITED", "user1")
user2_msg1 = tester.create_test_message("USER2-1", "user2")
user2_msg2 = tester.create_test_message("USER2-2", "user2")
# Initially publish user1 message (gets consumed, then rate limited on retry)
tester.publish_message(user1_msg)
# Other users publish their messages
tester.publish_message(user2_msg1)
tester.publish_message(user2_msg2)
# Now simulate: user1 message gets "requeued" using our new republishing method
# This is what happens in manager.py when requeue_by_republishing=True
tester.publish_message(user1_msg) # Goes to back via our method
# Expected order: RATE-LIMITED, USER2-1, USER2-2, RATE-LIMITED (republished to back)
# This shows that user2 messages get processed instead of being blocked
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=4)
assert len(messages) == 4, f"Expected 4 messages, got {len(messages)}"
# The key verification: user2 messages are NOT blocked by user1's rate-limited message
user2_messages = [msg for msg in messages if msg["user_id"] == "user2"]
assert len(user2_messages) == 2, "Both user2 messages should be processed"
assert user2_messages[0]["graph_exec_id"] == "exec-USER2-1"
assert user2_messages[1]["graph_exec_id"] == "exec-USER2-2"
print("✅ Rate limiting fix confirmed: user2 executions NOT blocked by user1")
# Test 3: Verify our method behaves like going to back of queue
print("3. Testing republishing sends messages to back")
# Start with message X in queue
msg_x = tester.create_test_message("X")
tester.publish_message(msg_x)
# Add message Y
msg_y = tester.create_test_message("Y")
tester.publish_message(msg_y)
# Republish X (simulates requeue using our method)
tester.publish_message(msg_x)
# Expected: X, Y, X (X was republished to back)
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=3)
assert len(messages) == 3
# Y should come before the republished X
y_index = next(
i for i, msg in enumerate(messages) if msg["graph_exec_id"] == "exec-Y"
)
republished_x_index = next(
i
for i, msg in enumerate(messages[1:], 1)
if msg["graph_exec_id"] == "exec-X"
)
assert (
y_index < republished_x_index
), f"Y should come before republished X, but got order: {[m['graph_exec_id'] for m in messages]}"
print("✅ Republishing confirmed: messages go to back of queue")
print("🎉 All integration tests passed!")
print("🎉 Our republishing method works correctly with real RabbitMQ")
print("🎉 Queue blocking issue is fixed!")
finally:
tester.cleanup()
def test_traditional_requeue_behavior():
"""
Test that traditional requeue (basic_nack with requeue=True) sends messages to FRONT of queue.
This validates our hypothesis about why queue blocking occurs.
"""
tester = QueueOrderTester()
try:
tester.setup_queue()
print("🧪 Testing traditional requeue behavior (basic_nack with requeue=True)")
# Step 1: Publish message A
msg_a = tester.create_test_message("A")
tester.publish_message(msg_a)
# Step 2: Publish message B
msg_b = tester.create_test_message("B")
tester.publish_message(msg_b)
# Step 3: Consume message A and requeue it using traditional method
channel = tester.queue_client.get_channel()
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=False
)
assert method_frame is not None, "Should have received message A"
consumed_msg = json.loads(body.decode())
assert (
consumed_msg["graph_exec_id"] == "exec-A"
), f"Should have consumed message A, got {consumed_msg['graph_exec_id']}"
# Traditional requeue: basic_nack with requeue=True (sends to FRONT)
channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True)
print(f"🔄 Traditional requeue (to FRONT): {consumed_msg['graph_exec_id']}")
# Step 4: Consume all messages using basic_get for reliability
received_messages = []
# Get first message
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=True
)
if method_frame:
msg = json.loads(body.decode())
received_messages.append(msg)
# Get second message
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=True
)
if method_frame:
msg = json.loads(body.decode())
received_messages.append(msg)
# CRITICAL ASSERTION: Traditional requeue should put A at FRONT
# Expected order: A (requeued to front), B
assert (
len(received_messages) == 2
), f"Expected 2 messages, got {len(received_messages)}"
first_msg = received_messages[0]["graph_exec_id"]
second_msg = received_messages[1]["graph_exec_id"]
# This is the critical test: requeued message A should come BEFORE B
assert (
first_msg == "exec-A"
), f"Traditional requeue should put A at FRONT, but first message was: {first_msg}"
assert (
second_msg == "exec-B"
), f"B should come after requeued A, but second message was: {second_msg}"
print(
"✅ HYPOTHESIS CONFIRMED: Traditional requeue sends messages to FRONT of queue"
)
print(f" Order: {first_msg} (requeued to front) → {second_msg}")
print(" This explains why rate-limited messages block other users!")
finally:
tester.cleanup()
if __name__ == "__main__":
test_queue_ordering_behavior()

View File

@@ -27,7 +27,6 @@ export function CopilotPage() {
createSession,
onSend,
isLoadingSession,
isSessionError,
isCreatingSession,
isUserLoading,
isLoggedIn,
@@ -72,7 +71,6 @@ export function CopilotPage() {
error={error}
sessionId={sessionId}
isLoadingSession={isLoadingSession}
isSessionError={isSessionError}
isCreatingSession={isCreatingSession}
isReconnecting={isReconnecting}
onCreateSession={createSession}

View File

@@ -13,7 +13,6 @@ export interface ChatContainerProps {
error: Error | undefined;
sessionId: string | null;
isLoadingSession: boolean;
isSessionError?: boolean;
isCreatingSession: boolean;
/** True when backend has an active stream but we haven't reconnected yet. */
isReconnecting?: boolean;
@@ -28,7 +27,6 @@ export const ChatContainer = ({
error,
sessionId,
isLoadingSession,
isSessionError,
isCreatingSession,
isReconnecting,
onCreateSession,
@@ -36,12 +34,7 @@ export const ChatContainer = ({
onStop,
headerSlot,
}: ChatContainerProps) => {
const isBusy =
status === "streaming" ||
status === "submitted" ||
!!isReconnecting ||
isLoadingSession ||
!!isSessionError;
const isBusy = status === "streaming" || !!isReconnecting;
const inputLayoutId = "copilot-2-chat-input";
return (

View File

@@ -10,8 +10,9 @@ import {
MessageResponse,
} from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { toast } from "@/components/molecules/Toast/use-toast";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
import {
@@ -128,6 +129,7 @@ export const ChatMessagesContainer = ({
headerSlot,
}: ChatMessagesContainerProps) => {
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
const lastToastTimeRef = useRef(0);
useEffect(() => {
if (status === "submitted") {
@@ -135,6 +137,20 @@ export const ChatMessagesContainer = ({
}
}, [status]);
// Show a toast when a new error occurs, debounced to avoid spam
useEffect(() => {
if (!error) return;
const now = Date.now();
if (now - lastToastTimeRef.current < 3_000) return;
lastToastTimeRef.current = now;
toast({
variant: "destructive",
title: "Something went wrong",
description:
"The assistant encountered an error. Please try sending your message again.",
});
}, [error]);
const lastMessage = messages[messages.length - 1];
const lastAssistantHasVisibleContent =
lastMessage?.role === "assistant" &&
@@ -298,15 +314,13 @@ export const ChatMessagesContainer = ({
</Message>
)}
{error && (
<details className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
<summary className="cursor-pointer font-medium">
<div className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
<p className="font-medium">Something went wrong</p>
<p className="mt-1 text-red-600">
The assistant encountered an error. Please try sending your
message again.
</summary>
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words text-xs text-red-600">
{error instanceof Error ? error.message : String(error)}
</pre>
</details>
</p>
</div>
)}
</ConversationContent>
<ConversationScrollButton />

View File

@@ -116,10 +116,12 @@ export function convertChatSessionMessagesToUiMessages(
output: "",
});
} else {
// Active stream exists: Skip incomplete tool calls during hydration.
// The resume stream will deliver them fresh with proper SDK state.
// This prevents "No tool invocation found" errors on page refresh.
continue;
parts.push({
type: `tool-${toolName}`,
toolCallId,
state: "input-available",
input,
});
}
}
}

View File

@@ -0,0 +1,47 @@
import { useEffect, useRef, useState } from "react";
/**
* Hook that returns a progress value that starts fast and slows down,
* asymptotically approaching but never reaching the max value.
*
* Uses a half-life formula: progress = max * (1 - 0.5^(time/halfLife))
* This creates a "loading bar" effect where:
* - 50% is reached at halfLifeSeconds
* - 75% is reached at 2 * halfLifeSeconds
* - 87.5% is reached at 3 * halfLifeSeconds
*
* @param isActive - Whether the progress should be animating
* @param halfLifeSeconds - Time in seconds to reach 50% progress (default: 30)
* @param maxProgress - Maximum progress value to approach (default: 100)
* @param intervalMs - Update interval in milliseconds (default: 100)
* @returns Current progress value (0maxProgress)
*/
export function useAsymptoticProgress(
isActive: boolean,
halfLifeSeconds = 30,
maxProgress = 100,
intervalMs = 100,
) {
const [progress, setProgress] = useState(0);
const elapsedTimeRef = useRef(0);
useEffect(() => {
if (!isActive) {
setProgress(0);
elapsedTimeRef.current = 0;
return;
}
const interval = setInterval(() => {
elapsedTimeRef.current += intervalMs / 1000;
const newProgress =
maxProgress *
(1 - Math.pow(0.5, elapsedTimeRef.current / halfLifeSeconds));
setProgress(newProgress);
}, intervalMs);
return () => clearInterval(interval);
}, [isActive, halfLifeSeconds, maxProgress, intervalMs]);
return progress;
}

View File

@@ -0,0 +1,126 @@
import { getGetV2GetSessionQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
import { useQueryClient } from "@tanstack/react-query";
import type { UIDataTypes, UIMessage, UITools } from "ai";
import { useCallback, useEffect, useRef } from "react";
import { convertChatSessionMessagesToUiMessages } from "../helpers/convertChatSessionToUiMessages";
const OPERATING_TYPES = new Set([
"operation_started",
"operation_pending",
"operation_in_progress",
]);
const POLL_INTERVAL_MS = 1_500;
/**
* Detects whether any message contains a tool part whose output indicates
* a long-running operation is still in progress.
*/
function hasOperatingTool(
messages: UIMessage<unknown, UIDataTypes, UITools>[],
) {
for (const msg of messages) {
for (const part of msg.parts) {
if (!part.type.startsWith("tool-")) continue;
const toolPart = part as { output?: unknown };
if (!toolPart.output) continue;
const output =
typeof toolPart.output === "string"
? safeParse(toolPart.output)
: toolPart.output;
if (
output &&
typeof output === "object" &&
"type" in output &&
OPERATING_TYPES.has((output as { type: string }).type)
) {
return true;
}
}
}
return false;
}
function safeParse(value: string): unknown {
try {
return JSON.parse(value);
} catch {
return null;
}
}
/**
* Polls the session endpoint while any tool is in an "operating" state
* (operation_started / operation_pending / operation_in_progress).
*
* When the session data shows the tool output has changed (e.g. to
* agent_saved), it calls `setMessages` with the updated messages.
*/
export function useLongRunningToolPolling(
sessionId: string | null,
messages: UIMessage<unknown, UIDataTypes, UITools>[],
setMessages: (
updater: (
prev: UIMessage<unknown, UIDataTypes, UITools>[],
) => UIMessage<unknown, UIDataTypes, UITools>[],
) => void,
) {
const queryClient = useQueryClient();
const intervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
const stopPolling = useCallback(() => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
}, []);
const poll = useCallback(async () => {
if (!sessionId) return;
// Invalidate the query cache so the next fetch gets fresh data
await queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
// Fetch fresh session data
const data = queryClient.getQueryData<{
status: number;
data: { messages?: unknown[] };
}>(getGetV2GetSessionQueryKey(sessionId));
if (data?.status !== 200 || !data.data.messages) return;
const freshMessages = convertChatSessionMessagesToUiMessages(
sessionId,
data.data.messages,
);
if (!freshMessages || freshMessages.length === 0) return;
// Update when the long-running tool completed
if (!hasOperatingTool(freshMessages)) {
setMessages(() => freshMessages);
stopPolling();
}
}, [sessionId, queryClient, setMessages, stopPolling]);
useEffect(() => {
const shouldPoll = hasOperatingTool(messages);
// Always clear any previous interval first so we never leak timers
// when the effect re-runs due to dependency changes (e.g. messages
// updating as the LLM streams text after the tool call).
stopPolling();
if (shouldPoll && sessionId) {
intervalRef.current = setInterval(() => {
poll();
}, POLL_INTERVAL_MS);
}
return () => {
stopPolling();
};
}, [messages, sessionId, poll, stopPolling]);
}

View File

@@ -1120,6 +1120,56 @@ export default function StyleguidePage() {
/>
</SubSection>
<SubSection label="Output available (operation started)">
<CreateAgentTool
part={{
type: "tool-create_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_started,
operation_id: "op-create-123",
tool_name: "create_agent",
message:
"Agent creation has been started. This may take a moment.",
},
}}
/>
</SubSection>
<SubSection label="Output available (operation pending)">
<CreateAgentTool
part={{
type: "tool-create_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_pending,
operation_id: "op-create-123",
tool_name: "create_agent",
message:
"Agent creation is queued and will begin shortly.",
},
}}
/>
</SubSection>
<SubSection label="Output available (operation in progress)">
<CreateAgentTool
part={{
type: "tool-create_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_in_progress,
tool_call_id: "tc-456",
message:
"An agent creation operation is already in progress. Please wait for it to finish.",
},
}}
/>
</SubSection>
<SubSection label="Output available (agent preview)">
<CreateAgentTool
part={{
@@ -1242,6 +1292,22 @@ export default function StyleguidePage() {
/>
</SubSection>
<SubSection label="Output available (operation started)">
<EditAgentTool
part={{
type: "tool-edit_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_started,
operation_id: "op-edit-456",
tool_name: "edit_agent",
message: "Agent editing has started.",
},
}}
/>
</SubSection>
<SubSection label="Output available (agent preview)">
<EditAgentTool
part={{

View File

@@ -16,7 +16,6 @@ import {
ContentCardDescription,
ContentCodeBlock,
ContentGrid,
ContentHint,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
@@ -36,6 +35,9 @@ import {
isAgentSavedOutput,
isClarificationNeededOutput,
isErrorOutput,
isOperationInProgressOutput,
isOperationPendingOutput,
isOperationStartedOutput,
isSuggestedGoalOutput,
ToolIcon,
truncateText,
@@ -54,18 +56,9 @@ interface Props {
part: CreateAgentToolPart;
}
function getAccordionMeta(output: CreateAgentToolOutput | null) {
function getAccordionMeta(output: CreateAgentToolOutput) {
const icon = <AccordionIcon />;
if (!output) {
return {
icon,
title:
"Creating agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
}
if (isAgentSavedOutput(output)) {
return { icon, title: output.agent_name, expanded: true };
}
@@ -92,6 +85,16 @@ function getAccordionMeta(output: CreateAgentToolOutput | null) {
expanded: true,
};
}
if (
isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output)
) {
return {
icon,
title: output.message || "Agent creation started",
};
}
return {
icon: (
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
@@ -113,11 +116,23 @@ export function CreateAgentTool({ part }: Props) {
const isError =
part.state === "output-error" || (!!output && isErrorOutput(output));
const isOperating = !output;
const isOperating =
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output));
// Show accordion for operating state and successful outputs, but not for errors
// (errors are shown inline so they get replaced when retrying)
const hasExpandableContent = !isError;
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output) ||
isAgentPreviewOutput(output) ||
isAgentSavedOutput(output) ||
isClarificationNeededOutput(output) ||
isSuggestedGoalOutput(output) ||
isErrorOutput(output));
function handleUseSuggestedGoal(goal: string) {
onSend(`Please create an agent with this goal: ${goal}`);
@@ -143,77 +158,33 @@ export function CreateAgentTool({ part }: Props) {
return (
<div className="py-2">
{isOperating && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isStreaming && (
<ToolAccordion
icon={<AccordionIcon />}
title="Creating agent, this may take a few minutes. Play while you wait."
expanded
>
<ContentGrid>
<MiniGame />
</ContentGrid>
</ToolAccordion>
)}
{isError && output && isErrorOutput(output) && (
<div className="space-y-3 rounded-lg border border-red-200 bg-red-50 p-4">
<div className="flex items-start gap-2">
<WarningDiamondIcon
size={20}
weight="regular"
className="mt-0.5 shrink-0 text-red-500"
/>
<div className="flex-1 space-y-2">
<Text variant="body-medium" className="text-red-900">
{output.message ||
"Failed to generate the agent. Please try again."}
</Text>
{output.error && (
<details className="text-xs text-red-700">
<summary className="cursor-pointer font-medium">
Technical details
</summary>
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2">
{formatMaybeJson(output.error)}
</pre>
</details>
)}
{output.details && (
<pre className="max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2 text-xs text-red-700">
{formatMaybeJson(output.details)}
</pre>
)}
</div>
</div>
<div className="flex gap-2">
<Button
variant="outline"
size="small"
onClick={() => onSend("Please try creating the agent again.")}
>
Try again
</Button>
<Button
variant="outline"
size="small"
onClick={() => onSend("Can you help me simplify this goal?")}
>
Simplify goal
</Button>
</div>
</div>
)}
{hasExpandableContent && (
{hasExpandableContent && output && (
<ToolAccordion {...getAccordionMeta(output)}>
{isOperating && (
<ContentGrid>
<MiniGame />
<ContentHint>
This could take a few minutes play while you wait!
</ContentHint>
</ContentGrid>
{isOperating && output.message && (
<ContentMessage>{output.message}</ContentMessage>
)}
{output && isAgentSavedOutput(output) && (
{isAgentSavedOutput(output) && (
<div className="rounded-xl border border-border/60 bg-card p-4 shadow-sm">
<div className="flex items-baseline gap-2">
<Image
@@ -259,7 +230,7 @@ export function CreateAgentTool({ part }: Props) {
</div>
)}
{output && isAgentPreviewOutput(output) && (
{isAgentPreviewOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.description?.trim() && (
@@ -273,7 +244,7 @@ export function CreateAgentTool({ part }: Props) {
</ContentGrid>
)}
{output && isClarificationNeededOutput(output) && (
{isClarificationNeededOutput(output) && (
<ClarificationQuestionsCard
questions={(output.questions ?? []).map((q) => {
const item: ClarifyingQuestion = {
@@ -292,7 +263,7 @@ export function CreateAgentTool({ part }: Props) {
/>
)}
{output && isSuggestedGoalOutput(output) && (
{isSuggestedGoalOutput(output) && (
<SuggestedGoalCard
message={output.message}
suggestedGoal={output.suggested_goal}
@@ -301,6 +272,38 @@ export function CreateAgentTool({ part }: Props) {
onUseSuggestedGoal={handleUseSuggestedGoal}
/>
)}
{isErrorOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.error && (
<ContentCodeBlock>
{formatMaybeJson(output.error)}
</ContentCodeBlock>
)}
{output.details && (
<ContentCodeBlock>
{formatMaybeJson(output.details)}
</ContentCodeBlock>
)}
<div className="flex gap-2">
<Button
variant="outline"
size="small"
onClick={() => onSend("Please try creating the agent again.")}
>
Try again
</Button>
<Button
variant="outline"
size="small"
onClick={() => onSend("Can you help me simplify this goal?")}
>
Simplify goal
</Button>
</div>
</ContentGrid>
)}
</ToolAccordion>
)}
</div>

View File

@@ -2,6 +2,9 @@ import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentP
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import type { SuggestedGoalResponse } from "@/app/api/__generated__/models/suggestedGoalResponse";
import {
@@ -13,6 +16,9 @@ import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
export type CreateAgentToolOutput =
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
@@ -33,6 +39,9 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (
type === ResponseType.operation_started ||
type === ResponseType.operation_pending ||
type === ResponseType.operation_in_progress ||
type === ResponseType.agent_preview ||
type === ResponseType.agent_saved ||
type === ResponseType.clarification_needed ||
@@ -41,6 +50,9 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
) {
return output as CreateAgentToolOutput;
}
if ("operation_id" in output && "tool_name" in output)
return output as OperationStartedResponse | OperationPendingResponse;
if ("tool_call_id" in output) return output as OperationInProgressResponse;
if ("agent_json" in output && "agent_name" in output)
return output as AgentPreviewResponse;
if ("agent_id" in output && "library_agent_id" in output)
@@ -60,6 +72,30 @@ export function getCreateAgentToolOutput(
return parseOutput((part as { output?: unknown }).output);
}
export function isOperationStartedOutput(
output: CreateAgentToolOutput,
): output is OperationStartedResponse {
return (
output.type === ResponseType.operation_started ||
("operation_id" in output && "tool_name" in output)
);
}
export function isOperationPendingOutput(
output: CreateAgentToolOutput,
): output is OperationPendingResponse {
return output.type === ResponseType.operation_pending;
}
export function isOperationInProgressOutput(
output: CreateAgentToolOutput,
): output is OperationInProgressResponse {
return (
output.type === ResponseType.operation_in_progress ||
"tool_call_id" in output
);
}
export function isAgentPreviewOutput(
output: CreateAgentToolOutput,
): output is AgentPreviewResponse {
@@ -108,6 +144,10 @@ export function getAnimationText(part: {
case "output-available": {
const output = parseOutput(part.output);
if (!output) return "Creating a new agent";
if (isOperationStartedOutput(output)) return "Agent creation started";
if (isOperationPendingOutput(output)) return "Agent creation in progress";
if (isOperationInProgressOutput(output))
return "Agent creation already in progress";
if (isAgentSavedOutput(output)) return `Saved ${output.agent_name}`;
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
if (isClarificationNeededOutput(output)) return "Needs clarification";

View File

@@ -1,27 +1,18 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import {
BookOpenIcon,
PencilSimpleIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
import { WarningDiamondIcon } from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import Image from "next/image";
import NextLink from "next/link";
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
import sparklesImg from "../../components/MiniGame/assets/sparkles.png";
import { MiniGame } from "../../components/MiniGame/MiniGame";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import {
ContentCardDescription,
ContentCodeBlock,
ContentGrid,
ContentHint,
ContentLink,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import { MiniGame } from "../../components/MiniGame/MiniGame";
import {
ClarificationQuestionsCard,
ClarifyingQuestion,
@@ -35,6 +26,9 @@ import {
isAgentSavedOutput,
isClarificationNeededOutput,
isErrorOutput,
isOperationInProgressOutput,
isOperationPendingOutput,
isOperationStartedOutput,
ToolIcon,
truncateText,
type EditAgentToolOutput,
@@ -52,7 +46,7 @@ interface Props {
part: EditAgentToolPart;
}
function getAccordionMeta(output: EditAgentToolOutput | null): {
function getAccordionMeta(output: EditAgentToolOutput): {
icon: React.ReactNode;
title: string;
titleClassName?: string;
@@ -61,16 +55,8 @@ function getAccordionMeta(output: EditAgentToolOutput | null): {
} {
const icon = <AccordionIcon />;
if (!output) {
return {
icon,
title: "Editing agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
}
if (isAgentSavedOutput(output)) {
return { icon, title: output.agent_name, expanded: true };
return { icon, title: output.agent_name };
}
if (isAgentPreviewOutput(output)) {
return {
@@ -87,6 +73,16 @@ function getAccordionMeta(output: EditAgentToolOutput | null): {
description: `${questions.length} question${questions.length === 1 ? "" : "s"}`,
};
}
if (
isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output)
) {
return {
icon,
title: output.message || "Agent editing started",
};
}
return {
icon: (
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
@@ -105,12 +101,21 @@ export function EditAgentTool({ part }: Props) {
const output = getEditAgentToolOutput(part);
const isError =
part.state === "output-error" || (!!output && isErrorOutput(output));
const isOperating = !output;
// Show accordion for operating state and successful outputs, but not for errors
// (errors are shown inline so they get replaced when retrying)
const hasExpandableContent = !isError;
const isOperating =
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output));
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output) ||
isAgentPreviewOutput(output) ||
isAgentSavedOutput(output) ||
isClarificationNeededOutput(output) ||
isErrorOutput(output));
function handleClarificationAnswers(answers: Record<string, string>) {
const questions =
@@ -132,114 +137,53 @@ export function EditAgentTool({ part }: Props) {
return (
<div className="py-2">
{isOperating && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isStreaming && (
<ToolAccordion
icon={<AccordionIcon />}
title="Editing agent, this may take a few minutes. Play while you wait."
expanded
>
<ContentGrid>
<MiniGame />
</ContentGrid>
</ToolAccordion>
)}
{isError && output && isErrorOutput(output) && (
<div className="space-y-3 rounded-lg border border-red-200 bg-red-50 p-4">
<div className="flex items-start gap-2">
<WarningDiamondIcon
size={20}
weight="regular"
className="mt-0.5 shrink-0 text-red-500"
/>
<div className="flex-1 space-y-2">
<Text variant="body-medium" className="text-red-900">
{output.message ||
"Failed to edit the agent. Please try again."}
</Text>
{output.error && (
<details className="text-xs text-red-700">
<summary className="cursor-pointer font-medium">
Technical details
</summary>
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2">
{formatMaybeJson(output.error)}
</pre>
</details>
)}
{output.details && (
<pre className="max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2 text-xs text-red-700">
{formatMaybeJson(output.details)}
</pre>
)}
</div>
</div>
<Button
variant="outline"
size="small"
onClick={() => onSend("Please try editing the agent again.")}
>
Try again
</Button>
</div>
)}
{hasExpandableContent && (
{hasExpandableContent && output && (
<ToolAccordion {...getAccordionMeta(output)}>
{isOperating && (
{isOperating && output.message && (
<ContentMessage>{output.message}</ContentMessage>
)}
{isAgentSavedOutput(output) && (
<ContentGrid>
<MiniGame />
<ContentHint>
This could take a few minutes play while you wait!
</ContentHint>
<ContentMessage>{output.message}</ContentMessage>
<div className="flex flex-wrap gap-2">
<ContentLink href={output.library_agent_link}>
Open in library
</ContentLink>
<ContentLink href={output.agent_page_link}>
Open in builder
</ContentLink>
</div>
<ContentCodeBlock>
{truncateText(
formatMaybeJson({ agent_id: output.agent_id }),
800,
)}
</ContentCodeBlock>
</ContentGrid>
)}
{output && isAgentSavedOutput(output) && (
<div className="rounded-xl border border-border/60 bg-card p-4 shadow-sm">
<div className="flex items-baseline gap-2">
<Image
src={sparklesImg}
alt="sparkles"
width={24}
height={24}
className="relative top-1"
/>
<Text
variant="body-medium"
className="mb-2 text-[16px] text-black"
>
Agent{" "}
<span className="text-violet-600">{output.agent_name}</span>{" "}
has been updated!
</Text>
</div>
<div className="mt-3 flex flex-wrap gap-4">
<Button variant="outline" size="small">
<NextLink
href={output.library_agent_link}
className="inline-flex items-center gap-1.5"
target="_blank"
rel="noopener noreferrer"
>
<BookOpenIcon size={14} weight="regular" />
Open in library
</NextLink>
</Button>
<Button variant="outline" size="small">
<NextLink
href={output.agent_page_link}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center gap-1.5"
>
<PencilSimpleIcon size={14} weight="regular" />
Open in builder
</NextLink>
</Button>
</div>
</div>
)}
{output && isAgentPreviewOutput(output) && (
{isAgentPreviewOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.description?.trim() && (
@@ -253,7 +197,7 @@ export function EditAgentTool({ part }: Props) {
</ContentGrid>
)}
{output && isClarificationNeededOutput(output) && (
{isClarificationNeededOutput(output) && (
<ClarificationQuestionsCard
questions={(output.questions ?? []).map((q) => {
const item: ClarifyingQuestion = {
@@ -271,6 +215,22 @@ export function EditAgentTool({ part }: Props) {
onSubmitAnswers={handleClarificationAnswers}
/>
)}
{isErrorOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.error && (
<ContentCodeBlock>
{formatMaybeJson(output.error)}
</ContentCodeBlock>
)}
{output.details && (
<ContentCodeBlock>
{formatMaybeJson(output.details)}
</ContentCodeBlock>
)}
</ContentGrid>
)}
</ToolAccordion>
)}
</div>

View File

@@ -2,6 +2,9 @@ import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentP
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import {
NotePencilIcon,
@@ -12,6 +15,9 @@ import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
export type EditAgentToolOutput =
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
@@ -31,6 +37,9 @@ function parseOutput(output: unknown): EditAgentToolOutput | null {
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (
type === ResponseType.operation_started ||
type === ResponseType.operation_pending ||
type === ResponseType.operation_in_progress ||
type === ResponseType.agent_preview ||
type === ResponseType.agent_saved ||
type === ResponseType.clarification_needed ||
@@ -38,6 +47,9 @@ function parseOutput(output: unknown): EditAgentToolOutput | null {
) {
return output as EditAgentToolOutput;
}
if ("operation_id" in output && "tool_name" in output)
return output as OperationStartedResponse | OperationPendingResponse;
if ("tool_call_id" in output) return output as OperationInProgressResponse;
if ("agent_json" in output && "agent_name" in output)
return output as AgentPreviewResponse;
if ("agent_id" in output && "library_agent_id" in output)
@@ -56,6 +68,30 @@ export function getEditAgentToolOutput(
return parseOutput((part as { output?: unknown }).output);
}
export function isOperationStartedOutput(
output: EditAgentToolOutput,
): output is OperationStartedResponse {
return (
output.type === ResponseType.operation_started ||
("operation_id" in output && "tool_name" in output)
);
}
export function isOperationPendingOutput(
output: EditAgentToolOutput,
): output is OperationPendingResponse {
return output.type === ResponseType.operation_pending;
}
export function isOperationInProgressOutput(
output: EditAgentToolOutput,
): output is OperationInProgressResponse {
return (
output.type === ResponseType.operation_in_progress ||
"tool_call_id" in output
);
}
export function isAgentPreviewOutput(
output: EditAgentToolOutput,
): output is AgentPreviewResponse {
@@ -96,6 +132,10 @@ export function getAnimationText(part: {
case "output-available": {
const output = parseOutput(part.output);
if (!output) return "Editing the agent";
if (isOperationStartedOutput(output)) return "Agent update started";
if (isOperationPendingOutput(output)) return "Agent update in progress";
if (isOperationInProgressOutput(output))
return "Agent update already in progress";
if (isAgentSavedOutput(output)) return `Saved "${output.agent_name}"`;
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
if (isClarificationNeededOutput(output)) return "Needs clarification";

View File

@@ -686,20 +686,17 @@ export function GenericTool({ part }: Props) {
return (
<div className="py-2">
{/* Only show loading text when NOT showing accordion */}
{!showAccordion && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon
category={category}
isStreaming={isStreaming}
isError={isError}
/>
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
)}
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon
category={category}
isStreaming={isStreaming}
isError={isError}
/>
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{showAccordion && accordionData ? (
<ToolAccordion

View File

@@ -69,20 +69,13 @@ export function RunAgentTool({ part }: Props) {
return (
<div className="py-2">
{/* Only show loading text when NOT showing accordion or other content */}
{!isStreaming &&
!setupRequirementsOutput &&
!agentDetailsOutput &&
!needLoginOutput &&
!hasExpandableContent && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
)}
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isStreaming && !output && (
<ToolAccordion

View File

@@ -115,7 +115,6 @@ export function useChatSession() {
hydratedMessages,
hasActiveStream,
isLoadingSession: sessionQuery.isLoading,
isSessionError: sessionQuery.isError,
createSession,
isCreatingSession,
};

View File

@@ -14,6 +14,7 @@ import { DefaultChatTransport } from "ai";
import type { UIMessage } from "ai";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useChatSession } from "./useChatSession";
import { useLongRunningToolPolling } from "./hooks/useLongRunningToolPolling";
const STREAM_START_TIMEOUT_MS = 12_000;
@@ -35,46 +36,6 @@ function resolveInProgressTools(
}));
}
/** Build a fingerprint from a message's role + text/tool content for cross-boundary dedup. */
function messageFingerprint(msg: UIMessage): string {
const fragments = msg.parts.map((p) => {
if ("text" in p && typeof p.text === "string") return p.text;
if ("toolCallId" in p && typeof p.toolCallId === "string")
return `tool:${p.toolCallId}`;
return "";
});
return `${msg.role}::${fragments.join("\n")}`;
}
/**
* Deduplicate messages by ID *and* by content fingerprint.
* ID-based dedup catches duplicates within the same source (e.g. two
* identical stream events). Fingerprint-based dedup catches duplicates
* across the hydration/stream boundary where IDs differ (synthetic
* `${sessionId}-${index}` vs AI SDK nanoid).
*
* NOTE: Fingerprint dedup only applies to assistant messages, not user messages.
* Users should be able to send the same message multiple times.
*/
function deduplicateMessages(messages: UIMessage[]): UIMessage[] {
const seenIds = new Set<string>();
const seenFingerprints = new Set<string>();
return messages.filter((msg) => {
if (seenIds.has(msg.id)) return false;
seenIds.add(msg.id);
// Only apply fingerprint deduplication to assistant messages
// User messages should allow duplicates (same text sent multiple times)
if (msg.role === "assistant") {
const fp = messageFingerprint(msg);
if (fp !== "::" && seenFingerprints.has(fp)) return false;
seenFingerprints.add(fp);
}
return true;
});
}
export function useCopilotPage() {
const { isUserLoading, isLoggedIn } = useSupabase();
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
@@ -91,7 +52,6 @@ export function useCopilotPage() {
hydratedMessages,
hasActiveStream,
isLoadingSession,
isSessionError,
createSession,
isCreatingSession,
} = useChatSession();
@@ -154,7 +114,7 @@ export function useCopilotPage() {
);
const {
messages: rawMessages,
messages,
sendMessage,
stop: sdkStop,
status,
@@ -169,12 +129,6 @@ export function useCopilotPage() {
// call resumeStream() manually after hydration + active_stream detection.
});
// Deduplicate messages continuously to prevent duplicates when resuming streams
const messages = useMemo(
() => deduplicateMessages(rawMessages),
[rawMessages],
);
// Wrap AI SDK's stop() to also cancel the backend executor task.
// sdkStop() aborts the SSE fetch instantly (UI feedback), then we fire
// the cancel API to actually stop the executor and wait for confirmation.
@@ -230,14 +184,14 @@ export function useCopilotPage() {
if (status === "streaming" || status === "submitted") return;
setMessages((prev) => {
if (prev.length >= hydratedMessages.length) return prev;
// Deduplicate to handle rare cases where duplicate streams might occur
return deduplicateMessages(hydratedMessages);
return hydratedMessages;
});
}, [hydratedMessages, setMessages, status]);
// Ref: tracks whether we've already resumed for a given session.
// Format: Map<sessionId, hasResumed>
const hasResumedRef = useRef<Map<string, boolean>>(new Map());
// Reset when the stream ends so re-resume is possible if the backend
// task is still running (SSE dropped but executor didn't finish).
const hasResumedRef = useRef<string | null>(null);
// When the stream ends (or drops), invalidate the session cache so the
// next hydration fetches fresh messages from the backend. Without this,
@@ -254,27 +208,29 @@ export function useCopilotPage() {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
// Allow re-resume if the backend task is still running.
hasResumedRef.current = null;
}
}, [status, sessionId, queryClient]);
// Resume an active stream AFTER hydration completes.
// IMPORTANT: Only runs when page loads with existing active stream (reconnection).
// Does NOT run when new streams start during active conversation.
// The backend returns active_stream info when a task is still running.
// We wait for hydration so the AI SDK has the conversation history
// before the resumed stream appends the in-progress assistant message.
useEffect(() => {
if (!sessionId) return;
if (!hasActiveStream) return;
if (!hasActiveStream || !sessionId) return;
if (!hydratedMessages || hydratedMessages.length === 0) return;
// Never resume if currently streaming
if (status === "streaming" || status === "submitted") return;
// Only resume once per session
if (hasResumedRef.current.get(sessionId)) return;
// Mark as resumed immediately to prevent race conditions
hasResumedRef.current.set(sessionId, true);
// Only resume once per session to avoid re-triggering after stream ends
if (hasResumedRef.current === sessionId) return;
hasResumedRef.current = sessionId;
resumeStream();
}, [sessionId, hasActiveStream, hydratedMessages, status, resumeStream]);
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
// is in progress. When the backend completes, the session data will contain
// the final tool output — this hook detects the change and updates messages.
useLongRunningToolPolling(sessionId, messages, setMessages);
// Clear messages when session is null
useEffect(() => {
@@ -365,7 +321,6 @@ export function useCopilotPage() {
stop,
isReconnecting,
isLoadingSession,
isSessionError,
isCreatingSession,
isUserLoading,
isLoggedIn,

View File

@@ -0,0 +1,64 @@
import { environment } from "@/services/environment";
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest } from "next/server";
import { normalizeSSEStream, SSE_HEADERS } from "../../../sse-helpers";
export async function GET(
request: NextRequest,
{ params }: { params: Promise<{ taskId: string }> },
) {
const { taskId } = await params;
const searchParams = request.nextUrl.searchParams;
const lastMessageId = searchParams.get("last_message_id") || "0-0";
try {
const token = await getServerAuthToken();
const backendUrl = environment.getAGPTServerBaseUrl();
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
streamUrl.searchParams.set("last_message_id", lastMessageId);
const headers: Record<string, string> = {
Accept: "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(streamUrl.toString(), {
method: "GET",
headers,
});
if (!response.ok) {
const error = await response.text();
return new Response(error, {
status: response.status,
headers: { "Content-Type": "application/json" },
});
}
if (!response.body) {
return new Response(null, { status: 204 });
}
return new Response(normalizeSSEStream(response.body), {
headers: SSE_HEADERS,
});
} catch (error) {
console.error("Task stream proxy error:", error);
return new Response(
JSON.stringify({
error: "Failed to connect to task stream",
detail: error instanceof Error ? error.message : String(error),
}),
{
status: 500,
headers: { "Content-Type": "application/json" },
},
);
}
}

View File

@@ -961,6 +961,63 @@
}
}
},
"/api/chat/operations/{operation_id}/complete": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Complete Operation",
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
"operationId": "postV2CompleteOperation",
"parameters": [
{
"name": "operation_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Operation Id" }
},
{
"name": "x-api-key",
"in": "header",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "X-Api-Key"
}
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OperationCompleteRequest"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Postv2Completeoperation"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/schema/tool-responses": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -1000,7 +1057,12 @@
{ "$ref": "#/components/schemas/BlockDetailsResponse" },
{ "$ref": "#/components/schemas/BlockOutputResponse" },
{ "$ref": "#/components/schemas/DocSearchResultsResponse" },
{ "$ref": "#/components/schemas/DocPageResponse" }
{ "$ref": "#/components/schemas/DocPageResponse" },
{ "$ref": "#/components/schemas/OperationStartedResponse" },
{ "$ref": "#/components/schemas/OperationPendingResponse" },
{
"$ref": "#/components/schemas/OperationInProgressResponse"
}
],
"title": "Response Getv2[Dummy] Tool Response Type Export For Codegen"
}
@@ -1123,7 +1185,7 @@
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns active_stream info for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
"operationId": "getV2GetSession",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1275,7 +1337,7 @@
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
"operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1313,6 +1375,94 @@
}
}
},
"/api/chat/tasks/{task_id}": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Task Status",
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2GetTaskStatus",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Getv2Gettaskstatus"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/tasks/{task_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Task",
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
"operationId": "getV2StreamTask",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
},
{
"name": "last_message_id",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
"default": "0-0",
"title": "Last Message Id"
},
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -6412,11 +6562,13 @@
},
"ActiveStreamInfo": {
"properties": {
"turn_id": { "type": "string", "title": "Turn Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" }
"task_id": { "type": "string", "title": "Task Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" },
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" }
},
"type": "object",
"required": ["turn_id", "last_message_id"],
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
"title": "ActiveStreamInfo",
"description": "Information about an active stream for reconnection."
},
@@ -7426,6 +7578,10 @@
"CancelTaskResponse": {
"properties": {
"cancelled": { "type": "boolean", "title": "Cancelled" },
"task_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Task Id"
},
"reason": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Reason"
@@ -9951,6 +10107,87 @@
],
"title": "OnboardingStep"
},
"OperationCompleteRequest": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"result": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
{ "type": "string" },
{ "type": "null" }
],
"title": "Result"
},
"error": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Error"
}
},
"type": "object",
"required": ["success"],
"title": "OperationCompleteRequest",
"description": "Request model for external completion webhook."
},
"OperationInProgressResponse": {
"properties": {
"type": {
"$ref": "#/components/schemas/ResponseType",
"default": "operation_in_progress"
},
"message": { "type": "string", "title": "Message" },
"session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
},
"tool_call_id": { "type": "string", "title": "Tool Call Id" }
},
"type": "object",
"required": ["message", "tool_call_id"],
"title": "OperationInProgressResponse",
"description": "Response when an operation is already in progress.\n\nReturned for idempotency when the same tool_call_id is requested again\nwhile the background task is still running."
},
"OperationPendingResponse": {
"properties": {
"type": {
"$ref": "#/components/schemas/ResponseType",
"default": "operation_pending"
},
"message": { "type": "string", "title": "Message" },
"session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
},
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" }
},
"type": "object",
"required": ["message", "operation_id", "tool_name"],
"title": "OperationPendingResponse",
"description": "Response stored in chat history while a long-running operation is executing.\n\nThis is persisted to the database so users see a pending state when they\nrefresh before the operation completes."
},
"OperationStartedResponse": {
"properties": {
"type": {
"$ref": "#/components/schemas/ResponseType",
"default": "operation_started"
},
"message": { "type": "string", "title": "Message" },
"session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
},
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" },
"task_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Task Id"
}
},
"type": "object",
"required": ["message", "operation_id", "tool_name"],
"title": "OperationStartedResponse",
"description": "Response when a long-running operation has been started in the background.\n\nThis is returned immediately to the client while the operation continues\nto execute. The user can close the tab and check back later.\n\nThe task_id can be used to reconnect to the SSE stream via\nGET /chat/tasks/{task_id}/stream?last_idx=0"
},
"Pagination": {
"properties": {
"total_items": {
@@ -10607,10 +10844,13 @@
"workspace_file_metadata",
"workspace_file_written",
"workspace_file_deleted",
"operation_started",
"operation_pending",
"operation_in_progress",
"input_validation_error",
"web_fetch",
"bash_exec",
"operation_status",
"feature_request_search",
"feature_request_created",
"suggested_goal"

View File

@@ -15,6 +15,7 @@
## Advanced Setup
* [Advanced Setup](advanced_setup.md)
* [Deployment Environment Variables](deployment-environment-variables.md)
## Building Blocks

View File

@@ -0,0 +1,397 @@
# Deployment Environment Variables
This guide documents **all environment variables that must be configured** when deploying AutoGPT to a new server or environment. Use this as a checklist to ensure your deployment works correctly.
## Quick Reference: What MUST Change
When deploying to a new server, these variables **must** be updated from their localhost defaults:
| Variable | Location | Default | Purpose |
|----------|----------|---------|---------|
| `SITE_URL` | `.env` | `http://localhost:3000` | Frontend URL for auth redirects |
| `API_EXTERNAL_URL` | `.env` | `http://localhost:8000` | Public Supabase API URL |
| `SUPABASE_PUBLIC_URL` | `.env` | `http://localhost:8000` | Studio dashboard URL |
| `PLATFORM_BASE_URL` | `backend/.env` | `http://localhost:8000` | Backend platform URL |
| `FRONTEND_BASE_URL` | `backend/.env` | `http://localhost:3000` | Frontend URL for webhooks/OAuth |
| `NEXT_PUBLIC_SUPABASE_URL` | `frontend/.env` | `http://localhost:8000` | Client-side Supabase URL |
| `NEXT_PUBLIC_AGPT_SERVER_URL` | `frontend/.env` | `http://localhost:8006/api` | Client-side backend API URL |
| `NEXT_PUBLIC_AGPT_WS_SERVER_URL` | `frontend/.env` | `ws://localhost:8001/ws` | Client-side WebSocket URL |
| `NEXT_PUBLIC_FRONTEND_BASE_URL` | `frontend/.env` | `http://localhost:3000` | Client-side frontend URL |
---
## Configuration Files
AutoGPT uses multiple `.env` files across different components:
```text
autogpt_platform/
├── .env # Supabase/infrastructure config
├── backend/
│ ├── .env.default # Backend defaults (DO NOT EDIT)
│ └── .env # Your backend overrides
└── frontend/
├── .env.default # Frontend defaults (DO NOT EDIT)
└── .env # Your frontend overrides
```
**Loading Order** (later overrides earlier):
1. `*.env.default` - Base defaults
2. `*.env` - Your overrides
3. Docker `environment:` section
4. Shell environment variables
---
## 1. URL Configuration (REQUIRED)
These URLs must be updated to match your deployment domain/IP.
### Root `.env` (Supabase)
```bash
# Auth redirects - where users return after login
SITE_URL=https://your-domain.com:3000
# Public API URL - exposed to clients
API_EXTERNAL_URL=https://your-domain.com:8000
# Studio dashboard URL
SUPABASE_PUBLIC_URL=https://your-domain.com:8000
```
### Backend `.env`
```bash
# Platform URLs for webhooks and OAuth callbacks
PLATFORM_BASE_URL=https://your-domain.com:8000
FRONTEND_BASE_URL=https://your-domain.com:3000
# Internal Supabase URL (use Docker service name if containerized)
SUPABASE_URL=http://kong:8000 # Docker
# SUPABASE_URL=https://your-domain.com:8000 # External
```
### Frontend `.env`
```bash
# Client-side URLs (used in browser)
NEXT_PUBLIC_SUPABASE_URL=https://your-domain.com:8000
NEXT_PUBLIC_AGPT_SERVER_URL=https://your-domain.com:8006/api
NEXT_PUBLIC_AGPT_WS_SERVER_URL=wss://your-domain.com:8001/ws
NEXT_PUBLIC_FRONTEND_BASE_URL=https://your-domain.com:3000
```
!!! warning "HTTPS Note"
For production, use HTTPS URLs and `wss://` for WebSocket. You'll need a reverse proxy (nginx, Caddy) with SSL certificates.
!!! info "Port Numbers"
The port numbers shown (`:3000`, `:8000`, `:8001`, `:8006`) are internal Docker service ports. In production with a reverse proxy, your public URLs typically won't include port numbers (e.g., `https://your-domain.com` instead of `https://your-domain.com:3000`). Configure your reverse proxy to route external traffic to the internal service ports.
---
## 2. Security Keys (MUST REGENERATE)
These default values are **public** and **must be changed** for production.
### Root `.env`
```bash
# Database password
POSTGRES_PASSWORD=<generate-strong-password>
# JWT secret for Supabase auth (min 32 chars)
JWT_SECRET=<generate-random-string>
# Supabase keys (regenerate with matching JWT_SECRET)
ANON_KEY=<regenerate>
SERVICE_ROLE_KEY=<regenerate>
# Studio dashboard credentials
DASHBOARD_USERNAME=<your-username>
DASHBOARD_PASSWORD=<strong-password>
# Encryption keys
SECRET_KEY_BASE=<generate-random-string>
VAULT_ENC_KEY=<generate-32-char-key> # Run: openssl rand -hex 16
```
### Backend `.env`
```bash
# Must match root POSTGRES_PASSWORD
DB_PASS=<same-as-POSTGRES_PASSWORD>
# Must match root SERVICE_ROLE_KEY
SUPABASE_SERVICE_ROLE_KEY=<same-as-SERVICE_ROLE_KEY>
# Must match root JWT_SECRET
JWT_VERIFY_KEY=<same-as-JWT_SECRET>
# Generate new encryption keys
# Run: python -c "from cryptography.fernet import Fernet;print(Fernet.generate_key().decode())"
ENCRYPTION_KEY=<generated-fernet-key>
UNSUBSCRIBE_SECRET_KEY=<generated-fernet-key>
```
### Generating Keys
```bash
# Generate Fernet encryption key (for ENCRYPTION_KEY, UNSUBSCRIBE_SECRET_KEY)
python -c "from cryptography.fernet import Fernet;print(Fernet.generate_key().decode())"
# Generate random string (for JWT_SECRET, SECRET_KEY_BASE)
openssl rand -base64 32
# Generate 32-character key (for VAULT_ENC_KEY)
openssl rand -hex 16
# Generate Supabase keys (requires matching JWT_SECRET)
# Use: https://supabase.com/docs/guides/self-hosting/docker#generate-api-keys
```
---
## 3. Database Configuration
### Root `.env`
```bash
POSTGRES_HOST=db # Docker service name or external host
POSTGRES_DB=postgres
POSTGRES_PORT=5432
POSTGRES_PASSWORD=<your-password>
```
### Backend `.env`
```bash
DB_USER=postgres
DB_PASS=<your-password>
DB_NAME=postgres
DB_PORT=5432
DB_HOST=localhost # Default is localhost; use 'db' in Docker
DB_SCHEMA=platform
# Connection pooling
DB_CONNECTION_LIMIT=12
DB_CONNECT_TIMEOUT=60
DB_POOL_TIMEOUT=300
# Full connection URL (auto-constructed from above in .env.default)
# Variable substitution is handled automatically; only override if you need custom parameters
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}"
```
---
## 4. Service Dependencies
### Redis
```bash
REDIS_HOST=redis # Docker: 'redis', External: hostname/IP
REDIS_PORT=6379
# REDIS_PASSWORD= # Uncomment if using authentication
```
### RabbitMQ
```bash
RABBITMQ_DEFAULT_USER=<username>
RABBITMQ_DEFAULT_PASS=<strong-password>
# In Docker, host is 'rabbitmq'
```
---
## 5. Default Ports
| Service | Port | Purpose |
|---------|------|---------|
| Frontend | 3000 | Next.js web UI |
| Kong (Supabase API) | 8000 | API gateway |
| WebSocket Server | 8001 | Real-time updates |
| Executor | 8002 | Agent execution |
| Scheduler | 8003 | Scheduled tasks |
| Database Manager | 8005 | DB operations |
| REST Server | 8006 | Main API |
| Notification Server | 8007 | Notifications |
| PostgreSQL | 5432 | Database |
| Redis | 6379 | Cache/queue |
| RabbitMQ | 5672/15672 | Message queue |
| ClamAV | 3310 | Antivirus scanning |
---
## 6. OAuth Callbacks
When configuring OAuth providers, use this callback URL format:
```text
https://your-domain.com/auth/integrations/oauth_callback
# Or with explicit port if not using a reverse proxy:
# https://your-domain.com:3000/auth/integrations/oauth_callback
```
### Supported OAuth Providers
| Provider | Env Variables | Setup URL |
|----------|---------------|-----------|
| GitHub | `GITHUB_CLIENT_ID`, `GITHUB_CLIENT_SECRET` | [github.com/settings/developers](https://github.com/settings/developers) |
| Google | `GOOGLE_CLIENT_ID`, `GOOGLE_CLIENT_SECRET` | [console.cloud.google.com](https://console.cloud.google.com/apis/credentials) |
| Discord | `DISCORD_CLIENT_ID`, `DISCORD_CLIENT_SECRET` | [discord.com/developers](https://discord.com/developers/applications) |
| Twitter/X | `TWITTER_CLIENT_ID`, `TWITTER_CLIENT_SECRET` | [developer.x.com](https://developer.x.com) |
| Notion | `NOTION_CLIENT_ID`, `NOTION_CLIENT_SECRET` | [developers.notion.com](https://developers.notion.com) |
| Linear | `LINEAR_CLIENT_ID`, `LINEAR_CLIENT_SECRET` | [linear.app/settings/api](https://linear.app/settings/api/applications/new) |
| Reddit | `REDDIT_CLIENT_ID`, `REDDIT_CLIENT_SECRET` | [reddit.com/prefs/apps](https://reddit.com/prefs/apps) |
| Todoist | `TODOIST_CLIENT_ID`, `TODOIST_CLIENT_SECRET` | [developer.todoist.com](https://developer.todoist.com/appconsole.html) |
---
## 7. Optional Services
### AI/LLM Providers
```bash
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GROQ_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
```
### Email (SMTP)
```bash
# Supabase auth emails
SMTP_HOST=smtp.example.com
SMTP_PORT=587
SMTP_USER=<username>
SMTP_PASS=<password>
SMTP_ADMIN_EMAIL=admin@example.com
# Application emails (Postmark)
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=noreply@your-domain.com
```
### Payments (Stripe)
```bash
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
```
### Error Tracking (Sentry)
```bash
SENTRY_DSN=
```
### Analytics (PostHog)
```bash
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Frontend
NEXT_PUBLIC_POSTHOG_KEY=
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
```
---
## 8. Deployment Checklist
Use this checklist when deploying to a new environment:
### Pre-deployment
- [ ] Clone repository and navigate to `autogpt_platform/`
- [ ] Copy all `.env.default` files to `.env`
- [ ] Determine your deployment domain/IP
### URL Configuration
- [ ] Update `SITE_URL` in root `.env`
- [ ] Update `API_EXTERNAL_URL` in root `.env`
- [ ] Update `SUPABASE_PUBLIC_URL` in root `.env`
- [ ] Update `PLATFORM_BASE_URL` in `backend/.env`
- [ ] Update `FRONTEND_BASE_URL` in `backend/.env`
- [ ] Update all `NEXT_PUBLIC_*` URLs in `frontend/.env`
### Security
- [ ] Generate new `POSTGRES_PASSWORD`
- [ ] Generate new `JWT_SECRET` (min 32 chars)
- [ ] Regenerate `ANON_KEY` and `SERVICE_ROLE_KEY`
- [ ] Change `DASHBOARD_USERNAME` and `DASHBOARD_PASSWORD`
- [ ] Generate new `ENCRYPTION_KEY` (backend)
- [ ] Generate new `UNSUBSCRIBE_SECRET_KEY` (backend)
- [ ] Update `DB_PASS` to match `POSTGRES_PASSWORD`
- [ ] Update `JWT_VERIFY_KEY` to match `JWT_SECRET`
- [ ] Update `SUPABASE_SERVICE_ROLE_KEY` to match
### Services
- [ ] Configure Redis connection (if external)
- [ ] Configure RabbitMQ credentials
- [ ] Configure SMTP for emails (if needed)
### OAuth (if using integrations)
- [ ] Register OAuth apps with your callback URL
- [ ] Add client IDs and secrets to `backend/.env`
### Post-deployment
- [ ] Run `docker compose up -d --build`
- [ ] Verify frontend loads at your URL
- [ ] Test authentication flow
- [ ] Test WebSocket connection (real-time updates)
---
## 9. Docker vs External Services
### Running Everything in Docker (Default)
The docker-compose files automatically set internal hostnames:
```yaml
# Internal Docker service names (container-to-container communication)
# These are set automatically in docker-compose.platform.yml
DB_HOST: db
REDIS_HOST: redis
RABBITMQ_HOST: rabbitmq
SUPABASE_URL: http://kong:8000
```
### Using External Services
If using managed services (AWS RDS, Redis Cloud, etc.), override in your `.env`:
```bash
# External PostgreSQL
DB_HOST=your-rds-instance.region.rds.amazonaws.com
DB_PORT=5432
# External Redis
REDIS_HOST=your-redis.cache.amazonaws.com
REDIS_PORT=6379
REDIS_PASSWORD=<if-required>
# External Supabase (hosted)
SUPABASE_URL=https://your-project.supabase.co
SUPABASE_SERVICE_ROLE_KEY=<your-service-role-key>
```
---
## Related Documentation
- [Getting Started](getting-started.md) - Basic setup guide
- [Advanced Setup](advanced_setup.md) - Development configuration
- [OAuth & SSO](integrating/oauth-guide.md) - Integration setup