mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-01 02:15:27 -05:00
Compare commits
7 Commits
swiftyos/s
...
docker/opt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ed748a356 | ||
|
|
9c28639c32 | ||
|
|
4f37a12743 | ||
|
|
7ee94d986c | ||
|
|
18a1661fa3 | ||
|
|
b72521daa9 | ||
|
|
350ad3591b |
@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
|
||||
### Updated Setup Instructions:
|
||||
We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
|
||||
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
@@ -37,13 +37,15 @@ ENV POETRY_VIRTUALENVS_CREATE=true
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
RUN pip3 install poetry --break-system-packages
|
||||
RUN pip3 install --no-cache-dir poetry --break-system-packages
|
||||
|
||||
# Copy and install dependencies
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --no-root
|
||||
# Production image only needs runtime deps; dev deps (pytest, black, ruff, etc.)
|
||||
# are installed locally via `poetry install --with dev` per the development docs
|
||||
RUN poetry install --no-ansi --no-root --only main
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
@@ -51,6 +53,15 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Clean up build artifacts and caches to reduce layer size
|
||||
# Note: setuptools is kept as it's a direct dependency (used by aioclamd via pkg_resources)
|
||||
RUN find /app -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true; \
|
||||
find /app -type d -name tests -exec rm -rf {} + 2>/dev/null || true; \
|
||||
find /app -type d -name test -exec rm -rf {} + 2>/dev/null || true; \
|
||||
rm -rf /app/autogpt_platform/backend/.venv/lib/python*/site-packages/pip* \
|
||||
/root/.cache/pip \
|
||||
/root/.cache/pypoetry
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
@@ -68,7 +79,7 @@ RUN apt-get update && apt-get install -y \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
# Copy built artifacts from builder (cleaned of caches, __pycache__, and test dirs)
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
@@ -81,9 +92,7 @@ COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-pyth
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
|
||||
# Copy fresh source from context (overwrites builder's copy with latest source)
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
"""RabbitMQ consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer that listens for completion notifications
|
||||
from external services (like Agent Generator) and triggers the appropriate
|
||||
stream registry and chat service updates.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
)
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Queue and exchange configuration
|
||||
OPERATION_COMPLETE_EXCHANGE = Exchange(
|
||||
name="chat_operations",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
)
|
||||
|
||||
OPERATION_COMPLETE_QUEUE = Queue(
|
||||
name="chat_operation_complete",
|
||||
durable=True,
|
||||
exchange=OPERATION_COMPLETE_EXCHANGE,
|
||||
routing_key="operation.complete",
|
||||
)
|
||||
|
||||
RABBITMQ_CONFIG = RabbitMQConfig(
|
||||
exchanges=[OPERATION_COMPLETE_EXCHANGE],
|
||||
queues=[OPERATION_COMPLETE_QUEUE],
|
||||
)
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from RabbitMQ."""
|
||||
|
||||
def __init__(self):
|
||||
self._rabbitmq: AsyncRabbitMQ | None = None
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
|
||||
await self._rabbitmq.connect()
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info("Chat completion consumer started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._rabbitmq:
|
||||
await self._rabbitmq.disconnect()
|
||||
self._rabbitmq = None
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
if not self._rabbitmq:
|
||||
logger.error("RabbitMQ not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
channel = await self._rabbitmq.get_channel()
|
||||
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
async with queue.iterator() as queue_iter:
|
||||
async for message in queue_iter:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
async with message.process():
|
||||
await self._handle_message(message.body)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message will be requeued due to exception
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a single completion message."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
# Try to look up by task_id directly
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
# Publish result to stream registry
|
||||
result_output = message.result if message.result else {"status": "completed"}
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=(
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
),
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
result_str = (
|
||||
message.result
|
||||
if isinstance(message.result, str)
|
||||
else (
|
||||
orjson.dumps(message.result).decode("utf-8")
|
||||
if message.result
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed completion for task {task.task_id} "
|
||||
f"(operation {message.operation_id})"
|
||||
)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
error_msg = message.error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=message.error,
|
||||
)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
logger.info(
|
||||
f"Processed failure for task {task.task_id} "
|
||||
f"(operation {message.operation_id}): {error_msg}"
|
||||
)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message.
|
||||
|
||||
This is a helper function for testing or for services that want to
|
||||
publish completion messages directly.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
|
||||
try:
|
||||
await rabbitmq.connect()
|
||||
await rabbitmq.publish_message(
|
||||
routing_key="operation.complete",
|
||||
message=message.model_dump_json(),
|
||||
exchange=OPERATION_COMPLETE_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
finally:
|
||||
await rabbitmq.disconnect()
|
||||
@@ -44,20 +44,6 @@ class ChatConfig(BaseSettings):
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=1000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
@@ -96,14 +82,6 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -4,19 +4,16 @@ import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
import orjson
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -84,14 +81,6 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -377,267 +366,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise NotFoundError(f"Task {task_id} not found or access denied.")
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
chunk_count = 0
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"Task stream completed for task {task_id}, "
|
||||
f"chunk_count={chunk_count}"
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership
|
||||
if user_id and task.user_id and task.user_id != user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
# Publish result to stream registry
|
||||
from .response_model import StreamToolOutputAvailable
|
||||
|
||||
result_output = request.result if request.result else {"status": "completed"}
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=(
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
),
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
from . import service as svc
|
||||
|
||||
result_str = (
|
||||
request.result
|
||||
if isinstance(request.result, str)
|
||||
else (
|
||||
orjson.dumps(request.result).decode("utf-8")
|
||||
if request.result
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
await svc._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
await svc._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
else:
|
||||
# Publish error to stream registry
|
||||
from .response_model import StreamError
|
||||
|
||||
error_msg = request.error or "Operation failed"
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
# Send finish event to end the stream
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
from . import service as svc
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=request.error,
|
||||
)
|
||||
await svc._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
@@ -1611,9 +1610,8 @@ async def _yield_tool_call(
|
||||
)
|
||||
return
|
||||
|
||||
# Generate operation ID and task ID
|
||||
# Generate operation ID
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
task_id = str(uuid_module.uuid4())
|
||||
|
||||
# Build a user-friendly message based on tool and arguments
|
||||
if tool_name == "create_agent":
|
||||
@@ -1656,16 +1654,6 @@ async def _yield_tool_call(
|
||||
|
||||
# Wrap session save and task creation in try-except to release lock on failure
|
||||
try:
|
||||
# Create task in stream registry for SSE reconnection support
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Save assistant message with tool_call FIRST (required by LLM)
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
@@ -1687,27 +1675,23 @@ async def _yield_tool_call(
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
||||
f"for tool {tool_name} in session {session.session_id}"
|
||||
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||
f"in session {session.session_id}"
|
||||
)
|
||||
|
||||
# Store task reference in module-level set to prevent GC before completion
|
||||
bg_task = asyncio.create_task(
|
||||
_execute_long_running_tool_with_streaming(
|
||||
task = asyncio.create_task(
|
||||
_execute_long_running_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(bg_task)
|
||||
bg_task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Associate the asyncio task with the stream registry task
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as e:
|
||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||
if (
|
||||
@@ -1725,11 +1709,6 @@ async def _yield_tool_call(
|
||||
|
||||
# Release the Redis lock since the background task won't be spawned
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
# Mark stream registry task as failed if it was created
|
||||
try:
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
@@ -1743,7 +1722,6 @@ async def _yield_tool_call(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
task_id=task_id, # Include task_id for SSE reconnection
|
||||
).model_dump_json(),
|
||||
success=True,
|
||||
)
|
||||
@@ -1813,9 +1791,6 @@ async def _execute_long_running_tool(
|
||||
|
||||
This function runs independently of the SSE connection, so the operation
|
||||
survives if the user closes their browser tab.
|
||||
|
||||
NOTE: This is the legacy function without stream registry support.
|
||||
Use _execute_long_running_tool_with_streaming for new implementations.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
@@ -1859,132 +1834,15 @@ async def _execute_long_running_tool(
|
||||
tool_call_id=tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
# Generate LLM continuation so user sees explanation even for errors
|
||||
try:
|
||||
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||
except Exception as llm_err:
|
||||
logger.warning(f"Failed to generate LLM continuation for error: {llm_err}")
|
||||
finally:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _execute_long_running_tool_with_streaming(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
"""Execute a long-running tool with stream registry support for SSE reconnection.
|
||||
|
||||
This function runs independently of the SSE connection, publishes progress
|
||||
to the stream registry, and survives if the user closes their browser tab.
|
||||
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
||||
|
||||
If the external service returns a 202 Accepted (async), this function exits
|
||||
early and lets the RabbitMQ completion consumer handle the rest.
|
||||
"""
|
||||
# Track whether we delegated to async processing - if so, the RabbitMQ
|
||||
# completion consumer will handle cleanup, not us
|
||||
delegated_to_async = False
|
||||
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for background tool")
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
return
|
||||
|
||||
# Pass operation_id and task_id to the tool for async processing
|
||||
enriched_parameters = {
|
||||
**parameters,
|
||||
"_operation_id": operation_id,
|
||||
"_task_id": task_id,
|
||||
}
|
||||
|
||||
# Execute the actual tool
|
||||
result = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=enriched_parameters,
|
||||
tool_call_id=tool_call_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Check if the tool result indicates async processing
|
||||
# (e.g., Agent Generator returned 202 Accepted)
|
||||
try:
|
||||
result_data = orjson.loads(result.output) if result.output else {}
|
||||
if result_data.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Tool {tool_name} delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id}). "
|
||||
f"RabbitMQ completion consumer will handle the rest."
|
||||
)
|
||||
# Don't publish result, don't continue with LLM, and don't cleanup
|
||||
# The RabbitMQ consumer will handle everything when the external
|
||||
# service completes and publishes to the queue
|
||||
delegated_to_async = True
|
||||
return
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
pass # Not JSON or not async - continue normally
|
||||
|
||||
# Publish tool result to stream registry
|
||||
await stream_registry.publish_chunk(task_id, result)
|
||||
|
||||
# Update the pending message with result
|
||||
result_str = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else orjson.dumps(result.output).decode("utf-8")
|
||||
)
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Background tool {tool_name} completed for session {session_id} "
|
||||
f"(task_id={task_id})"
|
||||
)
|
||||
|
||||
# Generate LLM continuation and stream chunks to registry
|
||||
await _generate_llm_continuation_with_streaming(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed in stream registry
|
||||
await stream_registry.mark_task_completed(task_id, status="completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(
|
||||
message=f"Tool {tool_name} failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=str(e)),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed in stream registry
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
finally:
|
||||
# Only cleanup if we didn't delegate to async processing
|
||||
# For async path, the RabbitMQ completion consumer handles cleanup
|
||||
if not delegated_to_async:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _update_pending_operation(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
@@ -2111,128 +1969,3 @@ async def _generate_llm_continuation(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def _generate_llm_continuation_with_streaming(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Generate an LLM response with streaming to the stream registry.
|
||||
|
||||
This is called by background tasks to continue the conversation
|
||||
after a tool result is saved. Chunks are published to the stream registry
|
||||
so reconnecting clients can receive them.
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||
return
|
||||
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
"environment": settings.config.app_env.value,
|
||||
},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Publish start event
|
||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||
|
||||
# Stream the response
|
||||
stream = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assistant_content = ""
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
delta = chunk.choices[0].delta.content
|
||||
assistant_content += delta
|
||||
# Publish delta to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamTextDelta(id=text_block_id, delta=delta),
|
||||
)
|
||||
|
||||
# Publish end events
|
||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||
|
||||
if assistant_content:
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
fresh_session = await get_chat_session(session_id, user_id)
|
||||
if not fresh_session:
|
||||
logger.error(
|
||||
f"Session {session_id} disappeared during LLM continuation"
|
||||
)
|
||||
return
|
||||
|
||||
# Save assistant message to database
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
logger.info(
|
||||
f"Generated streaming LLM continuation for session {session_id} "
|
||||
f"(task_id={task_id}), response length: {len(assistant_content)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Streaming LLM continuation returned empty response for {session_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
||||
)
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
@@ -1,648 +0,0 @@
|
||||
"""Stream registry for managing reconnectable SSE streams.
|
||||
|
||||
This module provides a registry for tracking active streaming tasks and their
|
||||
messages. It supports:
|
||||
- Creating tasks with unique IDs for long-running operations
|
||||
- Publishing stream messages to both Redis Streams and in-memory queues
|
||||
- Subscribing to tasks with replay of missed messages
|
||||
- Looking up tasks by operation_id for webhook callbacks
|
||||
- Cross-pod real-time delivery via Redis pub/sub
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
from .response_model import StreamBaseResponse, StreamFinish
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track active pub/sub listeners for cross-pod delivery
|
||||
_pubsub_listeners: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""Represents an active streaming task."""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
operation_id: str
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
# Lock for atomic status checks and subscriber management
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
# Set of subscriber queues for fan-out
|
||||
subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set)
|
||||
|
||||
|
||||
# Module-level registry for active tasks
|
||||
_active_tasks: dict[str, ActiveTask] = {}
|
||||
|
||||
# Redis key patterns
|
||||
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
|
||||
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
|
||||
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
|
||||
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
"""Get Redis key for task metadata."""
|
||||
return f"{TASK_META_PREFIX}{task_id}"
|
||||
|
||||
|
||||
def _get_task_stream_key(task_id: str) -> str:
|
||||
"""Get Redis key for task message stream."""
|
||||
return f"{TASK_STREAM_PREFIX}{task_id}"
|
||||
|
||||
|
||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
"""Get Redis key for operation_id to task_id mapping."""
|
||||
return f"{TASK_OP_PREFIX}{operation_id}"
|
||||
|
||||
|
||||
def _get_task_pubsub_channel(task_id: str) -> str:
|
||||
"""Get Redis pub/sub channel for task cross-pod delivery."""
|
||||
return f"{TASK_PUBSUB_PREFIX}{task_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
operation_id: str,
|
||||
) -> ActiveTask:
|
||||
"""Create a new streaming task in memory and Redis.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous)
|
||||
tool_call_id: Tool call ID from the LLM
|
||||
tool_name: Name of the tool being executed
|
||||
operation_id: Operation ID for webhook callbacks
|
||||
|
||||
Returns:
|
||||
The created ActiveTask instance
|
||||
"""
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Store in memory registry
|
||||
_active_tasks[task_id] = task
|
||||
|
||||
# Store metadata in Redis for durability
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id or "",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"operation_id": operation_id,
|
||||
"status": task.status,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.info(
|
||||
f"Created streaming task {task_id} for operation {operation_id} "
|
||||
f"in session {session_id}"
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> str:
|
||||
"""Publish a chunk to the task's stream.
|
||||
|
||||
Delivers to in-memory subscribers first (for real-time), then persists to
|
||||
Redis Stream (for replay). This order ensures live subscribers get messages
|
||||
even if Redis temporarily fails.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if
|
||||
Redis persistence failed
|
||||
"""
|
||||
# Deliver to in-memory subscribers FIRST for real-time updates
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
async with task.lock:
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, dropping chunk"
|
||||
)
|
||||
|
||||
# Then persist to Redis Stream for replay (with error handling)
|
||||
message_id = "0-0"
|
||||
chunk_json = chunk.model_dump_json()
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Add to Redis Stream with auto-generated ID
|
||||
# The ID format is "timestamp-sequence" which gives us ordering
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Publish to pub/sub for cross-pod real-time delivery
|
||||
pubsub_channel = _get_task_pubsub_channel(task_id)
|
||||
await redis.publish(pubsub_channel, chunk_json)
|
||||
|
||||
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist chunk to Redis for task {task_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
async def subscribe_to_task(
|
||||
task_id: str,
|
||||
user_id: str | None,
|
||||
last_message_id: str = "0-0",
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||
|
||||
Returns:
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
# Check in-memory first
|
||||
task = _active_tasks.get(task_id)
|
||||
|
||||
if task:
|
||||
# Validate ownership
|
||||
if user_id and task.user_id and task.user_id != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} attempted to subscribe to task {task_id} "
|
||||
f"owned by {task.user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Create a new queue for this subscriber
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
|
||||
# Replay from Redis Stream
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Track the last message ID we've seen for gap detection
|
||||
replay_last_id = last_message_id
|
||||
|
||||
# Read all messages from stream starting after last_message_id
|
||||
# xread returns messages with ID > last_message_id
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
if messages:
|
||||
# messages format: [[stream_name, [(id, {data: json}), ...]]]
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
# Track the last message ID we've processed
|
||||
replay_last_id = (
|
||||
msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
)
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
# Reconstruct the appropriate response type
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
# Atomically check status and register subscriber under lock
|
||||
# This prevents race condition where task completes between check and subscribe
|
||||
should_start_pubsub = False
|
||||
async with task.lock:
|
||||
if task.status == "running":
|
||||
# Register this subscriber for live updates
|
||||
task.subscribers.add(subscriber_queue)
|
||||
# Start pub/sub listener if this is the first subscriber
|
||||
should_start_pubsub = len(task.subscribers) == 1
|
||||
logger.debug(
|
||||
f"Registered subscriber for task {task_id}, "
|
||||
f"total subscribers: {len(task.subscribers)}"
|
||||
)
|
||||
else:
|
||||
# Task is done, add finish marker
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
# After registering, do a second read to catch any messages published
|
||||
# between the first read and registration (closes the race window)
|
||||
if task.status == "running":
|
||||
gap_messages = await redis.xread(
|
||||
{stream_key: replay_last_id}, block=0, count=1000
|
||||
)
|
||||
if gap_messages:
|
||||
for _stream_name, stream_messages in gap_messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay gap message: {e}")
|
||||
|
||||
# Start pub/sub listener outside the lock to avoid deadlocks
|
||||
if should_start_pubsub:
|
||||
await start_pubsub_listener(task_id)
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
# Try to load from Redis if not in memory
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
logger.warning(f"Task {task_id} not found in memory or Redis")
|
||||
return None
|
||||
|
||||
# Validate ownership
|
||||
task_user_id = meta.get(b"user_id", b"").decode() or None
|
||||
if user_id and task_user_id and task_user_id != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} attempted to subscribe to task {task_id} "
|
||||
f"owned by {task_user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Replay from Redis Stream only (task is not in memory, so it's completed/crashed)
|
||||
subscriber_queue = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Read all messages starting after last_message_id
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
# Add finish marker since task is not active
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
) -> None:
|
||||
"""Mark a task as completed and publish final event.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
|
||||
if task:
|
||||
# Acquire lock to prevent new subscribers during completion
|
||||
async with task.lock:
|
||||
task.status = status
|
||||
# Send finish event directly to all current subscribers
|
||||
finish_event = StreamFinish()
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(finish_event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id} during completion"
|
||||
)
|
||||
# Clear subscribers since task is done
|
||||
task.subscribers.clear()
|
||||
|
||||
# Stop pub/sub listener since task is done
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
# Also publish to Redis Stream for replay (and pub/sub for cross-pod)
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
|
||||
# Remove from active tasks after a short delay to allow subscribers to finish
|
||||
async def _cleanup():
|
||||
await asyncio.sleep(5)
|
||||
_active_tasks.pop(task_id, None)
|
||||
logger.info(f"Cleaned up task {task_id} from memory")
|
||||
|
||||
asyncio.create_task(_cleanup())
|
||||
|
||||
# Update Redis metadata
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
await redis.hset(meta_key, "status", status) # type: ignore[misc]
|
||||
|
||||
logger.info(f"Marked task {task_id} as {status}")
|
||||
|
||||
|
||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
"""Find a task by its operation ID.
|
||||
|
||||
Used by webhook callbacks to locate the task to update.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID to search for
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
# Check in-memory first
|
||||
for task in _active_tasks.values():
|
||||
if task.operation_id == operation_id:
|
||||
return task
|
||||
|
||||
# Try Redis lookup
|
||||
redis = await get_redis_async()
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
task_id = await redis.get(op_key)
|
||||
|
||||
if task_id:
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
# Check if task is in memory
|
||||
if task_id_str in _active_tasks:
|
||||
return _active_tasks[task_id_str]
|
||||
|
||||
# Load metadata from Redis
|
||||
meta_key = _get_task_meta_key(task_id_str)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if meta:
|
||||
# Reconstruct task object (not fully active, but has metadata)
|
||||
return ActiveTask(
|
||||
task_id=meta.get(b"task_id", b"").decode(),
|
||||
session_id=meta.get(b"session_id", b"").decode(),
|
||||
user_id=meta.get(b"user_id", b"").decode() or None,
|
||||
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
|
||||
tool_name=meta.get(b"tool_name", b"").decode(),
|
||||
operation_id=operation_id,
|
||||
status=meta.get(b"status", b"running").decode(), # type: ignore
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> ActiveTask | None:
|
||||
"""Get a task by its ID.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
# Check in-memory first
|
||||
if task_id in _active_tasks:
|
||||
return _active_tasks[task_id]
|
||||
|
||||
# Try Redis lookup
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if meta:
|
||||
return ActiveTask(
|
||||
task_id=meta.get(b"task_id", b"").decode(),
|
||||
session_id=meta.get(b"session_id", b"").decode(),
|
||||
user_id=meta.get(b"user_id", b"").decode() or None,
|
||||
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
|
||||
tool_name=meta.get(b"tool_name", b"").decode(),
|
||||
operation_id=meta.get(b"operation_id", b"").decode(),
|
||||
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||
|
||||
Args:
|
||||
chunk_data: Parsed JSON data from Redis
|
||||
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
|
||||
try:
|
||||
if chunk_type == ResponseType.START.value:
|
||||
return StreamStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.FINISH.value:
|
||||
return StreamFinish(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_START.value:
|
||||
return StreamTextStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_DELTA.value:
|
||||
return StreamTextDelta(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_END.value:
|
||||
return StreamTextEnd(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_INPUT_START.value:
|
||||
return StreamToolInputStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_INPUT_AVAILABLE.value:
|
||||
return StreamToolInputAvailable(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_OUTPUT_AVAILABLE.value:
|
||||
return StreamToolOutputAvailable(**chunk_data)
|
||||
elif chunk_type == ResponseType.ERROR.value:
|
||||
return StreamError(**chunk_data)
|
||||
elif chunk_type == ResponseType.USAGE.value:
|
||||
return StreamUsage(**chunk_data)
|
||||
elif chunk_type == ResponseType.HEARTBEAT.value:
|
||||
return StreamHeartbeat(**chunk_data)
|
||||
else:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||
"""Associate an asyncio.Task with an ActiveTask.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
asyncio_task: The asyncio Task to associate
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
task.asyncio_task = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Unsubscribe a queue from a task's stream.
|
||||
|
||||
Should be called when a client disconnects to clean up resources.
|
||||
Also stops the pub/sub listener if there are no more local subscribers.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to unsubscribe from
|
||||
subscriber_queue: The queue to remove from subscribers
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
async with task.lock:
|
||||
task.subscribers.discard(subscriber_queue)
|
||||
remaining = len(task.subscribers)
|
||||
logger.debug(
|
||||
f"Unsubscribed from task {task_id}, "
|
||||
f"remaining subscribers: {remaining}"
|
||||
)
|
||||
# Stop pub/sub listener if no more local subscribers
|
||||
if remaining == 0:
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
|
||||
async def start_pubsub_listener(task_id: str) -> None:
|
||||
"""Start listening to Redis pub/sub for cross-pod delivery.
|
||||
|
||||
This enables real-time updates when another pod publishes chunks for a task
|
||||
that has local subscribers on this pod.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
"""
|
||||
if task_id in _pubsub_listeners:
|
||||
return # Already listening
|
||||
|
||||
task = _active_tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
|
||||
async def _listener():
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pubsub = redis.pubsub()
|
||||
channel = _get_task_pubsub_channel(task_id)
|
||||
await pubsub.subscribe(channel)
|
||||
logger.debug(f"Started pub/sub listener for task {task_id}")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(message["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
# Deliver to local subscribers
|
||||
local_task = _active_tasks.get(task_id)
|
||||
if local_task:
|
||||
async with local_task.lock:
|
||||
for queue in local_task.subscribers:
|
||||
try:
|
||||
queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
# Stop listening if this was a finish event
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing pub/sub message: {e}")
|
||||
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Pub/sub listener error for task {task_id}: {e}")
|
||||
finally:
|
||||
_pubsub_listeners.pop(task_id, None)
|
||||
logger.debug(f"Stopped pub/sub listener for task {task_id}")
|
||||
|
||||
listener_task = asyncio.create_task(_listener())
|
||||
_pubsub_listeners[task_id] = listener_task
|
||||
|
||||
|
||||
async def stop_pubsub_listener(task_id: str) -> None:
|
||||
"""Stop the pub/sub listener for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to stop listening for
|
||||
"""
|
||||
listener = _pubsub_listeners.pop(task_id, None)
|
||||
if listener and not listener.done():
|
||||
listener.cancel()
|
||||
try:
|
||||
await listener
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.debug(f"Cancelled pub/sub listener for task {task_id}")
|
||||
@@ -2,30 +2,54 @@
|
||||
|
||||
from .core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
AgentJsonValidationError,
|
||||
AgentSummary,
|
||||
DecompositionResult,
|
||||
DecompositionStep,
|
||||
LibraryAgentSummary,
|
||||
MarketplaceAgentSummary,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
extract_search_terms_from_steps,
|
||||
extract_uuids_from_text,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
get_library_agent_by_id,
|
||||
get_library_agents_for_generation,
|
||||
json_to_graph,
|
||||
save_agent_to_library,
|
||||
search_marketplace_agents_for_generation,
|
||||
)
|
||||
from .errors import get_user_message_for_error
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
# Core functions
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
"AgentJsonValidationError",
|
||||
"AgentSummary",
|
||||
"DecompositionResult",
|
||||
"DecompositionStep",
|
||||
"LibraryAgentSummary",
|
||||
"MarketplaceAgentSummary",
|
||||
"check_external_service_health",
|
||||
"decompose_goal",
|
||||
"enrich_library_agents_from_steps",
|
||||
"extract_search_terms_from_steps",
|
||||
"extract_uuids_from_text",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"save_agent_to_library",
|
||||
"get_agent_as_json",
|
||||
"json_to_graph",
|
||||
# Exceptions
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
# Service
|
||||
"is_external_service_configured",
|
||||
"check_external_service_health",
|
||||
# Error handling
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
"get_library_agent_by_id",
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"is_external_service_configured",
|
||||
"json_to_graph",
|
||||
"save_agent_to_library",
|
||||
"search_marketplace_agents_for_generation",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import (
|
||||
Graph,
|
||||
Link,
|
||||
Node,
|
||||
create_graph,
|
||||
get_graph,
|
||||
get_graph_all_versions,
|
||||
)
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
decompose_goal_external,
|
||||
@@ -16,6 +26,74 @@ from .service import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
|
||||
|
||||
class ExecutionSummary(TypedDict):
|
||||
"""Summary of a single execution for quality assessment."""
|
||||
|
||||
status: str
|
||||
correctness_score: NotRequired[float]
|
||||
activity_summary: NotRequired[str]
|
||||
|
||||
|
||||
class LibraryAgentSummary(TypedDict):
|
||||
"""Summary of a library agent for sub-agent composition.
|
||||
|
||||
Includes recent executions to help the LLM decide whether to use this agent.
|
||||
Each execution shows status, correctness_score (0-1), and activity_summary.
|
||||
"""
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
recent_executions: NotRequired[list[ExecutionSummary]]
|
||||
|
||||
|
||||
class MarketplaceAgentSummary(TypedDict):
|
||||
"""Summary of a marketplace agent for sub-agent composition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sub_heading: str
|
||||
creator: str
|
||||
is_marketplace_agent: bool
|
||||
|
||||
|
||||
class DecompositionStep(TypedDict, total=False):
|
||||
"""A single step in decomposed instructions."""
|
||||
|
||||
description: str
|
||||
action: str
|
||||
block_name: str
|
||||
tool: str
|
||||
name: str
|
||||
|
||||
|
||||
class DecompositionResult(TypedDict, total=False):
|
||||
"""Result from decompose_goal - can be instructions, questions, or error."""
|
||||
|
||||
type: str
|
||||
steps: list[DecompositionStep]
|
||||
questions: list[dict[str, Any]]
|
||||
error: str
|
||||
error_type: str
|
||||
|
||||
|
||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
def _to_dict_list(
|
||||
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
return None
|
||||
return [dict(a) for a in agents]
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
@@ -36,15 +114,414 @@ def _check_service_configured() -> None:
|
||||
)
|
||||
|
||||
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_uuids_from_text(text: str) -> list[str]:
|
||||
"""Extract all UUID v4 strings from text.
|
||||
|
||||
Args:
|
||||
text: Text that may contain UUIDs (e.g., user's goal description)
|
||||
|
||||
Returns:
|
||||
List of unique UUIDs found in the text (lowercase)
|
||||
"""
|
||||
matches = _UUID_PATTERN.findall(text)
|
||||
return list({m.lower() for m in matches})
|
||||
|
||||
|
||||
async def get_library_agent_by_id(
|
||||
user_id: str, agent_id: str
|
||||
) -> LibraryAgentSummary | None:
|
||||
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
|
||||
|
||||
This function tries multiple lookup strategies:
|
||||
1. First tries to find by graph_id (AgentGraph primary key)
|
||||
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
|
||||
|
||||
This handles both cases:
|
||||
- User provides graph_id (e.g., from AgentExecutorBlock)
|
||||
- User provides library agent ID (e.g., from library URL)
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||
|
||||
|
||||
async def get_library_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
exclude_graph_id: str | None = None,
|
||||
max_results: int = 15,
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Fetch user's library agents formatted for Agent Generator.
|
||||
|
||||
Uses search-based fetching to return relevant agents instead of all agents.
|
||||
This is more scalable for users with large libraries.
|
||||
|
||||
Includes recent_executions list to help the LLM assess agent quality:
|
||||
- Each execution has status, correctness_score (0-1), and activity_summary
|
||||
- This gives the LLM concrete examples of recent performance
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
search_query: Optional search term to find relevant agents (user's goal/description)
|
||||
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||
max_results: Maximum number of agents to return (default 15)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in response.agents:
|
||||
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
|
||||
continue
|
||||
|
||||
summary = LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
if agent.recent_executions:
|
||||
exec_summaries: list[ExecutionSummary] = []
|
||||
for ex in agent.recent_executions:
|
||||
exec_sum = ExecutionSummary(status=ex.status)
|
||||
if ex.correctness_score is not None:
|
||||
exec_sum["correctness_score"] = ex.correctness_score
|
||||
if ex.activity_summary:
|
||||
exec_sum["activity_summary"] = ex.activity_summary
|
||||
exec_summaries.append(exec_sum)
|
||||
summary["recent_executions"] = exec_summaries
|
||||
results.append(summary)
|
||||
return results
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def search_marketplace_agents_for_generation(
|
||||
search_query: str,
|
||||
max_results: int = 10,
|
||||
) -> list[MarketplaceAgentSummary]:
|
||||
"""Search marketplace agents formatted for Agent Generator.
|
||||
|
||||
Note: This returns basic agent info. Full input/output schemas would require
|
||||
additional graph fetches and is a potential future enhancement.
|
||||
|
||||
Args:
|
||||
search_query: Search term to find relevant public agents
|
||||
max_results: Maximum number of agents to return (default 10)
|
||||
|
||||
Returns:
|
||||
List of MarketplaceAgentSummary (without detailed schemas for now)
|
||||
"""
|
||||
try:
|
||||
response = await store_db.get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
)
|
||||
|
||||
results: list[MarketplaceAgentSummary] = []
|
||||
for agent in response.agents:
|
||||
results.append(
|
||||
MarketplaceAgentSummary(
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
sub_heading=agent.sub_heading,
|
||||
creator=agent.creator,
|
||||
is_marketplace_agent=True,
|
||||
)
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_all_relevant_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
exclude_graph_id: str | None = None,
|
||||
include_library: bool = True,
|
||||
include_marketplace: bool = True,
|
||||
max_library_results: int = 15,
|
||||
max_marketplace_results: int = 10,
|
||||
) -> list[AgentSummary]:
|
||||
"""Fetch relevant agents from library and/or marketplace.
|
||||
|
||||
Searches both user's library and marketplace by default.
|
||||
Explicitly mentioned UUIDs in the search query are always looked up.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
search_query: Search term to find relevant agents (user's goal/description)
|
||||
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||
include_library: Whether to search user's library (default True)
|
||||
include_marketplace: Whether to also search marketplace (default True)
|
||||
max_library_results: Max library agents to return (default 15)
|
||||
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||
|
||||
Returns:
|
||||
List of AgentSummary, library agents first (with full schemas),
|
||||
then marketplace agents (basic info only)
|
||||
"""
|
||||
agents: list[AgentSummary] = []
|
||||
seen_graph_ids: set[str] = set()
|
||||
|
||||
if search_query:
|
||||
mentioned_uuids = extract_uuids_from_text(search_query)
|
||||
for graph_id in mentioned_uuids:
|
||||
if graph_id == exclude_graph_id:
|
||||
continue
|
||||
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||
agent_graph_id = agent.get("graph_id") if agent else None
|
||||
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(agent_graph_id)
|
||||
logger.debug(
|
||||
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
|
||||
)
|
||||
|
||||
if include_library:
|
||||
library_agents = await get_library_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=search_query,
|
||||
exclude_graph_id=exclude_graph_id,
|
||||
max_results=max_library_results,
|
||||
)
|
||||
for agent in library_agents:
|
||||
graph_id = agent.get("graph_id")
|
||||
if graph_id and graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(graph_id)
|
||||
|
||||
if include_marketplace and search_query:
|
||||
marketplace_agents = await search_marketplace_agents_for_generation(
|
||||
search_query=search_query,
|
||||
max_results=max_marketplace_results,
|
||||
)
|
||||
library_names: set[str] = set()
|
||||
for a in agents:
|
||||
name = a.get("name")
|
||||
if name and isinstance(name, str):
|
||||
library_names.add(name.lower())
|
||||
for agent in marketplace_agents:
|
||||
agent_name = agent.get("name")
|
||||
if agent_name and isinstance(agent_name, str):
|
||||
if agent_name.lower() not in library_names:
|
||||
agents.append(agent)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def extract_search_terms_from_steps(
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
) -> list[str]:
|
||||
"""Extract search terms from decomposed instruction steps.
|
||||
|
||||
Analyzes the decomposition result to extract relevant keywords
|
||||
for additional library agent searches.
|
||||
|
||||
Args:
|
||||
decomposition_result: Result from decompose_goal containing steps
|
||||
|
||||
Returns:
|
||||
List of unique search terms extracted from steps
|
||||
"""
|
||||
search_terms: list[str] = []
|
||||
|
||||
if decomposition_result.get("type") != "instructions":
|
||||
return search_terms
|
||||
|
||||
steps = decomposition_result.get("steps", [])
|
||||
if not steps:
|
||||
return search_terms
|
||||
|
||||
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
|
||||
|
||||
for step in steps:
|
||||
for key in step_keys:
|
||||
value = step.get(key) # type: ignore[union-attr]
|
||||
if isinstance(value, str) and len(value) > 3:
|
||||
search_terms.append(value)
|
||||
|
||||
seen: set[str] = set()
|
||||
unique_terms: list[str] = []
|
||||
for term in search_terms:
|
||||
term_lower = term.lower()
|
||||
if term_lower not in seen:
|
||||
seen.add(term_lower)
|
||||
unique_terms.append(term)
|
||||
|
||||
return unique_terms
|
||||
|
||||
|
||||
async def enrich_library_agents_from_steps(
|
||||
user_id: str,
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||
exclude_graph_id: str | None = None,
|
||||
include_marketplace: bool = True,
|
||||
max_additional_results: int = 10,
|
||||
) -> list[AgentSummary] | list[dict[str, Any]]:
|
||||
"""Enrich library agents list with additional searches based on decomposed steps.
|
||||
|
||||
This implements two-phase search: after decomposition, we search for additional
|
||||
relevant agents based on the specific steps identified.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
decomposition_result: Result from decompose_goal containing steps
|
||||
existing_agents: Already fetched library agents from initial search
|
||||
exclude_graph_id: Optional graph ID to exclude
|
||||
include_marketplace: Whether to also search marketplace
|
||||
max_additional_results: Max additional agents per search term (default 10)
|
||||
|
||||
Returns:
|
||||
Combined list of library agents (existing + newly discovered)
|
||||
"""
|
||||
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
if not search_terms:
|
||||
return existing_agents
|
||||
|
||||
existing_ids: set[str] = set()
|
||||
existing_names: set[str] = set()
|
||||
|
||||
for agent in existing_agents:
|
||||
agent_name = agent.get("name")
|
||||
if agent_name and isinstance(agent_name, str):
|
||||
existing_names.add(agent_name.lower())
|
||||
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||
|
||||
for term in search_terms[:3]:
|
||||
try:
|
||||
additional_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=term,
|
||||
exclude_graph_id=exclude_graph_id,
|
||||
include_marketplace=include_marketplace,
|
||||
max_library_results=max_additional_results,
|
||||
max_marketplace_results=5,
|
||||
)
|
||||
|
||||
for agent in additional_agents:
|
||||
agent_name = agent.get("name")
|
||||
if not agent_name or not isinstance(agent_name, str):
|
||||
continue
|
||||
agent_name_lower = agent_name.lower()
|
||||
|
||||
if agent_name_lower in existing_names:
|
||||
continue
|
||||
|
||||
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||
if graph_id and graph_id in existing_ids:
|
||||
continue
|
||||
|
||||
all_agents.append(agent)
|
||||
existing_names.add(agent_name_lower)
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
except DatabaseError:
|
||||
logger.error(f"Database error searching for agents with term '{term}'")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to search for additional agents with term '{term}': {e}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Enriched library agents: {len(existing_agents)} initial + "
|
||||
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
|
||||
)
|
||||
|
||||
return all_agents
|
||||
|
||||
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
DecompositionResult with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
@@ -54,40 +531,36 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
return await decompose_goal_external(description, context)
|
||||
result = await decompose_goal_external(
|
||||
description, context, _to_dict_list(library_agents)
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
|
||||
async def generate_agent(
|
||||
instructions: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(instructions, operation_id, task_id)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
)
|
||||
if result:
|
||||
# Check if it's an error response - pass through as-is
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
# Ensure required fields for successful agent generation
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
@@ -97,6 +570,12 @@ async def generate_agent(
|
||||
return result
|
||||
|
||||
|
||||
class AgentJsonValidationError(Exception):
|
||||
"""Raised when agent JSON is invalid or missing required fields."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
@@ -105,25 +584,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
|
||||
Raises:
|
||||
AgentJsonValidationError: If required fields are missing from nodes or links
|
||||
"""
|
||||
nodes = []
|
||||
for n in agent_json.get("nodes", []):
|
||||
for idx, n in enumerate(agent_json.get("nodes", [])):
|
||||
block_id = n.get("block_id")
|
||||
if not block_id:
|
||||
node_id = n.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Node '{node_id}' is missing required field 'block_id'"
|
||||
)
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=n["block_id"],
|
||||
block_id=block_id,
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for link_data in agent_json.get("links", []):
|
||||
for idx, link_data in enumerate(agent_json.get("links", [])):
|
||||
source_id = link_data.get("source_id")
|
||||
sink_id = link_data.get("sink_id")
|
||||
source_name = link_data.get("source_name")
|
||||
sink_name = link_data.get("sink_name")
|
||||
|
||||
missing_fields = []
|
||||
if not source_id:
|
||||
missing_fields.append("source_id")
|
||||
if not sink_id:
|
||||
missing_fields.append("sink_id")
|
||||
if not source_name:
|
||||
missing_fields.append("source_name")
|
||||
if not sink_name:
|
||||
missing_fields.append("sink_name")
|
||||
|
||||
if missing_fields:
|
||||
link_id = link_data.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
source_id=source_id,
|
||||
sink_id=sink_id,
|
||||
source_name=source_name,
|
||||
sink_name=sink_name,
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
@@ -144,22 +653,40 @@ def _reassign_node_ids(graph: Graph) -> None:
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
# Create mapping from old node IDs to new UUIDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
# Reassign node IDs
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Update link references to use new node IDs
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
||||
link.id = str(uuid.uuid4())
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||
"""Populate user_id in AgentExecutorBlock nodes.
|
||||
|
||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict (modified in place)
|
||||
user_id: User ID to set
|
||||
"""
|
||||
for node in agent_json.get("nodes", []):
|
||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = node.get("input_default") or {}
|
||||
if not input_default.get("user_id"):
|
||||
input_default["user_id"] = user_id
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||
)
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
@@ -173,33 +700,27 @@ async def save_agent_to_library(
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
from backend.data.graph import get_graph_all_versions
|
||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
# For updates, keep the same graph ID but increment version
|
||||
# and reassign node/link IDs to avoid conflicts
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
# Reassign node IDs (but keep graph ID the same)
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
# For new agents, always generate a fresh UUID to avoid collisions
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
# Reassign all node IDs as well
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
# Save to database
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
# Add to user's library (or update existing library agent)
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
@@ -211,25 +732,31 @@ async def save_agent_to_library(
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
graph_id: str, user_id: str | None
|
||||
agent_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
graph_id: Graph ID or library agent ID
|
||||
agent_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
from backend.data.graph import get_graph
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
# Try to get the graph (version=None gets the active version)
|
||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
# Convert to JSON format
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
@@ -269,8 +796,7 @@ async def get_agent_as_json(
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -282,12 +808,11 @@ async def generate_agent_patch(
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -295,5 +820,5 @@ async def generate_agent_patch(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request, current_agent, operation_id, task_id
|
||||
update_request, current_agent, _to_dict_list(library_agents)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,43 @@
|
||||
"""Error handling utilities for agent generator."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _sanitize_error_details(details: str) -> str:
|
||||
"""Sanitize error details to remove sensitive information.
|
||||
|
||||
Strips common patterns that could expose internal system info:
|
||||
- File paths (Unix and Windows)
|
||||
- Database connection strings
|
||||
- URLs with credentials
|
||||
- Stack trace internals
|
||||
|
||||
Args:
|
||||
details: Raw error details string
|
||||
|
||||
Returns:
|
||||
Sanitized error details safe for user display
|
||||
"""
|
||||
sanitized = re.sub(
|
||||
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
|
||||
)
|
||||
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
|
||||
sanitized = re.sub(
|
||||
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
|
||||
)
|
||||
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
|
||||
sanitized = re.sub(r", line \d+", "", sanitized)
|
||||
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
|
||||
|
||||
return sanitized.strip()
|
||||
|
||||
|
||||
def get_user_message_for_error(
|
||||
error_type: str,
|
||||
operation: str = "process the request",
|
||||
llm_parse_message: str | None = None,
|
||||
validation_message: str | None = None,
|
||||
error_details: str | None = None,
|
||||
) -> str:
|
||||
"""Get a user-friendly error message based on error type.
|
||||
|
||||
@@ -19,25 +51,45 @@ def get_user_message_for_error(
|
||||
message (e.g., "analyze the goal", "generate the agent")
|
||||
llm_parse_message: Custom message for llm_parse_error type
|
||||
validation_message: Custom message for validation_error type
|
||||
error_details: Optional additional details about the error
|
||||
|
||||
Returns:
|
||||
User-friendly error message suitable for display to the user
|
||||
"""
|
||||
base_message = ""
|
||||
|
||||
if error_type == "llm_parse_error":
|
||||
return (
|
||||
base_message = (
|
||||
llm_parse_message
|
||||
or "The AI had trouble processing this request. Please try again."
|
||||
)
|
||||
elif error_type == "validation_error":
|
||||
return (
|
||||
base_message = (
|
||||
validation_message
|
||||
or "The request failed validation. Please try rephrasing."
|
||||
or "The generated agent failed validation. "
|
||||
"This usually happens when the agent structure doesn't match "
|
||||
"what the platform expects. Please try simplifying your goal "
|
||||
"or breaking it into smaller parts."
|
||||
)
|
||||
elif error_type == "patch_error":
|
||||
return "Failed to apply the changes. Please try a different approach."
|
||||
base_message = (
|
||||
"Failed to apply the changes. The modification couldn't be "
|
||||
"validated. Please try a different approach or simplify the change."
|
||||
)
|
||||
elif error_type in ("timeout", "llm_timeout"):
|
||||
return "The request took too long. Please try again."
|
||||
base_message = (
|
||||
"The request took too long to process. This can happen with "
|
||||
"complex agents. Please try again or simplify your goal."
|
||||
)
|
||||
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||
return "The service is currently busy. Please try again in a moment."
|
||||
base_message = "The service is currently busy. Please try again in a moment."
|
||||
else:
|
||||
return f"Failed to {operation}. Please try again."
|
||||
base_message = f"Failed to {operation}. Please try again."
|
||||
|
||||
if error_details:
|
||||
details = _sanitize_error_details(error_details)
|
||||
if len(details) > 200:
|
||||
details = details[:200] + "..."
|
||||
base_message += f"\n\nTechnical details: {details}"
|
||||
|
||||
return base_message
|
||||
|
||||
@@ -117,13 +117,16 @@ def _get_client() -> httpx.AsyncClient:
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str, context: str = ""
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
@@ -141,6 +144,8 @@ async def decompose_goal_external(
|
||||
if context:
|
||||
# The external service uses user_instruction for additional context
|
||||
payload["user_instruction"] = context
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
@@ -207,42 +212,25 @@ async def decompose_goal_external(
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -250,8 +238,7 @@ async def generate_agent_external(
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
@@ -274,46 +261,29 @@ async def generate_agent_external(
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
@@ -19,6 +20,85 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
Tries multiple lookup strategies:
|
||||
1. First by graph_id (AgentGraph primary key)
|
||||
2. Then by library agent ID (LibraryAgent primary key)
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
@@ -69,29 +149,37 @@ async def search_agents(
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else: # library
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
else:
|
||||
if _is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,9 @@ from backend.api.features.chat.model import ChatSession
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -16,7 +18,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -97,10 +98,6 @@ class CreateAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
@@ -108,9 +105,24 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Decompose goal into steps
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=description,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -129,7 +141,6 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if decomposition_result.get("type") == "error":
|
||||
error_msg = decomposition_result.get("error", "Unknown error")
|
||||
error_type = decomposition_result.get("error_type", "unknown")
|
||||
@@ -149,7 +160,6 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
@@ -168,7 +178,6 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check for unachievable/vague goals
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
@@ -195,13 +204,22 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||
if user_id and library_agents is not None:
|
||||
try:
|
||||
library_agents = await enrich_library_agents_from_steps(
|
||||
user_id=user_id,
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=library_agents,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
try:
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -220,7 +238,6 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||
error_msg = agent_json.get("error", "Unknown error")
|
||||
error_type = agent_json.get("error_type", "unknown")
|
||||
@@ -228,7 +245,12 @@ class CreateAgentTool(BaseTool):
|
||||
error_type,
|
||||
operation="generate the agent",
|
||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||
validation_message="The generated agent failed validation. Please try rephrasing your goal.",
|
||||
validation_message=(
|
||||
"I wasn't able to create a valid agent for this request. "
|
||||
"The generated workflow had some structural issues. "
|
||||
"Please try simplifying your goal or breaking it into smaller steps."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
@@ -241,25 +263,11 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
@@ -274,7 +282,6 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
@@ -292,7 +299,7 @@ class CreateAgentTool(BaseTool):
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -16,7 +17,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -104,10 +104,6 @@ class EditAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -122,7 +118,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Fetch current agent
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
@@ -132,18 +127,29 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build the update request with context
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
update_request, current_agent, library_agents
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -163,20 +169,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
@@ -185,6 +177,7 @@ class EditAgentTool(BaseTool):
|
||||
operation="generate the changes",
|
||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
@@ -198,7 +191,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
@@ -217,7 +209,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result is the updated agent JSON
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
@@ -225,7 +216,6 @@ class EditAgentTool(BaseTool):
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
@@ -241,7 +231,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library (creates a new version)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
@@ -259,7 +248,7 @@ class EditAgentTool(BaseTool):
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -352,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
@@ -384,20 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer
|
||||
will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -266,13 +266,14 @@ async def match_user_credentials_to_graph(
|
||||
credential_requirements,
|
||||
_node_fields,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider and type
|
||||
# Find first matching credential by provider, type, and scopes
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
and _credential_has_required_scopes(cred, credential_requirements)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -296,10 +297,17 @@ async def match_user_credentials_to_graph(
|
||||
f"{credential_field_name} (validation failed: {e})"
|
||||
)
|
||||
else:
|
||||
# Build a helpful error message including scope requirements
|
||||
error_parts = [
|
||||
f"provider in {list(credential_requirements.provider)}",
|
||||
f"type in {list(credential_requirements.supported_types)}",
|
||||
]
|
||||
if credential_requirements.required_scopes:
|
||||
error_parts.append(
|
||||
f"scopes including {list(credential_requirements.required_scopes)}"
|
||||
)
|
||||
missing_creds.append(
|
||||
f"{credential_field_name} "
|
||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
||||
f"type in {list(credential_requirements.supported_types)})"
|
||||
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -309,6 +317,28 @@ async def match_user_credentials_to_graph(
|
||||
return graph_credentials_inputs, missing_creds
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
@@ -39,6 +39,7 @@ async def list_library_agents(
|
||||
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
include_executions: bool = False,
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Retrieves a paginated list of LibraryAgent records for a given user.
|
||||
@@ -49,6 +50,9 @@ async def list_library_agents(
|
||||
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
||||
page: Current page (1-indexed).
|
||||
page_size: Number of items per page.
|
||||
include_executions: Whether to include execution data for status calculation.
|
||||
Defaults to False for performance (UI fetches status separately).
|
||||
Set to True when accurate status/metrics are needed (e.g., agent generator).
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing the list of agents and pagination details.
|
||||
@@ -76,7 +80,6 @@ async def list_library_agents(
|
||||
"isArchived": False,
|
||||
}
|
||||
|
||||
# Build search filter if applicable
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{
|
||||
@@ -93,7 +96,6 @@ async def list_library_agents(
|
||||
},
|
||||
]
|
||||
|
||||
# Determine sorting
|
||||
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
||||
|
||||
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
|
||||
@@ -105,7 +107,7 @@ async def list_library_agents(
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
user_id, include_nodes=False, include_executions=include_executions
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
|
||||
@@ -9,6 +9,7 @@ import pydantic
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,10 +17,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
COMPLETED = "COMPLETED" # All runs completed
|
||||
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
|
||||
WAITING = "WAITING" # Agent is queued or waiting to start
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
COMPLETED = "COMPLETED"
|
||||
HEALTHY = "HEALTHY"
|
||||
WAITING = "WAITING"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||
@@ -39,6 +40,30 @@ class MarketplaceListing(pydantic.BaseModel):
|
||||
creator: MarketplaceListingCreator
|
||||
|
||||
|
||||
class RecentExecution(pydantic.BaseModel):
|
||||
"""Summary of a recent execution for quality assessment.
|
||||
|
||||
Used by the LLM to understand the agent's recent performance with specific examples
|
||||
rather than just aggregate statistics.
|
||||
"""
|
||||
|
||||
status: str
|
||||
correctness_score: float | None = None
|
||||
activity_summary: str | None = None
|
||||
|
||||
|
||||
def _parse_settings(settings: dict | str | None) -> GraphSettings:
|
||||
"""Parse settings from database, handling both dict and string formats."""
|
||||
if settings is None:
|
||||
return GraphSettings()
|
||||
try:
|
||||
if isinstance(settings, str):
|
||||
settings = json_loads(settings)
|
||||
return GraphSettings.model_validate(settings)
|
||||
except Exception:
|
||||
return GraphSettings()
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -48,7 +73,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
owner_user_id: str
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -64,7 +89,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||
description="Input schema for credentials required by the agent",
|
||||
@@ -81,25 +106,19 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
)
|
||||
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
|
||||
# Whether the user can access the underlying graph
|
||||
execution_count: int = 0
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
recent_executions: list[RecentExecution] = pydantic.Field(
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
can_access_graph: bool
|
||||
|
||||
# Indicates if this agent is the latest version
|
||||
is_latest_version: bool
|
||||
|
||||
# Whether the agent is marked as favorite by the user
|
||||
is_favorite: bool
|
||||
|
||||
# Recommended schedule cron (from marketplace agents)
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
# User-specific settings for this library agent
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
|
||||
# Marketplace listing information if the agent has been published
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
|
||||
@staticmethod
|
||||
@@ -123,7 +142,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
agent_updated_at = agent.AgentGraph.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Compute updated_at as the latest between library agent and graph
|
||||
updated_at = (
|
||||
max(agent_updated_at, lib_agent_updated_at)
|
||||
if agent_updated_at
|
||||
@@ -136,7 +154,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
creator_name = agent.Creator.name or "Unknown"
|
||||
creator_image_url = agent.Creator.avatarUrl or ""
|
||||
|
||||
# Logic to calculate status and new_output
|
||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=7
|
||||
)
|
||||
@@ -145,13 +162,55 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
# Check if user can access the graph
|
||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||
execution_count = len(executions)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
|
||||
)
|
||||
success_rate = (success_count / execution_count) * 100
|
||||
|
||||
# Hard-coded to True until a method to check is implemented
|
||||
correctness_scores = []
|
||||
for e in executions:
|
||||
if e.stats and isinstance(e.stats, dict):
|
||||
score = e.stats.get("correctness_score")
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
correctness_scores.append(float(score))
|
||||
if correctness_scores:
|
||||
avg_correctness_score = sum(correctness_scores) / len(
|
||||
correctness_scores
|
||||
)
|
||||
|
||||
recent_executions: list[RecentExecution] = []
|
||||
for e in executions:
|
||||
exec_score: float | None = None
|
||||
exec_summary: str | None = None
|
||||
if e.stats and isinstance(e.stats, dict):
|
||||
score = e.stats.get("correctness_score")
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
exec_score = float(score)
|
||||
summary = e.stats.get("activity_status")
|
||||
if summary is not None and isinstance(summary, str):
|
||||
exec_summary = summary
|
||||
exec_status = (
|
||||
e.executionStatus.value
|
||||
if hasattr(e.executionStatus, "value")
|
||||
else str(e.executionStatus)
|
||||
)
|
||||
recent_executions.append(
|
||||
RecentExecution(
|
||||
status=exec_status,
|
||||
correctness_score=exec_score,
|
||||
activity_summary=exec_summary,
|
||||
)
|
||||
)
|
||||
|
||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||
is_latest_version = True
|
||||
|
||||
# Build marketplace_listing if available
|
||||
marketplace_listing_data = None
|
||||
if store_listing and store_listing.ActiveVersion and profile:
|
||||
creator_data = MarketplaceListingCreator(
|
||||
@@ -190,11 +249,15 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
has_sensitive_action=graph.has_sensitive_action,
|
||||
trigger_setup_info=graph.trigger_setup_info,
|
||||
new_output=new_output,
|
||||
execution_count=execution_count,
|
||||
success_rate=success_rate,
|
||||
avg_correctness_score=avg_correctness_score,
|
||||
recent_executions=recent_executions,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
settings=GraphSettings.model_validate(agent.settings),
|
||||
settings=_parse_settings(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
)
|
||||
|
||||
@@ -220,18 +283,15 @@ def _calculate_agent_status(
|
||||
if not executions:
|
||||
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
|
||||
|
||||
# Track how many times each execution status appears
|
||||
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
|
||||
new_output = False
|
||||
|
||||
for execution in executions:
|
||||
# Check if there's a completed run more recent than `recent_threshold`
|
||||
if execution.createdAt >= recent_threshold:
|
||||
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
|
||||
new_output = True
|
||||
status_counts[execution.executionStatus] += 1
|
||||
|
||||
# Determine the final status based on counts
|
||||
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
|
||||
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
|
||||
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
|
||||
|
||||
@@ -40,10 +40,6 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -122,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for RabbitMQ notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
@@ -31,6 +31,10 @@
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"execution_count": 0,
|
||||
"success_rate": null,
|
||||
"avg_correctness_score": null,
|
||||
"recent_executions": [],
|
||||
"can_access_graph": true,
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
@@ -72,6 +76,10 @@
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"execution_count": 0,
|
||||
"success_rate": null,
|
||||
"avg_correctness_score": null,
|
||||
"recent_executions": [],
|
||||
"can_access_graph": false,
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
|
||||
@@ -57,7 +57,8 @@ class TestDecomposeGoal:
|
||||
|
||||
result = await core.decompose_goal("Build a chatbot")
|
||||
|
||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Build a chatbot", "", None)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -74,7 +75,8 @@ class TestDecomposeGoal:
|
||||
|
||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python", None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_service_failure(self):
|
||||
@@ -109,7 +111,8 @@ class TestGenerateAgent:
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await core.generate_agent(instructions)
|
||||
|
||||
mock_external.assert_called_once_with(instructions)
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with(instructions, None)
|
||||
# Result should have id, version, is_active added if not present
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Agent"
|
||||
@@ -174,7 +177,8 @@ class TestGenerateAgentPatch:
|
||||
current_agent = {"nodes": [], "links": []}
|
||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||
|
||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -0,0 +1,841 @@
|
||||
"""
|
||||
Tests for library agent fetching functionality in agent generator.
|
||||
|
||||
This test suite verifies the search-based library agent fetching,
|
||||
including the combination of library and marketplace agents.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
|
||||
|
||||
class TestGetLibraryAgentsForGeneration:
|
||||
"""Test get_library_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_agents_with_search_term(self):
|
||||
"""Test that search_term is passed to the library db."""
|
||||
# Create a mock agent with proper attribute values
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "agent-123"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Email Agent"
|
||||
mock_agent.description = "Sends emails"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
mock_agent.recent_executions = []
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [mock_agent]
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="send email",
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term="send email",
|
||||
page=1,
|
||||
page_size=15,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
# Verify result format
|
||||
assert len(result) == 1
|
||||
assert result[0]["graph_id"] == "agent-123"
|
||||
assert result[0]["name"] == "Email Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excludes_specified_graph_id(self):
|
||||
"""Test that agents with excluded graph_id are filtered out."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [
|
||||
MagicMock(
|
||||
graph_id="agent-123",
|
||||
graph_version=1,
|
||||
name="Agent 1",
|
||||
description="First agent",
|
||||
input_schema={},
|
||||
output_schema={},
|
||||
recent_executions=[],
|
||||
),
|
||||
MagicMock(
|
||||
graph_id="agent-456",
|
||||
graph_version=1,
|
||||
name="Agent 2",
|
||||
description="Second agent",
|
||||
input_schema={},
|
||||
output_schema={},
|
||||
recent_executions=[],
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
exclude_graph_id="agent-123",
|
||||
)
|
||||
|
||||
# Verify the excluded agent is not in results
|
||||
assert len(result) == 1
|
||||
assert result[0]["graph_id"] == "agent-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_max_results(self):
|
||||
"""Test that max_results parameter limits the page_size."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term=None,
|
||||
page=1,
|
||||
page_size=5,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
|
||||
class TestSearchMarketplaceAgentsForGeneration:
|
||||
"""Test search_marketplace_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searches_marketplace_with_query(self):
|
||||
"""Test that marketplace is searched with the query."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [
|
||||
MagicMock(
|
||||
agent_name="Public Agent",
|
||||
description="A public agent",
|
||||
sub_heading="Does something useful",
|
||||
creator="creator-1",
|
||||
)
|
||||
]
|
||||
|
||||
# The store_db is dynamically imported, so patch the import path
|
||||
with patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_search:
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="automation",
|
||||
max_results=10,
|
||||
)
|
||||
|
||||
mock_search.assert_called_once_with(
|
||||
search_query="automation",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Public Agent"
|
||||
assert result[0]["is_marketplace_agent"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_marketplace_error_gracefully(self):
|
||||
"""Test that marketplace errors don't crash the function."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Marketplace unavailable"),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="test"
|
||||
)
|
||||
|
||||
# Should return empty list, not raise exception
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGetAllRelevantAgentsForGeneration:
|
||||
"""Test get_all_relevant_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combines_library_and_marketplace_agents(self):
|
||||
"""Test that agents from both sources are combined."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
marketplace_agents = [
|
||||
{
|
||||
"name": "Market Agent",
|
||||
"description": "From marketplace",
|
||||
"sub_heading": "Sub heading",
|
||||
"creator": "creator-1",
|
||||
"is_marketplace_agent": True,
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=marketplace_agents,
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test query",
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Library agents should come first
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "Library Agent"
|
||||
assert result[1]["name"] == "Market Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_name(self):
|
||||
"""Test that marketplace agents with same name as library are excluded."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Shared Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
marketplace_agents = [
|
||||
{
|
||||
"name": "Shared Agent", # Same name, should be deduplicated
|
||||
"description": "From marketplace",
|
||||
"sub_heading": "Sub heading",
|
||||
"creator": "creator-1",
|
||||
"is_marketplace_agent": True,
|
||||
},
|
||||
{
|
||||
"name": "Unique Agent",
|
||||
"description": "Only in marketplace",
|
||||
"sub_heading": "Sub heading",
|
||||
"creator": "creator-2",
|
||||
"is_marketplace_agent": True,
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=marketplace_agents,
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test",
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Shared Agent from marketplace should be excluded
|
||||
assert len(result) == 2
|
||||
names = [a["name"] for a in result]
|
||||
assert "Shared Agent" in names
|
||||
assert "Unique Agent" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_marketplace_when_disabled(self):
|
||||
"""Test that marketplace is not searched when include_marketplace=False."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_marketplace:
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test",
|
||||
include_marketplace=False,
|
||||
)
|
||||
|
||||
# Marketplace should not be called
|
||||
mock_marketplace.assert_not_called()
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_marketplace_when_no_search_query(self):
|
||||
"""Test that marketplace is not searched without a search query."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_marketplace:
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query=None, # No search query
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Marketplace should not be called without search query
|
||||
mock_marketplace.assert_not_called()
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestExtractSearchTermsFromSteps:
|
||||
"""Test extract_search_terms_from_steps function."""
|
||||
|
||||
def test_extracts_terms_from_instructions_type(self):
|
||||
"""Test extraction from valid instructions decomposition result."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{
|
||||
"description": "Send an email notification",
|
||||
"block_name": "GmailSendBlock",
|
||||
},
|
||||
{"description": "Fetch weather data", "action": "Get weather API"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert "Send an email notification" in result
|
||||
assert "GmailSendBlock" in result
|
||||
assert "Fetch weather data" in result
|
||||
assert "Get weather API" in result
|
||||
|
||||
def test_returns_empty_for_non_instructions_type(self):
|
||||
"""Test that non-instructions types return empty list."""
|
||||
decomposition_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [{"question": "What email?"}],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_deduplicates_terms_case_insensitively(self):
|
||||
"""Test that duplicate terms are removed (case-insensitive)."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "Send Email", "name": "send email"},
|
||||
{"description": "Other task"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
# Should only have one "send email" variant
|
||||
email_terms = [t for t in result if "email" in t.lower()]
|
||||
assert len(email_terms) == 1
|
||||
|
||||
def test_filters_short_terms(self):
|
||||
"""Test that terms with 3 or fewer characters are filtered out."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "ab", "action": "xyz"}, # Both too short
|
||||
{"description": "Valid term here"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert "ab" not in result
|
||||
assert "xyz" not in result
|
||||
assert "Valid term here" in result
|
||||
|
||||
def test_handles_empty_steps(self):
|
||||
"""Test handling of empty steps list."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestEnrichLibraryAgentsFromSteps:
|
||||
"""Test enrich_library_agents_from_steps function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enriches_with_additional_agents(self):
|
||||
"""Test that additional agents are found based on steps."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "existing-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "new-456",
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent",
|
||||
"description": "For sending emails",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "Send email notification"},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should have both existing and new agents
|
||||
assert len(result) == 2
|
||||
names = [a["name"] for a in result]
|
||||
assert "Existing Agent" in names
|
||||
assert "Email Agent" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_graph_id(self):
|
||||
"""Test that agents with same graph_id are not duplicated."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Additional search returns same agent
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "agent-123", # Same ID
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent Copy",
|
||||
"description": "Same agent different name",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [{"description": "Some action"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should not duplicate
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_name(self):
|
||||
"""Test that agents with same name are not duplicated."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Additional search returns agent with same name but different ID
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "agent-456", # Different ID
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent", # Same name
|
||||
"description": "Different agent same name",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [{"description": "Send email"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should not duplicate by name
|
||||
assert len(result) == 1
|
||||
assert result[0].get("graph_id") == "agent-123" # Original kept
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_existing_when_no_steps(self):
|
||||
"""Test that existing agents are returned when no search terms extracted."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "existing-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "clarifying_questions", # Not instructions type
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should return existing unchanged
|
||||
assert result == existing_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limits_search_terms_to_three(self):
|
||||
"""Test that only first 3 search terms are used."""
|
||||
existing_agents = []
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "First action"},
|
||||
{"description": "Second action"},
|
||||
{"description": "Third action"},
|
||||
{"description": "Fourth action"},
|
||||
{"description": "Fifth action"},
|
||||
],
|
||||
}
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_agents(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return []
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
side_effect=mock_get_agents,
|
||||
):
|
||||
await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should only make 3 calls (limited to first 3 terms)
|
||||
assert call_count == 3
|
||||
|
||||
|
||||
class TestExtractUuidsFromText:
|
||||
"""Test extract_uuids_from_text function."""
|
||||
|
||||
def test_extracts_single_uuid(self):
|
||||
"""Test extraction of a single UUID from text."""
|
||||
text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 1
|
||||
assert "46631191-e8a8-486f-ad90-84f89738321d" in result
|
||||
|
||||
def test_extracts_multiple_uuids(self):
|
||||
"""Test extraction of multiple UUIDs from text."""
|
||||
text = (
|
||||
"Combine agents 11111111-1111-4111-8111-111111111111 "
|
||||
"and 22222222-2222-4222-9222-222222222222"
|
||||
)
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 2
|
||||
assert "11111111-1111-4111-8111-111111111111" in result
|
||||
assert "22222222-2222-4222-9222-222222222222" in result
|
||||
|
||||
def test_deduplicates_uuids(self):
|
||||
"""Test that duplicate UUIDs are deduplicated."""
|
||||
text = (
|
||||
"Use 46631191-e8a8-486f-ad90-84f89738321d twice: "
|
||||
"46631191-e8a8-486f-ad90-84f89738321d"
|
||||
)
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_normalizes_to_lowercase(self):
|
||||
"""Test that UUIDs are normalized to lowercase."""
|
||||
text = "Use 46631191-E8A8-486F-AD90-84F89738321D"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
|
||||
def test_returns_empty_for_no_uuids(self):
|
||||
"""Test that empty list is returned when no UUIDs found."""
|
||||
text = "Create an email agent that sends notifications"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert result == []
|
||||
|
||||
def test_ignores_invalid_uuids(self):
|
||||
"""Test that invalid UUID-like strings are ignored."""
|
||||
text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
# UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestGetLibraryAgentById:
|
||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_agent_when_found_by_graph_id(self):
|
||||
"""Test that agent is returned when found by graph_id."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "agent-123"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Test Agent"
|
||||
mock_agent.description = "Test description"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is not None
|
||||
assert result["graph_id"] == "agent-123"
|
||||
assert result["name"] == "Test Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_library_agent_id(self):
|
||||
"""Test that lookup falls back to library agent ID when graph_id not found."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "graph-456" # Different from the lookup ID
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Library Agent"
|
||||
mock_agent.description = "Found by library ID"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Not found by graph_id
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent, # Found by library ID
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||
|
||||
assert result is not None
|
||||
assert result["graph_id"] == "graph-456"
|
||||
assert result["name"] == "Library Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found_by_either_method(self):
|
||||
"""Test that None is returned when agent not found by either method."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=core.NotFoundError("Not found"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_exception(self):
|
||||
"""Test that None is returned when exception occurs in both lookups."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_works(self):
|
||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||
|
||||
|
||||
class TestGetAllRelevantAgentsWithUuids:
|
||||
"""Test UUID extraction in get_all_relevant_agents_for_generation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_explicitly_mentioned_agents(self):
|
||||
"""Test that agents mentioned by UUID are fetched directly."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Mentioned Agent"
|
||||
mock_agent.description = "Explicitly mentioned"
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||
include_marketplace=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
assert result[0].get("name") == "Mentioned Agent"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -433,5 +433,139 @@ class TestGetBlocksExternal:
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLibraryAgentsPassthrough:
|
||||
"""Test that library_agents are passed correctly in all requests."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_passes_library_agents(self):
|
||||
"""Test that library_agents are included in decompose goal payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Email Sender",
|
||||
"description": "Sends emails",
|
||||
"input_schema": {"properties": {"to": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"sent": {"type": "boolean"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external(
|
||||
"Send an email",
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_passes_library_agents(self):
|
||||
"""Test that library_agents are included in generate agent payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-456",
|
||||
"graph_version": 2,
|
||||
"name": "Data Fetcher",
|
||||
"description": "Fetches data from API",
|
||||
"input_schema": {"properties": {"url": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"data": {"type": "object"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.generate_agent_external(
|
||||
{"steps": ["Step 1"]},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_patch_passes_library_agents(self):
|
||||
"""Test that library_agents are included in patch generation payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-789",
|
||||
"graph_version": 1,
|
||||
"name": "Slack Notifier",
|
||||
"description": "Sends Slack messages",
|
||||
"input_schema": {"properties": {"message": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"success": {"type": "boolean"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.generate_agent_patch_external(
|
||||
"Add error handling",
|
||||
{"name": "Original Agent", "nodes": []},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_without_library_agents(self):
|
||||
"""Test that decompose goal works without library_agents."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external("Build a workflow")
|
||||
|
||||
# Verify library_agents was NOT passed when not provided
|
||||
call_args = mock_client.post.call_args
|
||||
assert "library_agents" not in call_args[1]["json"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -857,7 +857,7 @@ export const CustomNode = React.memo(
|
||||
})();
|
||||
|
||||
const hasAdvancedFields =
|
||||
data.inputSchema &&
|
||||
data.inputSchema?.properties &&
|
||||
Object.entries(data.inputSchema.properties).some(([key, value]) => {
|
||||
return (
|
||||
value.advanced === true && !data.inputSchema.required?.includes(key)
|
||||
|
||||
@@ -11,6 +11,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { useRef } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
@@ -69,16 +70,41 @@ export function useCopilotShell() {
|
||||
});
|
||||
|
||||
const stopStream = useChatStore((s) => s.stopStream);
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
async function stopCurrentStream() {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
setIsSwitchingSession(true);
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsubscribe = onStreamComplete((completedId) => {
|
||||
if (completedId === currentSessionId) {
|
||||
clearTimeout(timeout);
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
const timeout = setTimeout(() => {
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}, 3000);
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||
});
|
||||
setIsSwitchingSession(false);
|
||||
}
|
||||
|
||||
function selectSession(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
@@ -88,12 +114,7 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
|
||||
function startNewChat() {
|
||||
resetPagination();
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
@@ -102,6 +123,32 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
selectSession(sessionId);
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
selectSession(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
startNewChat();
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
startNewChat();
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -4,25 +4,53 @@ import { create } from "zustand";
|
||||
|
||||
interface CopilotStoreState {
|
||||
isStreaming: boolean;
|
||||
isSwitchingSession: boolean;
|
||||
isCreatingSession: boolean;
|
||||
isInterruptModalOpen: boolean;
|
||||
pendingAction: (() => void) | null;
|
||||
}
|
||||
|
||||
interface CopilotStoreActions {
|
||||
setIsStreaming: (isStreaming: boolean) => void;
|
||||
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||
setIsCreatingSession: (isCreating: boolean) => void;
|
||||
openInterruptModal: (onConfirm: () => void) => void;
|
||||
confirmInterrupt: () => void;
|
||||
cancelInterrupt: () => void;
|
||||
}
|
||||
|
||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||
|
||||
export const useCopilotStore = create<CopilotStore>((set) => ({
|
||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||
isStreaming: false,
|
||||
isSwitchingSession: false,
|
||||
isCreatingSession: false,
|
||||
isInterruptModalOpen: false,
|
||||
pendingAction: null,
|
||||
|
||||
setIsStreaming(isStreaming) {
|
||||
set({ isStreaming });
|
||||
},
|
||||
|
||||
setIsSwitchingSession(isSwitchingSession) {
|
||||
set({ isSwitchingSession });
|
||||
},
|
||||
|
||||
setIsCreatingSession(isCreatingSession) {
|
||||
set({ isCreatingSession });
|
||||
},
|
||||
|
||||
openInterruptModal(onConfirm) {
|
||||
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||
},
|
||||
|
||||
confirmInterrupt() {
|
||||
const { pendingAction } = get();
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
if (pendingAction) pendingAction();
|
||||
},
|
||||
|
||||
cancelInterrupt() {
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -5,10 +5,15 @@ import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
const { state, handlers } = useCopilotPage();
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
@@ -35,6 +40,42 @@ export default function CopilotPage() {
|
||||
onSessionNotFound={handleSessionNotFound}
|
||||
onStreamingChange={handleStreamingChange}
|
||||
/>
|
||||
<Dialog
|
||||
title="Interrupt current chat?"
|
||||
styling={{ maxWidth: 300, width: "100%" }}
|
||||
controlled={{
|
||||
isOpen: isInterruptModalOpen,
|
||||
set: (open) => {
|
||||
if (!open) cancelInterrupt();
|
||||
},
|
||||
}}
|
||||
onClose={cancelInterrupt}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text variant="body">
|
||||
The current chat response will be interrupted. Are you sure you
|
||||
want to continue?
|
||||
</Text>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={cancelInterrupt}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
onClick={confirmInterrupt}
|
||||
>
|
||||
Continue
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
@@ -25,12 +26,20 @@ export function useCopilotPage() {
|
||||
const queryClient = useQueryClient();
|
||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||
|
||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) {
|
||||
completeStep("VISIT_COPILOT");
|
||||
}
|
||||
}, [completeStep, isLoggedIn]);
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import { environment } from "@/services/environment";
|
||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
/**
|
||||
* SSE Proxy for task stream reconnection.
|
||||
*
|
||||
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
||||
* background task's stream. It replays missed messages from Redis Streams and
|
||||
* subscribes to live updates if the task is still running.
|
||||
*
|
||||
* Client contract:
|
||||
* 1. When receiving an operation_started event, store the task_id
|
||||
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
* 3. Messages are replayed from the last_message_id position
|
||||
* 4. Stream ends when "finish" event is received
|
||||
*/
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ taskId: string }> },
|
||||
) {
|
||||
const { taskId } = await params;
|
||||
const searchParams = request.nextUrl.searchParams;
|
||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
||||
|
||||
try {
|
||||
// Get auth token from server-side session
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
// Build backend URL
|
||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
||||
|
||||
// Forward request to backend with auth header
|
||||
const headers: Record<string, string> = {
|
||||
Accept: "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
Connection: "keep-alive",
|
||||
};
|
||||
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(streamUrl.toString(), {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
return new Response(error, {
|
||||
status: response.status,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
// Return the SSE stream directly
|
||||
return new Response(response.body, {
|
||||
headers: {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
Connection: "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Task stream proxy error:", error);
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
error: "Failed to connect to task stream",
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
}),
|
||||
{
|
||||
status: 500,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -939,63 +939,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/operations/{operation_id}/complete": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Complete Operation",
|
||||
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
||||
"operationId": "postV2CompleteOperation",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "operation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Operation Id" }
|
||||
},
|
||||
{
|
||||
"name": "x-api-key",
|
||||
"in": "header",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "X-Api-Key"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/OperationCompleteRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Postv2Completeoperation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -1252,94 +1195,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Task Status",
|
||||
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2GetTaskStatus",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Getv2Gettaskstatus"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Task",
|
||||
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2StreamTask",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
},
|
||||
{
|
||||
"name": "last_message_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
"default": "0-0",
|
||||
"title": "Last Message Id"
|
||||
},
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -8126,6 +7981,25 @@
|
||||
]
|
||||
},
|
||||
"new_output": { "type": "boolean", "title": "New Output" },
|
||||
"execution_count": {
|
||||
"type": "integer",
|
||||
"title": "Execution Count",
|
||||
"default": 0
|
||||
},
|
||||
"success_rate": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Success Rate"
|
||||
},
|
||||
"avg_correctness_score": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Avg Correctness Score"
|
||||
},
|
||||
"recent_executions": {
|
||||
"items": { "$ref": "#/components/schemas/RecentExecution" },
|
||||
"type": "array",
|
||||
"title": "Recent Executions",
|
||||
"description": "List of recent executions with status, score, and summary"
|
||||
},
|
||||
"can_access_graph": {
|
||||
"type": "boolean",
|
||||
"title": "Can Access Graph"
|
||||
@@ -8949,27 +8823,6 @@
|
||||
],
|
||||
"title": "OnboardingStep"
|
||||
},
|
||||
"OperationCompleteRequest": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"result": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Result"
|
||||
},
|
||||
"error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Error"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["success"],
|
||||
"title": "OperationCompleteRequest",
|
||||
"description": "Request model for external completion webhook."
|
||||
},
|
||||
"Pagination": {
|
||||
"properties": {
|
||||
"total_items": {
|
||||
@@ -9540,6 +9393,23 @@
|
||||
"required": ["providers", "pagination"],
|
||||
"title": "ProviderResponse"
|
||||
},
|
||||
"RecentExecution": {
|
||||
"properties": {
|
||||
"status": { "type": "string", "title": "Status" },
|
||||
"correctness_score": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Correctness Score"
|
||||
},
|
||||
"activity_summary": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Activity Summary"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["status"],
|
||||
"title": "RecentExecution",
|
||||
"description": "Summary of a recent execution for quality assessment.\n\nUsed by the LLM to understand the agent's recent performance with specific examples\nrather than just aggregate statistics."
|
||||
},
|
||||
"RefundRequest": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
@@ -24,6 +25,7 @@ export function Chat({
|
||||
}: ChatProps) {
|
||||
const { urlSessionId } = useCopilotSessionId();
|
||||
const hasHandledNotFoundRef = useRef(false);
|
||||
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
@@ -51,7 +53,8 @@ export function Chat({
|
||||
isCreating,
|
||||
]);
|
||||
|
||||
const shouldShowLoader = showLoader && (isLoading || isCreating);
|
||||
const shouldShowLoader =
|
||||
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
@@ -63,19 +66,21 @@ export function Chat({
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chat...
|
||||
{isSwitchingSession
|
||||
? "Switching chat..."
|
||||
: "Loading your chat..."}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && (
|
||||
{error && !isLoading && !isSwitchingSession && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && (
|
||||
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
# SSE Reconnection Contract for Long-Running Operations
|
||||
|
||||
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
|
||||
|
||||
## Overview
|
||||
|
||||
When a user triggers a long-running operation (like agent generation), the backend:
|
||||
1. Spawns a background task that survives SSE disconnections
|
||||
2. Returns an `operation_started` response with a `task_id`
|
||||
3. Stores stream messages in Redis Streams for replay
|
||||
|
||||
Clients can reconnect to the task stream at any time to receive missed messages.
|
||||
|
||||
## Client-Side Flow
|
||||
|
||||
### 1. Receiving Operation Started
|
||||
|
||||
When you receive an `operation_started` tool response:
|
||||
|
||||
```typescript
|
||||
// The response includes a task_id for reconnection
|
||||
{
|
||||
type: "operation_started",
|
||||
tool_name: "generate_agent",
|
||||
operation_id: "uuid-...",
|
||||
task_id: "task-uuid-...", // <-- Store this for reconnection
|
||||
message: "Operation started. You can close this tab."
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Storing Task Info
|
||||
|
||||
Use the chat store to track the active task:
|
||||
|
||||
```typescript
|
||||
import { useChatStore } from "./chat-store";
|
||||
|
||||
// When operation_started is received:
|
||||
useChatStore.getState().setActiveTask(sessionId, {
|
||||
taskId: response.task_id,
|
||||
operationId: response.operation_id,
|
||||
toolName: response.tool_name,
|
||||
lastMessageId: "0",
|
||||
});
|
||||
```
|
||||
|
||||
### 3. Reconnecting to a Task
|
||||
|
||||
To reconnect (e.g., after page refresh or tab reopen):
|
||||
|
||||
```typescript
|
||||
const { reconnectToTask, getActiveTask } = useChatStore.getState();
|
||||
|
||||
// Check if there's an active task for this session
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
|
||||
if (activeTask) {
|
||||
// Reconnect to the task stream
|
||||
await reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
activeTask.lastMessageId, // Resume from last position
|
||||
(chunk) => {
|
||||
// Handle incoming chunks
|
||||
console.log("Received chunk:", chunk);
|
||||
}
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Tracking Message Position
|
||||
|
||||
To enable precise replay, update the last message ID as chunks arrive:
|
||||
|
||||
```typescript
|
||||
const { updateTaskLastMessageId } = useChatStore.getState();
|
||||
|
||||
function handleChunk(chunk: StreamChunk) {
|
||||
// If chunk has an index/id, track it
|
||||
if (chunk.idx !== undefined) {
|
||||
updateTaskLastMessageId(sessionId, String(chunk.idx));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Task Stream Reconnection
|
||||
```
|
||||
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
```
|
||||
|
||||
- `taskId`: The task ID from `operation_started`
|
||||
- `last_message_id`: Last received message index (default: "0" for full replay)
|
||||
|
||||
Returns: SSE stream of missed messages + live updates
|
||||
|
||||
## Chunk Types
|
||||
|
||||
The reconnected stream follows the same Vercel AI SDK protocol:
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `start` | Message lifecycle start |
|
||||
| `text-delta` | Streaming text content |
|
||||
| `text-end` | Text block completed |
|
||||
| `tool-output-available` | Tool result available |
|
||||
| `finish` | Stream completed |
|
||||
| `error` | Error occurred |
|
||||
|
||||
## Error Handling
|
||||
|
||||
If reconnection fails:
|
||||
1. Check if task still exists (may have expired - default TTL: 1 hour)
|
||||
2. Fall back to polling the session for final state
|
||||
3. Show appropriate UI message to user
|
||||
|
||||
## Persistence Considerations
|
||||
|
||||
For robust reconnection across browser restarts:
|
||||
|
||||
```typescript
|
||||
// Store in localStorage/sessionStorage
|
||||
const ACTIVE_TASKS_KEY = "chat_active_tasks";
|
||||
|
||||
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
|
||||
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
tasks[sessionId] = task;
|
||||
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
|
||||
}
|
||||
|
||||
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
|
||||
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
}
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
The following backend settings affect reconnection behavior:
|
||||
|
||||
| Setting | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `stream_ttl` | 3600s | How long streams are kept in Redis |
|
||||
| `stream_max_length` | 1000 | Max messages per stream |
|
||||
|
||||
## Testing
|
||||
|
||||
To test reconnection locally:
|
||||
|
||||
1. Start a long-running operation (e.g., agent generation)
|
||||
2. Note the `task_id` from the `operation_started` response
|
||||
3. Close the browser tab
|
||||
4. Reopen and call `reconnectToTask` with the saved `task_id`
|
||||
5. Verify that missed messages are replayed
|
||||
|
||||
See the main README for full local development setup.
|
||||
@@ -8,68 +8,15 @@ import type {
|
||||
StreamResult,
|
||||
StreamStatus,
|
||||
} from "./chat-types";
|
||||
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
||||
import { executeStream } from "./stream-executor";
|
||||
|
||||
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
||||
const ACTIVE_TASKS_STORAGE_KEY = "chat_active_tasks";
|
||||
const TASK_TTL = 60 * 60 * 1000; // 1 hour - tasks expire after this
|
||||
|
||||
/**
|
||||
* Tracks active task info for SSE reconnection.
|
||||
* When a long-running operation starts, we store this so clients can reconnect
|
||||
* if the browser tab is closed and reopened.
|
||||
*/
|
||||
export interface ActiveTaskInfo {
|
||||
taskId: string;
|
||||
sessionId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
lastMessageId: string; // Last processed message ID for replay (Redis Stream format: "0-0")
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
/** Load active tasks from localStorage */
|
||||
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
||||
if (typeof window === "undefined") return new Map();
|
||||
try {
|
||||
const stored = localStorage.getItem(ACTIVE_TASKS_STORAGE_KEY);
|
||||
if (!stored) return new Map();
|
||||
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
||||
const now = Date.now();
|
||||
const tasks = new Map<string, ActiveTaskInfo>();
|
||||
// Filter out expired tasks
|
||||
for (const [sessionId, task] of Object.entries(parsed)) {
|
||||
if (now - task.startedAt < TASK_TTL) {
|
||||
tasks.set(sessionId, task);
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
} catch {
|
||||
return new Map();
|
||||
}
|
||||
}
|
||||
|
||||
/** Save active tasks to localStorage */
|
||||
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
const obj: Record<string, ActiveTaskInfo> = {};
|
||||
for (const [sessionId, task] of tasks) {
|
||||
obj[sessionId] = task;
|
||||
}
|
||||
localStorage.setItem(ACTIVE_TASKS_STORAGE_KEY, JSON.stringify(obj));
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
interface ChatStoreState {
|
||||
activeStreams: Map<string, ActiveStream>;
|
||||
completedStreams: Map<string, StreamResult>;
|
||||
activeSessions: Set<string>;
|
||||
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
||||
/** Active tasks for SSE reconnection - keyed by sessionId */
|
||||
activeTasks: Map<string, ActiveTaskInfo>;
|
||||
}
|
||||
|
||||
interface ChatStoreActions {
|
||||
@@ -94,24 +41,6 @@ interface ChatStoreActions {
|
||||
unregisterActiveSession: (sessionId: string) => void;
|
||||
isSessionActive: (sessionId: string) => boolean;
|
||||
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
||||
/** Track active task for SSE reconnection */
|
||||
setActiveTask: (
|
||||
sessionId: string,
|
||||
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
|
||||
) => void;
|
||||
/** Get active task for a session */
|
||||
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
|
||||
/** Clear active task when operation completes */
|
||||
clearActiveTask: (sessionId: string) => void;
|
||||
/** Reconnect to an existing task stream */
|
||||
reconnectToTask: (
|
||||
sessionId: string,
|
||||
taskId: string,
|
||||
lastMessageId?: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
) => Promise<void>;
|
||||
/** Update last message ID for a task (for tracking replay position) */
|
||||
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
|
||||
}
|
||||
|
||||
type ChatStore = ChatStoreState & ChatStoreActions;
|
||||
@@ -147,7 +76,6 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
completedStreams: new Map(),
|
||||
activeSessions: new Set(),
|
||||
streamCompleteCallbacks: new Set(),
|
||||
activeTasks: loadPersistedTasks(),
|
||||
|
||||
startStream: async function startStream(
|
||||
sessionId,
|
||||
@@ -358,139 +286,4 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||
};
|
||||
},
|
||||
|
||||
setActiveTask: function setActiveTask(sessionId, taskInfo) {
|
||||
const state = get();
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...taskInfo,
|
||||
sessionId,
|
||||
startedAt: Date.now(),
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
getActiveTask: function getActiveTask(sessionId) {
|
||||
return get().activeTasks.get(sessionId);
|
||||
},
|
||||
|
||||
clearActiveTask: function clearActiveTask(sessionId) {
|
||||
const state = get();
|
||||
if (!state.activeTasks.has(sessionId)) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
reconnectToTask: async function reconnectToTask(
|
||||
sessionId,
|
||||
taskId,
|
||||
lastMessageId = "0-0", // Redis Stream ID format
|
||||
onChunk,
|
||||
) {
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
// Clean up any existing stream for this session
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
const stream: ActiveStream = {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
});
|
||||
|
||||
try {
|
||||
await executeTaskReconnect(stream, taskId, lastMessageId);
|
||||
} finally {
|
||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||
if (stream.status !== "streaming") {
|
||||
const currentState = get();
|
||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||
|
||||
const storedStream = finalActiveStreams.get(sessionId);
|
||||
if (storedStream === stream) {
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
finalCompletedStreams.set(sessionId, result);
|
||||
finalActiveStreams.delete(sessionId);
|
||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||
set({
|
||||
activeStreams: finalActiveStreams,
|
||||
completedStreams: finalCompletedStreams,
|
||||
});
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
notifyStreamComplete(
|
||||
currentState.streamCompleteCallbacks,
|
||||
sessionId,
|
||||
);
|
||||
// Clear active task on completion
|
||||
const taskState = get();
|
||||
const newActiveTasks = new Map(taskState.activeTasks);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
updateTaskLastMessageId: function updateTaskLastMessageId(
|
||||
sessionId,
|
||||
lastMessageId,
|
||||
) {
|
||||
const state = get();
|
||||
const task = state.activeTasks.get(sessionId);
|
||||
if (!task) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...task,
|
||||
lastMessageId,
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -23,12 +23,6 @@ export interface HandlerDependencies {
|
||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
onOperationStarted?: () => void;
|
||||
onActiveTaskStarted?: (taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) => void;
|
||||
}
|
||||
|
||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||
@@ -170,19 +164,9 @@ export function handleToolResponse(
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Trigger polling and store task info when operation_started is received
|
||||
// Trigger polling when operation_started is received
|
||||
if (responseMessage.type === "operation_started") {
|
||||
deps.onOperationStarted?.();
|
||||
// Store task info for SSE reconnection if taskId is present
|
||||
const taskId = (responseMessage as any).taskId;
|
||||
if (taskId && deps.onActiveTaskStarted) {
|
||||
deps.onActiveTaskStarted({
|
||||
taskId,
|
||||
operationId: (responseMessage as any).operationId || "",
|
||||
toolName: (responseMessage as any).toolName || "",
|
||||
toolCallId: (responseMessage as any).toolId || "",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
deps.setMessages((prev) => {
|
||||
|
||||
@@ -349,7 +349,6 @@ export function parseToolResponse(
|
||||
toolName: (parsedResult.tool_name as string) || toolName,
|
||||
toolId,
|
||||
operationId: (parsedResult.operation_id as string) || "",
|
||||
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
|
||||
message:
|
||||
(parsedResult.message as string) ||
|
||||
"Operation started. You can close this tab.",
|
||||
|
||||
@@ -65,27 +65,8 @@ export function useChatContainer({
|
||||
} = useChatStream();
|
||||
const activeStreams = useChatStore((s) => s.activeStreams);
|
||||
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
||||
const setActiveTask = useChatStore((s) => s.setActiveTask);
|
||||
const getActiveTask = useChatStore((s) => s.getActiveTask);
|
||||
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
|
||||
// Callback to store active task info for SSE reconnection
|
||||
function handleActiveTaskStarted(taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) {
|
||||
if (!sessionId) return;
|
||||
setActiveTask(sessionId, {
|
||||
taskId: taskInfo.taskId,
|
||||
operationId: taskInfo.operationId,
|
||||
toolName: taskInfo.toolName,
|
||||
lastMessageId: "0-0", // Redis Stream ID format for full replay
|
||||
});
|
||||
}
|
||||
|
||||
useEffect(
|
||||
function handleSessionChange() {
|
||||
if (sessionId === previousSessionIdRef.current) return;
|
||||
@@ -104,34 +85,6 @@ export function useChatContainer({
|
||||
|
||||
if (!sessionId) return;
|
||||
|
||||
// Check if there's an active task for this session that we should reconnect to
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
if (activeTask) {
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
// Reconnect to the task stream
|
||||
reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
activeTask.lastMessageId,
|
||||
dispatcher,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise check for an in-memory active stream
|
||||
const activeStream = activeStreams.get(sessionId);
|
||||
if (!activeStream || activeStream.status !== "streaming") return;
|
||||
|
||||
@@ -145,7 +98,6 @@ export function useChatContainer({
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
@@ -158,8 +110,6 @@ export function useChatContainer({
|
||||
activeStreams,
|
||||
subscribeToStream,
|
||||
onOperationStarted,
|
||||
getActiveTask,
|
||||
reconnectToTask,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -275,7 +225,6 @@ export function useChatContainer({
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
try {
|
||||
|
||||
@@ -156,11 +156,19 @@ export function ChatMessage({
|
||||
}
|
||||
|
||||
if (isClarificationNeeded && message.type === "clarification_needed") {
|
||||
const hasUserReplyAfter =
|
||||
index >= 0 &&
|
||||
messages
|
||||
.slice(index + 1)
|
||||
.some((m) => m.type === "message" && m.role === "user");
|
||||
|
||||
return (
|
||||
<ClarificationQuestionsWidget
|
||||
questions={message.questions}
|
||||
message={message.message}
|
||||
sessionId={message.sessionId}
|
||||
onSubmitAnswers={handleClarificationAnswers}
|
||||
isAnswered={hasUserReplyAfter}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -111,7 +111,6 @@ export type ChatMessageData =
|
||||
toolName: string;
|
||||
toolId: string;
|
||||
operationId: string;
|
||||
taskId?: string; // For SSE reconnection
|
||||
message: string;
|
||||
timestamp?: string | Date;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CheckCircleIcon, QuestionIcon } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
|
||||
export interface ClarifyingQuestion {
|
||||
question: string;
|
||||
@@ -17,39 +17,96 @@ export interface ClarifyingQuestion {
|
||||
interface Props {
|
||||
questions: ClarifyingQuestion[];
|
||||
message: string;
|
||||
sessionId?: string;
|
||||
onSubmitAnswers: (answers: Record<string, string>) => void;
|
||||
onCancel?: () => void;
|
||||
isAnswered?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
function getStorageKey(sessionId?: string): string | null {
|
||||
if (!sessionId) return null;
|
||||
return `clarification_answers_${sessionId}`;
|
||||
}
|
||||
|
||||
export function ClarificationQuestionsWidget({
|
||||
questions,
|
||||
message,
|
||||
sessionId,
|
||||
onSubmitAnswers,
|
||||
onCancel,
|
||||
isAnswered = false,
|
||||
className,
|
||||
}: Props) {
|
||||
const [answers, setAnswers] = useState<Record<string, string>>({});
|
||||
const [isSubmitted, setIsSubmitted] = useState(false);
|
||||
const lastSessionIdRef = useRef<string | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
const storageKey = getStorageKey(sessionId);
|
||||
if (!storageKey) {
|
||||
setAnswers({});
|
||||
setIsSubmitted(false);
|
||||
lastSessionIdRef.current = sessionId;
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const saved = localStorage.getItem(storageKey);
|
||||
if (saved) {
|
||||
const parsed = JSON.parse(saved) as Record<string, string>;
|
||||
setAnswers(parsed);
|
||||
} else {
|
||||
setAnswers({});
|
||||
}
|
||||
setIsSubmitted(false);
|
||||
} catch {
|
||||
setAnswers({});
|
||||
setIsSubmitted(false);
|
||||
}
|
||||
lastSessionIdRef.current = sessionId;
|
||||
}, [sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (lastSessionIdRef.current !== sessionId) {
|
||||
return;
|
||||
}
|
||||
const storageKey = getStorageKey(sessionId);
|
||||
if (!storageKey) return;
|
||||
|
||||
const hasAnswers = Object.values(answers).some((v) => v.trim());
|
||||
try {
|
||||
if (hasAnswers) {
|
||||
localStorage.setItem(storageKey, JSON.stringify(answers));
|
||||
} else {
|
||||
localStorage.removeItem(storageKey);
|
||||
}
|
||||
} catch {}
|
||||
}, [answers, sessionId]);
|
||||
|
||||
function handleAnswerChange(keyword: string, value: string) {
|
||||
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
||||
}
|
||||
|
||||
function handleSubmit() {
|
||||
// Check if all questions are answered
|
||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||
if (!allAnswered) {
|
||||
return;
|
||||
}
|
||||
setIsSubmitted(true);
|
||||
onSubmitAnswers(answers);
|
||||
|
||||
const storageKey = getStorageKey(sessionId);
|
||||
try {
|
||||
if (storageKey) {
|
||||
localStorage.removeItem(storageKey);
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||
|
||||
// Show submitted state after answers are submitted
|
||||
if (isSubmitted) {
|
||||
if (isAnswered || isSubmitted) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
|
||||
@@ -30,9 +30,9 @@ export function getErrorMessage(result: unknown): string {
|
||||
}
|
||||
if (typeof result === "object" && result !== null) {
|
||||
const response = result as Record<string, unknown>;
|
||||
if (response.error) return stripInternalReasoning(String(response.error));
|
||||
if (response.message)
|
||||
return stripInternalReasoning(String(response.message));
|
||||
if (response.error) return stripInternalReasoning(String(response.error));
|
||||
}
|
||||
return "An error occurred";
|
||||
}
|
||||
@@ -363,8 +363,8 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
||||
|
||||
case "error":
|
||||
const errorMsg =
|
||||
(response.error as string) || response.message || "An error occurred";
|
||||
return `Error: ${errorMsg}`;
|
||||
(response.message as string) || response.error || "An error occurred";
|
||||
return stripInternalReasoning(String(errorMsg));
|
||||
|
||||
case "no_results":
|
||||
const suggestions = (response.suggestions as string[]) || [];
|
||||
|
||||
@@ -10,14 +10,8 @@ import {
|
||||
parseSSELine,
|
||||
} from "./stream-utils";
|
||||
|
||||
function notifySubscribers(
|
||||
stream: ActiveStream,
|
||||
chunk: StreamChunk,
|
||||
skipStore = false,
|
||||
) {
|
||||
if (!skipStore) {
|
||||
stream.chunks.push(chunk);
|
||||
}
|
||||
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
||||
stream.chunks.push(chunk);
|
||||
for (const callback of stream.onChunkCallbacks) {
|
||||
try {
|
||||
callback(chunk);
|
||||
@@ -146,124 +140,3 @@ export async function executeStream(
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconnect to an existing task stream.
|
||||
*
|
||||
* This is used when a client wants to resume receiving updates from a
|
||||
* long-running background task. Messages are replayed from the last_message_id
|
||||
* position, allowing clients to catch up on missed events.
|
||||
*
|
||||
* @param stream - The active stream state
|
||||
* @param taskId - The task ID to reconnect to
|
||||
* @param lastMessageId - The last message ID received (for replay)
|
||||
* @param retryCount - Current retry count
|
||||
*/
|
||||
export async function executeTaskReconnect(
|
||||
stream: ActiveStream,
|
||||
taskId: string,
|
||||
lastMessageId: string = "0",
|
||||
retryCount: number = 0,
|
||||
): Promise<void> {
|
||||
const { abortController } = stream;
|
||||
|
||||
try {
|
||||
const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
const data = parseSSELine(line);
|
||||
if (data !== null) {
|
||||
if (data === "[DONE]") {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const rawChunk = JSON.parse(data) as
|
||||
| StreamChunk
|
||||
| VercelStreamChunk;
|
||||
const chunk = normalizeStreamChunk(rawChunk);
|
||||
if (!chunk) continue;
|
||||
|
||||
notifySubscribers(stream, chunk);
|
||||
|
||||
if (chunk.type === "stream_end") {
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (chunk.type === "error") {
|
||||
stream.status = "error";
|
||||
stream.error = new Error(
|
||||
chunk.message || chunk.content || "Stream error",
|
||||
);
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
"[StreamExecutor] Failed to parse task reconnect SSE chunk:",
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (retryCount < MAX_RETRIES) {
|
||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||
console.log(
|
||||
`[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
return executeTaskReconnect(stream, taskId, lastMessageId, retryCount + 1);
|
||||
}
|
||||
|
||||
stream.status = "error";
|
||||
stream.error = err instanceof Error ? err : new Error("Task reconnect failed");
|
||||
notifySubscribers(stream, {
|
||||
type: "error",
|
||||
message: stream.error.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,28 @@
|
||||
|
||||
This guide walks through creating a simple question-answer AI agent using AutoGPT's visual builder. This is a basic example that can be expanded into more complex agents.
|
||||
|
||||
## **Prerequisites**
|
||||
|
||||
### **Cloud-Hosted AutoGPT**
|
||||
If you're using the cloud-hosted version at [agpt.co](https://agpt.co), you're ready to go! AI blocks come with **built-in credits** — no API keys required to get started. If you'd prefer to use your own API keys, you can add them via **Profile → Integrations**.
|
||||
|
||||
### **Self-Hosted (Docker)**
|
||||
If you're running AutoGPT locally with Docker, you'll need to add your own API keys to `autogpt_platform/backend/.env`:
|
||||
|
||||
```bash
|
||||
# Create or edit backend/.env
|
||||
OPENAI_API_KEY=sk-your-key-here
|
||||
ANTHROPIC_API_KEY=sk-ant-your-key-here
|
||||
# Add other provider keys as needed
|
||||
```
|
||||
|
||||
After adding keys, restart the services:
|
||||
```bash
|
||||
docker compose down && docker compose up -d
|
||||
```
|
||||
|
||||
**Note:** The Calculator example below doesn't require any API credentials — it's a good way to test your setup before adding AI blocks.
|
||||
|
||||
## **Example Agent: Q&A (with AI)**
|
||||
|
||||
A step-by-step guide to creating a simple Q&A agent using input and output blocks.
|
||||
|
||||
Reference in New Issue
Block a user