mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-05 20:35:10 -05:00
Compare commits
21 Commits
classic-fr
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85b6520710 | ||
|
|
bfa942e032 | ||
|
|
11256076d8 | ||
|
|
3ca2387631 | ||
|
|
ed07f02738 | ||
|
|
b121030c94 | ||
|
|
c22c18374d | ||
|
|
e40233a3ac | ||
|
|
3ae5eabf9d | ||
|
|
a077ba9f03 | ||
|
|
5401d54eaa | ||
|
|
5ac89d7c0b | ||
|
|
4f908d5cb3 | ||
|
|
c1aa684743 | ||
|
|
7e5b84cc5c | ||
|
|
09cb313211 | ||
|
|
c026485023 | ||
|
|
1eabc60484 | ||
|
|
f4bf492f24 | ||
|
|
81e48c00a4 | ||
|
|
7dc53071e8 |
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
|
ELEVENLABS_API_KEY=
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,3 +19,6 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
|
# Workspace files
|
||||||
|
workspaces/
|
||||||
|
|||||||
@@ -62,10 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python without upgrading system-managed packages
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
imagemagick \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -0,0 +1,368 @@
|
|||||||
|
"""Redis Streams consumer for operation completion messages.
|
||||||
|
|
||||||
|
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||||
|
completion notifications (OperationCompleteMessage) from external services
|
||||||
|
(like Agent Generator) and triggers the appropriate stream registry and
|
||||||
|
chat service updates via process_operation_success/process_operation_failure.
|
||||||
|
|
||||||
|
Why Redis Streams instead of RabbitMQ?
|
||||||
|
--------------------------------------
|
||||||
|
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||||
|
queue), Redis Streams was chosen for chat completion notifications because:
|
||||||
|
|
||||||
|
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||||
|
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||||
|
Streams for completion notifications keeps all chat streaming infrastructure
|
||||||
|
in one system, simplifying operations and reducing cross-system coordination.
|
||||||
|
|
||||||
|
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||||
|
allowing consumers to replay missed messages after reconnection. This aligns
|
||||||
|
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||||
|
|
||||||
|
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||||
|
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||||
|
recovering from dead consumers - ideal for the completion callback pattern.
|
||||||
|
|
||||||
|
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||||
|
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||||
|
|
||||||
|
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||||
|
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||||
|
transactional semantics without distributed coordination.
|
||||||
|
|
||||||
|
The consumer uses Redis Streams with consumer groups for reliable message
|
||||||
|
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||||
|
stale pending messages from dead consumers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from redis.exceptions import ResponseError
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
from . import stream_registry
|
||||||
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
|
from .config import ChatConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
|
||||||
|
class OperationCompleteMessage(BaseModel):
|
||||||
|
"""Message format for operation completion notifications."""
|
||||||
|
|
||||||
|
operation_id: str
|
||||||
|
task_id: str
|
||||||
|
success: bool
|
||||||
|
result: dict | str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionConsumer:
|
||||||
|
"""Consumer for chat operation completion messages from Redis Streams.
|
||||||
|
|
||||||
|
This consumer initializes its own Prisma client in start() to ensure
|
||||||
|
database operations work correctly within this async context.
|
||||||
|
|
||||||
|
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||||
|
messages reliably with automatic redelivery on failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._consumer_task: asyncio.Task | None = None
|
||||||
|
self._running = False
|
||||||
|
self._prisma: Prisma | None = None
|
||||||
|
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the completion consumer."""
|
||||||
|
if self._running:
|
||||||
|
logger.warning("Completion consumer already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create consumer group if it doesn't exist
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.xgroup_create(
|
||||||
|
config.stream_completion_name,
|
||||||
|
config.stream_consumer_group,
|
||||||
|
id="0",
|
||||||
|
mkstream=True,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Created consumer group '{config.stream_consumer_group}' "
|
||||||
|
f"on stream '{config.stream_completion_name}'"
|
||||||
|
)
|
||||||
|
except ResponseError as e:
|
||||||
|
if "BUSYGROUP" in str(e):
|
||||||
|
logger.debug(
|
||||||
|
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||||
|
logger.info(
|
||||||
|
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_prisma(self) -> Prisma:
|
||||||
|
"""Lazily initialize Prisma client on first use."""
|
||||||
|
if self._prisma is None:
|
||||||
|
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||||
|
self._prisma = Prisma(datasource={"url": database_url})
|
||||||
|
await self._prisma.connect()
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||||
|
return self._prisma
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the completion consumer."""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._consumer_task:
|
||||||
|
self._consumer_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._consumer_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._consumer_task = None
|
||||||
|
|
||||||
|
if self._prisma:
|
||||||
|
await self._prisma.disconnect()
|
||||||
|
self._prisma = None
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||||
|
|
||||||
|
logger.info("Chat completion consumer stopped")
|
||||||
|
|
||||||
|
async def _consume_messages(self) -> None:
|
||||||
|
"""Main message consumption loop with retry logic."""
|
||||||
|
max_retries = 10
|
||||||
|
retry_delay = 5 # seconds
|
||||||
|
retry_count = 0
|
||||||
|
block_timeout = 5000 # milliseconds
|
||||||
|
|
||||||
|
while self._running and retry_count < max_retries:
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
|
||||||
|
# Reset retry count on successful connection
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
# First, claim any stale pending messages from dead consumers
|
||||||
|
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||||
|
# claim them using XAUTOCLAIM
|
||||||
|
try:
|
||||||
|
claimed_result = await redis.xautoclaim(
|
||||||
|
name=config.stream_completion_name,
|
||||||
|
groupname=config.stream_consumer_group,
|
||||||
|
consumername=self._consumer_name,
|
||||||
|
min_idle_time=config.stream_claim_min_idle_ms,
|
||||||
|
start_id="0-0",
|
||||||
|
count=10,
|
||||||
|
)
|
||||||
|
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||||
|
if claimed_result and len(claimed_result) >= 2:
|
||||||
|
claimed_entries = claimed_result[1]
|
||||||
|
if claimed_entries:
|
||||||
|
logger.info(
|
||||||
|
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||||
|
)
|
||||||
|
for entry_id, data in claimed_entries:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await self._process_entry(redis, entry_id, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||||
|
|
||||||
|
# Read new messages from the stream
|
||||||
|
messages = await redis.xreadgroup(
|
||||||
|
groupname=config.stream_consumer_group,
|
||||||
|
consumername=self._consumer_name,
|
||||||
|
streams={config.stream_completion_name: ">"},
|
||||||
|
block=block_timeout,
|
||||||
|
count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for stream_name, entries in messages:
|
||||||
|
for entry_id, data in entries:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await self._process_entry(redis, entry_id, data)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Consumer cancelled")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
logger.error(
|
||||||
|
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
if self._running and retry_count < max_retries:
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
logger.error("Max retries reached, stopping consumer")
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _process_entry(
|
||||||
|
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Process a single stream entry and acknowledge it on success.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis client connection
|
||||||
|
entry_id: The stream entry ID
|
||||||
|
data: The entry data dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Handle the message
|
||||||
|
message_data = data.get("data")
|
||||||
|
if message_data:
|
||||||
|
await self._handle_message(
|
||||||
|
message_data.encode()
|
||||||
|
if isinstance(message_data, str)
|
||||||
|
else message_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Acknowledge the message after successful processing
|
||||||
|
await redis.xack(
|
||||||
|
config.stream_completion_name,
|
||||||
|
config.stream_consumer_group,
|
||||||
|
entry_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing completion message {entry_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Message remains in pending state and will be claimed by
|
||||||
|
# XAUTOCLAIM after min_idle_time expires
|
||||||
|
|
||||||
|
async def _handle_message(self, body: bytes) -> None:
|
||||||
|
"""Handle a completion message using our own Prisma client."""
|
||||||
|
try:
|
||||||
|
data = orjson.loads(body)
|
||||||
|
message = OperationCompleteMessage(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse completion message: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||||
|
f"(task_id={message.task_id}, success={message.success})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find task in registry
|
||||||
|
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||||
|
if task is None:
|
||||||
|
task = await stream_registry.get_task(message.task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
logger.warning(
|
||||||
|
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||||
|
f"(task_id={message.task_id})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||||
|
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Guard against empty task fields
|
||||||
|
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Task has empty critical fields! "
|
||||||
|
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||||
|
f"tool_call_id={task.tool_call_id!r}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.success:
|
||||||
|
await self._handle_success(task, message)
|
||||||
|
else:
|
||||||
|
await self._handle_failure(task, message)
|
||||||
|
|
||||||
|
async def _handle_success(
|
||||||
|
self,
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
message: OperationCompleteMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Handle successful operation completion."""
|
||||||
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_success(task, message.result, prisma)
|
||||||
|
|
||||||
|
async def _handle_failure(
|
||||||
|
self,
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
message: OperationCompleteMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Handle failed operation completion."""
|
||||||
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_failure(task, message.error, prisma)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level consumer instance
|
||||||
|
_consumer: ChatCompletionConsumer | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def start_completion_consumer() -> None:
|
||||||
|
"""Start the global completion consumer."""
|
||||||
|
global _consumer
|
||||||
|
if _consumer is None:
|
||||||
|
_consumer = ChatCompletionConsumer()
|
||||||
|
await _consumer.start()
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_completion_consumer() -> None:
|
||||||
|
"""Stop the global completion consumer."""
|
||||||
|
global _consumer
|
||||||
|
if _consumer:
|
||||||
|
await _consumer.stop()
|
||||||
|
_consumer = None
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_operation_complete(
|
||||||
|
operation_id: str,
|
||||||
|
task_id: str,
|
||||||
|
success: bool,
|
||||||
|
result: dict | str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Publish an operation completion message to Redis Streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID that completed.
|
||||||
|
task_id: The task ID associated with the operation.
|
||||||
|
success: Whether the operation succeeded.
|
||||||
|
result: The result data (for success).
|
||||||
|
error: The error message (for failure).
|
||||||
|
"""
|
||||||
|
message = OperationCompleteMessage(
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
success=success,
|
||||||
|
result=result,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.xadd(
|
||||||
|
config.stream_completion_name,
|
||||||
|
{"data": message.model_dump_json()},
|
||||||
|
maxlen=config.stream_max_length,
|
||||||
|
)
|
||||||
|
logger.info(f"Published completion for operation {operation_id}")
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
"""Shared completion handling for operation success and failure.
|
||||||
|
|
||||||
|
This module provides common logic for handling operation completion from both:
|
||||||
|
- The Redis Streams consumer (completion_consumer.py)
|
||||||
|
- The HTTP webhook endpoint (routes.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
|
|
||||||
|
from . import service as chat_service
|
||||||
|
from . import stream_registry
|
||||||
|
from .response_model import StreamError, StreamToolOutputAvailable
|
||||||
|
from .tools.models import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tools that produce agent_json that needs to be saved to library
|
||||||
|
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||||
|
|
||||||
|
# Keys that should be stripped from agent_json when returning in error responses
|
||||||
|
SENSITIVE_KEYS = frozenset(
|
||||||
|
{
|
||||||
|
"api_key",
|
||||||
|
"apikey",
|
||||||
|
"api_secret",
|
||||||
|
"password",
|
||||||
|
"secret",
|
||||||
|
"credentials",
|
||||||
|
"credential",
|
||||||
|
"token",
|
||||||
|
"access_token",
|
||||||
|
"refresh_token",
|
||||||
|
"private_key",
|
||||||
|
"privatekey",
|
||||||
|
"auth",
|
||||||
|
"authorization",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_agent_json(obj: Any) -> Any:
|
||||||
|
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to sanitize (dict, list, or primitive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized copy with sensitive keys removed/redacted
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {
|
||||||
|
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||||
|
for k, v in obj.items()
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [_sanitize_agent_json(item) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMessageUpdateError(Exception):
|
||||||
|
"""Raised when updating a tool message in the database fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_tool_message(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
content: str,
|
||||||
|
prisma_client: Prisma | None,
|
||||||
|
) -> None:
|
||||||
|
"""Update tool message in database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session ID
|
||||||
|
tool_call_id: The tool call ID to update
|
||||||
|
content: The new content for the message
|
||||||
|
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolMessageUpdateError: If the database update fails. The caller should
|
||||||
|
handle this to avoid marking the task as completed with inconsistent state.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if prisma_client:
|
||||||
|
# Use provided Prisma client (for consumer with its own connection)
|
||||||
|
updated_count = await prisma_client.chatmessage.update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={"content": content},
|
||||||
|
)
|
||||||
|
# Check if any rows were updated - 0 means message not found
|
||||||
|
if updated_count == 0:
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use service function (for webhook endpoint)
|
||||||
|
await chat_service._update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=content,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||||
|
"""Serialize result to JSON string with sensible defaults.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The result to serialize. Can be a dict, list, string,
|
||||||
|
number, boolean, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||||
|
only when result is explicitly None.
|
||||||
|
"""
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
if result is None:
|
||||||
|
return '{"status": "completed"}'
|
||||||
|
return orjson.dumps(result).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_agent_from_result(
|
||||||
|
result: dict[str, Any],
|
||||||
|
user_id: str | None,
|
||||||
|
tool_name: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Save agent to library if result contains agent_json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The result dict that may contain agent_json
|
||||||
|
user_id: The user ID to save the agent for
|
||||||
|
tool_name: The tool name (create_agent or edit_agent)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated result dict with saved agent details, or original result if no agent_json
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||||
|
return result
|
||||||
|
|
||||||
|
agent_json = result.get("agent_json")
|
||||||
|
if not agent_json:
|
||||||
|
logger.warning(
|
||||||
|
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .tools.agent_generator import save_agent_to_library
|
||||||
|
|
||||||
|
is_update = tool_name == "edit_agent"
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
agent_json, user_id, is_update=is_update
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||||
|
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return a response similar to AgentSavedResponse
|
||||||
|
return {
|
||||||
|
"type": "agent_saved",
|
||||||
|
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||||
|
"agent_id": created_graph.id,
|
||||||
|
"agent_name": created_graph.name,
|
||||||
|
"library_agent_id": library_agent.id,
|
||||||
|
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||||
|
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Return error but don't fail the whole operation
|
||||||
|
# Sanitize agent_json to remove sensitive keys before returning
|
||||||
|
return {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||||
|
"error": str(e),
|
||||||
|
"agent_json": _sanitize_agent_json(agent_json),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def process_operation_success(
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
result: dict | str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle successful operation completion.
|
||||||
|
|
||||||
|
Publishes the result to the stream registry, updates the database,
|
||||||
|
generates LLM continuation, and marks the task as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The active task that completed
|
||||||
|
result: The result data from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolMessageUpdateError: If the database update fails. The task will be
|
||||||
|
marked as failed instead of completed to avoid inconsistent state.
|
||||||
|
"""
|
||||||
|
# For agent generation tools, save the agent to library
|
||||||
|
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||||
|
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||||
|
|
||||||
|
# Serialize result for output (only substitute default when result is exactly None)
|
||||||
|
result_output = result if result is not None else {"status": "completed"}
|
||||||
|
output_str = (
|
||||||
|
result_output
|
||||||
|
if isinstance(result_output, str)
|
||||||
|
else orjson.dumps(result_output).decode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish result to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=task.tool_call_id,
|
||||||
|
toolName=task.tool_name,
|
||||||
|
output=output_str,
|
||||||
|
success=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update pending operation in database
|
||||||
|
# If this fails, we must not continue to mark the task as completed
|
||||||
|
result_str = serialize_result(result)
|
||||||
|
try:
|
||||||
|
await _update_tool_message(
|
||||||
|
session_id=task.session_id,
|
||||||
|
tool_call_id=task.tool_call_id,
|
||||||
|
content=result_str,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
# DB update failed - mark task as failed to avoid inconsistent state
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||||
|
"marking as failed instead of completed"
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamError(errorText="Failed to save operation result to database"),
|
||||||
|
)
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Generate LLM continuation with streaming
|
||||||
|
try:
|
||||||
|
await chat_service._generate_llm_continuation_with_streaming(
|
||||||
|
session_id=task.session_id,
|
||||||
|
user_id=task.user_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as completed and release Redis lock
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||||
|
try:
|
||||||
|
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_operation_failure(
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
error: str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle failed operation completion.
|
||||||
|
|
||||||
|
Publishes the error to the stream registry, updates the database with
|
||||||
|
the error response, and marks the task as failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The active task that failed
|
||||||
|
error: The error message from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
|
"""
|
||||||
|
error_msg = error or "Operation failed"
|
||||||
|
|
||||||
|
# Publish error to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamError(errorText=error_msg),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update pending operation with error
|
||||||
|
# If this fails, we still continue to mark the task as failed
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=error_msg,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await _update_tool_message(
|
||||||
|
session_id=task.session_id,
|
||||||
|
tool_call_id=task.tool_call_id,
|
||||||
|
content=error_response.model_dump_json(),
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
# DB update failed - log but continue with cleanup
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||||
|
"continuing with cleanup"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as failed and release Redis lock
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||||
|
try:
|
||||||
|
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||||
|
|
||||||
|
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
@@ -44,6 +44,48 @@ class ChatConfig(BaseSettings):
|
|||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Stream registry configuration for SSE reconnection
|
||||||
|
stream_ttl: int = Field(
|
||||||
|
default=3600,
|
||||||
|
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||||
|
)
|
||||||
|
stream_max_length: int = Field(
|
||||||
|
default=10000,
|
||||||
|
description="Maximum number of messages to store per stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis Streams configuration for completion consumer
|
||||||
|
stream_completion_name: str = Field(
|
||||||
|
default="chat:completions",
|
||||||
|
description="Redis Stream name for operation completions",
|
||||||
|
)
|
||||||
|
stream_consumer_group: str = Field(
|
||||||
|
default="chat_consumers",
|
||||||
|
description="Consumer group name for completion stream",
|
||||||
|
)
|
||||||
|
stream_claim_min_idle_ms: int = Field(
|
||||||
|
default=60000,
|
||||||
|
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis key prefixes for stream registry
|
||||||
|
task_meta_prefix: str = Field(
|
||||||
|
default="chat:task:meta:",
|
||||||
|
description="Prefix for task metadata hash keys",
|
||||||
|
)
|
||||||
|
task_stream_prefix: str = Field(
|
||||||
|
default="chat:stream:",
|
||||||
|
description="Prefix for task message stream keys",
|
||||||
|
)
|
||||||
|
task_op_prefix: str = Field(
|
||||||
|
default="chat:task:op:",
|
||||||
|
description="Prefix for operation ID to task ID mapping keys",
|
||||||
|
)
|
||||||
|
internal_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||||
langfuse_prompt_name: str = Field(
|
langfuse_prompt_name: str = Field(
|
||||||
@@ -82,6 +124,14 @@ class ChatConfig(BaseSettings):
|
|||||||
v = "https://openrouter.ai/api/v1"
|
v = "https://openrouter.ai/api/v1"
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("internal_api_key", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_internal_api_key(cls, v):
|
||||||
|
"""Get internal API key from environment if not provided."""
|
||||||
|
if v is None:
|
||||||
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
|
return v
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.START
|
type: ResponseType = ResponseType.START
|
||||||
messageId: str = Field(..., description="Unique message ID")
|
messageId: str = Field(..., description="Unique message ID")
|
||||||
|
taskId: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Query, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
|
from . import stream_registry
|
||||||
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
|
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveStreamInfo(BaseModel):
|
||||||
|
"""Information about an active stream for reconnection."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
last_message_id: str # Redis Stream message ID for resumption
|
||||||
|
operation_id: str # Operation ID for completion tracking
|
||||||
|
tool_name: str # Name of the tool being executed
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
"""Response model providing complete details for a chat session, including messages."""
|
"""Response model providing complete details for a chat session, including messages."""
|
||||||
|
|
||||||
@@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
|
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class OperationCompleteRequest(BaseModel):
|
||||||
|
"""Request model for external completion webhook."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
result: dict | str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -166,13 +188,14 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
|
If there's an active stream for this session, returns the task_id for reconnection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -180,11 +203,28 @@ async def get_session(
|
|||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
logger.info(
|
|
||||||
f"Returning session {session_id}: "
|
# Check if there's an active stream for this session
|
||||||
f"message_count={len(messages)}, "
|
active_stream_info = None
|
||||||
f"roles={[m.get('role') for m in messages]}"
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
if active_task:
|
||||||
|
# Filter out the in-progress assistant message from the session response.
|
||||||
|
# The client will receive the complete assistant response through the SSE
|
||||||
|
# stream replay instead, preventing duplicate content.
|
||||||
|
if messages and messages[-1].get("role") == "assistant":
|
||||||
|
messages = messages[:-1]
|
||||||
|
|
||||||
|
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||||
|
# Since we filtered out the cached assistant message, the client needs
|
||||||
|
# the full stream to reconstruct the response.
|
||||||
|
active_stream_info = ActiveStreamInfo(
|
||||||
|
task_id=active_task.task_id,
|
||||||
|
last_message_id="0-0",
|
||||||
|
operation_id=active_task.operation_id,
|
||||||
|
tool_name=active_task.tool_name,
|
||||||
|
)
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
@@ -192,6 +232,7 @@ async def get_session(
|
|||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
active_stream=active_stream_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -211,49 +252,112 @@ async def stream_chat_post(
|
|||||||
- Tool call UI elements (if invoked)
|
- Tool call UI elements (if invoked)
|
||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
|
The AI generation runs in a background task that continues even if the client disconnects.
|
||||||
|
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||||
|
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||||
|
containing the task_id for reconnection.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
|
# Create a task in the stream registry for reconnection support
|
||||||
|
task_id = str(uuid_module.uuid4())
|
||||||
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
await stream_registry.create_task(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||||
|
tool_name="chat",
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
|
async def run_ai_generation():
|
||||||
|
try:
|
||||||
|
# Emit a start event with task_id for reconnection
|
||||||
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
request.message,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
|
context=request.context,
|
||||||
|
):
|
||||||
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
|
# Mark task as completed
|
||||||
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in background AI generation for session {session_id}: {e}"
|
||||||
|
)
|
||||||
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
|
# Start the AI generation in a background task
|
||||||
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
|
||||||
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
subscriber_queue = None
|
||||||
first_chunk_type: str | None = None
|
try:
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
session_id,
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
request.message,
|
task_id=task_id,
|
||||||
is_user_message=request.is_user_message,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
last_message_id="0-0", # Get all messages from the beginning
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
)
|
||||||
context=request.context,
|
|
||||||
):
|
if subscriber_queue is None:
|
||||||
if chunk_count < 3:
|
yield StreamFinish().to_sse()
|
||||||
logger.info(
|
yield "data: [DONE]\n\n"
|
||||||
"Chat stream chunk",
|
return
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
# Read from the subscriber queue and yield to SSE
|
||||||
"chunk_type": str(chunk.type),
|
while True:
|
||||||
},
|
try:
|
||||||
)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
if not first_chunk_type:
|
yield chunk.to_sse()
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
# Check for finish signal
|
||||||
yield chunk.to_sse()
|
if isinstance(chunk, StreamFinish):
|
||||||
logger.info(
|
break
|
||||||
"Chat stream completed",
|
except asyncio.TimeoutError:
|
||||||
extra={
|
# Send heartbeat to keep connection alive
|
||||||
"session_id": session_id,
|
yield StreamHeartbeat().to_sse()
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
except GeneratorExit:
|
||||||
},
|
pass # Client disconnected - background task continues
|
||||||
)
|
except Exception as e:
|
||||||
# AI SDK protocol termination
|
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||||
yield "data: [DONE]\n\n"
|
finally:
|
||||||
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
|
if subscriber_queue is not None:
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(
|
||||||
|
task_id, subscriber_queue
|
||||||
|
)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -366,6 +470,251 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Task Streaming (SSE Reconnection) ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}/stream",
|
||||||
|
)
|
||||||
|
async def stream_task(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
last_message_id: str = Query(
|
||||||
|
default="0-0",
|
||||||
|
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reconnect to a long-running task's SSE stream.
|
||||||
|
|
||||||
|
When a long-running operation (like agent generation) starts, the client
|
||||||
|
receives a task_id. If the connection drops, the client can reconnect
|
||||||
|
using this endpoint to resume receiving updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID from the operation_started response.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||||
|
"""
|
||||||
|
# Check task existence and expiry before subscribing
|
||||||
|
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||||
|
|
||||||
|
if error_code == "TASK_EXPIRED":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=410,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_EXPIRED",
|
||||||
|
"message": "This operation has expired. Please try again.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_code == "TASK_NOT_FOUND":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate ownership if task has an owner
|
||||||
|
if task and task.user_id and user_id != task.user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"code": "ACCESS_DENIED",
|
||||||
|
"message": "You do not have access to this task.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get subscriber queue from stream registry
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id=last_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found or access denied.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for next chunk with timeout for heartbeats
|
||||||
|
chunk = await asyncio.wait_for(
|
||||||
|
subscriber_queue.get(), timeout=heartbeat_interval
|
||||||
|
)
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
# Check for finish signal
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send heartbeat to keep connection alive
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
# Unsubscribe when client disconnects or stream ends
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}",
|
||||||
|
)
|
||||||
|
async def get_task_status(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get the status of a long-running task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to check.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If task_id is not found or user doesn't have access.
|
||||||
|
"""
|
||||||
|
task = await stream_registry.get_task(task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task.user_id and user_id != task.user_id:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.task_id,
|
||||||
|
"session_id": task.session_id,
|
||||||
|
"status": task.status,
|
||||||
|
"tool_name": task.tool_name,
|
||||||
|
"operation_id": task.operation_id,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== External Completion Webhook ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/operations/{operation_id}/complete",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
async def complete_operation(
|
||||||
|
operation_id: str,
|
||||||
|
request: OperationCompleteRequest,
|
||||||
|
x_api_key: str | None = Header(default=None),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
External completion webhook for long-running operations.
|
||||||
|
|
||||||
|
Called by Agent Generator (or other services) when an operation completes.
|
||||||
|
This triggers the stream registry to publish completion and continue LLM generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID to complete.
|
||||||
|
request: Completion payload with success status and result/error.
|
||||||
|
x_api_key: Internal API key for authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Status of the completion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If API key is invalid or operation not found.
|
||||||
|
"""
|
||||||
|
# Validate internal API key - reject if not configured or invalid
|
||||||
|
if not config.internal_api_key:
|
||||||
|
logger.error(
|
||||||
|
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Webhook not available: internal API key not configured",
|
||||||
|
)
|
||||||
|
if x_api_key != config.internal_api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
# Find task by operation_id
|
||||||
|
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Operation {operation_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received completion webhook for operation {operation_id} "
|
||||||
|
f"(task_id={task.task_id}, success={request.success})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.success:
|
||||||
|
await process_operation_success(task, request.result)
|
||||||
|
else:
|
||||||
|
await process_operation_failure(task, request.error)
|
||||||
|
|
||||||
|
return {"status": "ok", "task_id": task.task_id}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Configuration ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/ttl", status_code=200)
|
||||||
|
async def get_ttl_config() -> dict:
|
||||||
|
"""
|
||||||
|
Get the stream TTL configuration.
|
||||||
|
|
||||||
|
Returns the Time-To-Live settings for chat streams, which determines
|
||||||
|
how long clients can reconnect to an active stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: TTL configuration with seconds and milliseconds values.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"stream_ttl_seconds": config.stream_ttl,
|
||||||
|
"stream_ttl_ms": config.stream_ttl * 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,9 +33,10 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import AppEnvironment, Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
|
from . import stream_registry
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -221,8 +222,18 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
# In non-production environments, fetch the latest prompt version
|
||||||
|
# instead of the production-labeled version for easier testing
|
||||||
|
label = (
|
||||||
|
None
|
||||||
|
if settings.config.app_env == AppEnvironment.PRODUCTION
|
||||||
|
else "latest"
|
||||||
|
)
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
langfuse.get_prompt,
|
||||||
|
config.langfuse_prompt_name,
|
||||||
|
label=label,
|
||||||
|
cache_ttl_seconds=0,
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -617,6 +628,9 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, StreamHeartbeat):
|
||||||
|
# Pass through heartbeat to keep SSE connection alive
|
||||||
|
yield chunk
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
@@ -1184,8 +1198,9 @@ async def _yield_tool_call(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate operation ID
|
# Generate operation ID and task ID
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
task_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Build a user-friendly message based on tool and arguments
|
# Build a user-friendly message based on tool and arguments
|
||||||
if tool_name == "create_agent":
|
if tool_name == "create_agent":
|
||||||
@@ -1228,6 +1243,16 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Wrap session save and task creation in try-except to release lock on failure
|
# Wrap session save and task creation in try-except to release lock on failure
|
||||||
try:
|
try:
|
||||||
|
# Create task in stream registry for SSE reconnection support
|
||||||
|
await stream_registry.create_task(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Save assistant message with tool_call FIRST (required by LLM)
|
# Save assistant message with tool_call FIRST (required by LLM)
|
||||||
assistant_message = ChatMessage(
|
assistant_message = ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
@@ -1249,23 +1274,27 @@ async def _yield_tool_call(
|
|||||||
session.messages.append(pending_message)
|
session.messages.append(pending_message)
|
||||||
await upsert_chat_session(session)
|
await upsert_chat_session(session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Saved pending operation {operation_id} for tool {tool_name} "
|
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
||||||
f"in session {session.session_id}"
|
f"for tool {tool_name} in session {session.session_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store task reference in module-level set to prevent GC before completion
|
# Store task reference in module-level set to prevent GC before completion
|
||||||
task = asyncio.create_task(
|
bg_task = asyncio.create_task(
|
||||||
_execute_long_running_tool(
|
_execute_long_running_tool_with_streaming(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
parameters=arguments,
|
parameters=arguments,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_background_tasks.add(task)
|
_background_tasks.add(bg_task)
|
||||||
task.add_done_callback(_background_tasks.discard)
|
bg_task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
# Associate the asyncio task with the stream registry task
|
||||||
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||||
if (
|
if (
|
||||||
@@ -1283,6 +1312,11 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Release the Redis lock since the background task won't be spawned
|
# Release the Redis lock since the background task won't be spawned
|
||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
# Mark stream registry task as failed if it was created
|
||||||
|
try:
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
@@ -1296,6 +1330,7 @@ async def _yield_tool_call(
|
|||||||
message=started_msg,
|
message=started_msg,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
|
task_id=task_id, # Include task_id for SSE reconnection
|
||||||
).model_dump_json(),
|
).model_dump_json(),
|
||||||
success=True,
|
success=True,
|
||||||
)
|
)
|
||||||
@@ -1365,6 +1400,9 @@ async def _execute_long_running_tool(
|
|||||||
|
|
||||||
This function runs independently of the SSE connection, so the operation
|
This function runs independently of the SSE connection, so the operation
|
||||||
survives if the user closes their browser tab.
|
survives if the user closes their browser tab.
|
||||||
|
|
||||||
|
NOTE: This is the legacy function without stream registry support.
|
||||||
|
Use _execute_long_running_tool_with_streaming for new implementations.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load fresh session (not stale reference)
|
# Load fresh session (not stale reference)
|
||||||
@@ -1417,6 +1455,133 @@ async def _execute_long_running_tool(
|
|||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_long_running_tool_with_streaming(
|
||||||
|
tool_name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
tool_call_id: str,
|
||||||
|
operation_id: str,
|
||||||
|
task_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a long-running tool with stream registry support for SSE reconnection.
|
||||||
|
|
||||||
|
This function runs independently of the SSE connection, publishes progress
|
||||||
|
to the stream registry, and survives if the user closes their browser tab.
|
||||||
|
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
||||||
|
|
||||||
|
If the external service returns a 202 Accepted (async), this function exits
|
||||||
|
early and lets the Redis Streams completion consumer handle the rest.
|
||||||
|
"""
|
||||||
|
# Track whether we delegated to async processing - if so, the Redis Streams
|
||||||
|
# completion consumer (stream_registry / completion_consumer) will handle cleanup, not us
|
||||||
|
delegated_to_async = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load fresh session (not stale reference)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for background tool")
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Pass operation_id and task_id to the tool for async processing
|
||||||
|
enriched_parameters = {
|
||||||
|
**parameters,
|
||||||
|
"_operation_id": operation_id,
|
||||||
|
"_task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the actual tool
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=enriched_parameters,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the tool result indicates async processing
|
||||||
|
# (e.g., Agent Generator returned 202 Accepted)
|
||||||
|
try:
|
||||||
|
if isinstance(result.output, dict):
|
||||||
|
result_data = result.output
|
||||||
|
elif result.output:
|
||||||
|
result_data = orjson.loads(result.output)
|
||||||
|
else:
|
||||||
|
result_data = {}
|
||||||
|
if result_data.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Tool {tool_name} delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id}). "
|
||||||
|
f"Redis Streams completion consumer will handle the rest."
|
||||||
|
)
|
||||||
|
# Don't publish result, don't continue with LLM, and don't cleanup
|
||||||
|
# The Redis Streams consumer (completion_consumer) will handle
|
||||||
|
# everything when the external service completes via webhook
|
||||||
|
delegated_to_async = True
|
||||||
|
return
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
pass # Not JSON or not async - continue normally
|
||||||
|
|
||||||
|
# Publish tool result to stream registry
|
||||||
|
await stream_registry.publish_chunk(task_id, result)
|
||||||
|
|
||||||
|
# Update the pending message with result
|
||||||
|
result_str = (
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else orjson.dumps(result.output).decode("utf-8")
|
||||||
|
)
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=result_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Background tool {tool_name} completed for session {session_id} "
|
||||||
|
f"(task_id={task_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate LLM continuation and stream chunks to registry
|
||||||
|
await _generate_llm_continuation_with_streaming(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as completed in stream registry
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=f"Tool {tool_name} failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish error to stream registry followed by finish event
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamError(errorText=str(e)),
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=error_response.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as failed in stream registry
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
|
finally:
|
||||||
|
# Only cleanup if we didn't delegate to async processing
|
||||||
|
# For async path, the Redis Streams completion consumer handles cleanup
|
||||||
|
if not delegated_to_async:
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
async def _update_pending_operation(
|
async def _update_pending_operation(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
@@ -1597,3 +1762,128 @@ async def _generate_llm_continuation(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_llm_continuation_with_streaming(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
task_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an LLM response with streaming to the stream registry.
|
||||||
|
|
||||||
|
This is called by background tasks to continue the conversation
|
||||||
|
after a tool result is saved. Chunks are published to the stream registry
|
||||||
|
so reconnecting clients can receive them.
|
||||||
|
"""
|
||||||
|
import uuid as uuid_module
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
|
# Build messages in OpenAI format
|
||||||
|
messages = session.to_openai_messages()
|
||||||
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
|
role="system",
|
||||||
|
content=system_prompt,
|
||||||
|
)
|
||||||
|
messages = [system_message] + messages
|
||||||
|
|
||||||
|
# Build extra_body for tracing
|
||||||
|
extra_body: dict[str, Any] = {
|
||||||
|
"posthogProperties": {
|
||||||
|
"environment": settings.config.app_env.value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
extra_body["user"] = user_id[:128]
|
||||||
|
extra_body["posthogDistinctId"] = user_id
|
||||||
|
if session_id:
|
||||||
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Make streaming LLM call (no tools - just text response)
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
# Generate unique IDs for AI SDK protocol
|
||||||
|
message_id = str(uuid_module.uuid4())
|
||||||
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
|
# Publish start event
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=config.model,
|
||||||
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
|
extra_body=extra_body,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_content = ""
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices and chunk.choices[0].delta.content:
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
assistant_content += delta
|
||||||
|
# Publish delta to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamTextDelta(id=text_block_id, delta=delta),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish end events
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||||
|
|
||||||
|
if assistant_content:
|
||||||
|
# Reload session from DB to avoid race condition with user messages
|
||||||
|
fresh_session = await get_chat_session(session_id, user_id)
|
||||||
|
if not fresh_session:
|
||||||
|
logger.error(
|
||||||
|
f"Session {session_id} disappeared during LLM continuation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save assistant message to database
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_content,
|
||||||
|
)
|
||||||
|
fresh_session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Save to database (not cache) to persist the response
|
||||||
|
await upsert_chat_session(fresh_session)
|
||||||
|
|
||||||
|
# Invalidate cache so next poll/refresh gets fresh data
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated streaming LLM continuation for session {session_id} "
|
||||||
|
f"(task_id={task_id}), response length: {len(assistant_content)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Streaming LLM continuation returned empty response for {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
# Publish error to stream registry followed by finish event
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|||||||
@@ -0,0 +1,704 @@
|
|||||||
|
"""Stream registry for managing reconnectable SSE streams.
|
||||||
|
|
||||||
|
This module provides a registry for tracking active streaming tasks and their
|
||||||
|
messages. It uses Redis for all state management (no in-memory state), making
|
||||||
|
pods stateless and horizontally scalable.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- Redis Stream: Persists all messages for replay and real-time delivery
|
||||||
|
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||||
|
|
||||||
|
Subscribers:
|
||||||
|
1. Replay missed messages from Redis Stream (XREAD)
|
||||||
|
2. Listen for live updates via blocking XREAD
|
||||||
|
3. No in-memory state required on the subscribing pod
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
from .config import ChatConfig
|
||||||
|
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||||
|
_local_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
# Track listener tasks per subscriber queue for cleanup
|
||||||
|
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
||||||
|
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
||||||
|
|
||||||
|
# Timeout for putting chunks into subscriber queues (seconds)
|
||||||
|
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||||
|
QUEUE_PUT_TIMEOUT = 5.0
|
||||||
|
|
||||||
|
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||||
|
# Returns 1 if status was updated, 0 if already completed/failed
|
||||||
|
COMPLETE_TASK_SCRIPT = """
|
||||||
|
local current = redis.call("HGET", KEYS[1], "status")
|
||||||
|
if current == "running" then
|
||||||
|
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActiveTask:
|
||||||
|
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
session_id: str
|
||||||
|
user_id: str | None
|
||||||
|
tool_call_id: str
|
||||||
|
tool_name: str
|
||||||
|
operation_id: str
|
||||||
|
status: Literal["running", "completed", "failed"] = "running"
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
asyncio_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_meta_key(task_id: str) -> str:
|
||||||
|
"""Get Redis key for task metadata."""
|
||||||
|
return f"{config.task_meta_prefix}{task_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_stream_key(task_id: str) -> str:
|
||||||
|
"""Get Redis key for task message stream."""
|
||||||
|
return f"{config.task_stream_prefix}{task_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||||
|
"""Get Redis key for operation_id to task_id mapping."""
|
||||||
|
return f"{config.task_op_prefix}{operation_id}"
|
||||||
|
|
||||||
|
|
||||||
|
async def create_task(
|
||||||
|
task_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
operation_id: str,
|
||||||
|
) -> ActiveTask:
|
||||||
|
"""Create a new streaming task in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique identifier for the task
|
||||||
|
session_id: Chat session ID
|
||||||
|
user_id: User ID (may be None for anonymous)
|
||||||
|
tool_call_id: Tool call ID from the LLM
|
||||||
|
tool_name: Name of the tool being executed
|
||||||
|
operation_id: Operation ID for webhook callbacks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created ActiveTask instance (metadata only)
|
||||||
|
"""
|
||||||
|
task = ActiveTask(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store metadata in Redis
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
|
await redis.hset( # type: ignore[misc]
|
||||||
|
meta_key,
|
||||||
|
mapping={
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"status": task.status,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
|
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_chunk(
|
||||||
|
task_id: str,
|
||||||
|
chunk: StreamBaseResponse,
|
||||||
|
) -> str:
|
||||||
|
"""Publish a chunk to Redis Stream.
|
||||||
|
|
||||||
|
All delivery is via Redis Streams - no in-memory state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to publish to
|
||||||
|
chunk: The stream response chunk to publish
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Redis Stream message ID
|
||||||
|
"""
|
||||||
|
chunk_json = chunk.model_dump_json()
|
||||||
|
message_id = "0-0"
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
|
raw_id = await redis.xadd(
|
||||||
|
stream_key,
|
||||||
|
{"data": chunk_json},
|
||||||
|
maxlen=config.stream_max_length,
|
||||||
|
)
|
||||||
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
|
# Set TTL on stream to match task metadata TTL
|
||||||
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to publish chunk for task {task_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return message_id
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_to_task(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
last_message_id: str = "0-0",
|
||||||
|
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||||
|
"""Subscribe to a task's stream with replay of missed messages.
|
||||||
|
|
||||||
|
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to subscribe to
|
||||||
|
user_id: User ID for ownership validation
|
||||||
|
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
|
or user doesn't have access
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
logger.debug(f"Task {task_id} not found in Redis")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
|
task_status = meta.get("status", "")
|
||||||
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task_user_id:
|
||||||
|
if user_id != task_user_id:
|
||||||
|
logger.warning(
|
||||||
|
f"User {user_id} denied access to task {task_id} "
|
||||||
|
f"owned by {task_user_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
# Step 1: Replay messages from Redis Stream
|
||||||
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
|
|
||||||
|
replayed_count = 0
|
||||||
|
replay_last_id = last_message_id
|
||||||
|
if messages:
|
||||||
|
for _stream_name, stream_messages in messages:
|
||||||
|
for msg_id, msg_data in stream_messages:
|
||||||
|
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
|
if "data" in msg_data:
|
||||||
|
try:
|
||||||
|
chunk_data = orjson.loads(msg_data["data"])
|
||||||
|
chunk = _reconstruct_chunk(chunk_data)
|
||||||
|
if chunk:
|
||||||
|
await subscriber_queue.put(chunk)
|
||||||
|
replayed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||||
|
|
||||||
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
|
if task_status == "running":
|
||||||
|
listener_task = asyncio.create_task(
|
||||||
|
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||||
|
)
|
||||||
|
# Track listener task for cleanup on unsubscribe
|
||||||
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
|
else:
|
||||||
|
# Task is completed/failed - add finish marker
|
||||||
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_listener(
|
||||||
|
task_id: str,
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
|
last_replayed_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
|
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||||
|
when messages are published during the gap between replay and subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to listen for
|
||||||
|
subscriber_queue: Queue to deliver messages to
|
||||||
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
|
"""
|
||||||
|
queue_id = id(subscriber_queue)
|
||||||
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
|
last_delivered_id = last_replayed_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
current_id = last_replayed_id
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Block for up to 30 seconds waiting for new messages
|
||||||
|
# This allows periodic checking if task is still running
|
||||||
|
messages = await redis.xread(
|
||||||
|
{stream_key: current_id}, block=30000, count=100
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
# Timeout - check if task is still running
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||||
|
if status and status != "running":
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(StreamFinish()),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout delivering finish event for task {task_id}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
for _stream_name, stream_messages in messages:
|
||||||
|
for msg_id, msg_data in stream_messages:
|
||||||
|
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
|
||||||
|
if "data" not in msg_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk_data = orjson.loads(msg_data["data"])
|
||||||
|
chunk = _reconstruct_chunk(chunk_data)
|
||||||
|
if chunk:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(chunk),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
# Update last delivered ID on successful delivery
|
||||||
|
last_delivered_id = current_id
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Subscriber queue full for task {task_id}, "
|
||||||
|
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||||
|
)
|
||||||
|
# Send overflow error with recovery info
|
||||||
|
try:
|
||||||
|
overflow_error = StreamError(
|
||||||
|
errorText="Message delivery timeout - some messages may have been missed",
|
||||||
|
code="QUEUE_OVERFLOW",
|
||||||
|
details={
|
||||||
|
"last_delivered_id": last_delivered_id,
|
||||||
|
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
subscriber_queue.put_nowait(overflow_error)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
# Queue is completely stuck, nothing more we can do
|
||||||
|
logger.error(
|
||||||
|
f"Cannot deliver overflow error for task {task_id}, "
|
||||||
|
"queue completely blocked"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop listening on finish
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing stream message: {e}")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||||
|
raise # Re-raise to propagate cancellation
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||||
|
# On error, send finish to unblock subscriber
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(StreamFinish()),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
|
logger.warning(
|
||||||
|
f"Could not deliver finish event for task {task_id} after error"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Clean up listener task mapping on exit
|
||||||
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def mark_task_completed(
|
||||||
|
task_id: str,
|
||||||
|
status: Literal["completed", "failed"] = "completed",
|
||||||
|
) -> bool:
|
||||||
|
"""Mark a task as completed and publish finish event.
|
||||||
|
|
||||||
|
This is idempotent - calling multiple times with the same task_id is safe.
|
||||||
|
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
||||||
|
Status is updated first (source of truth), then finish event is published (best-effort).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to mark as completed
|
||||||
|
status: Final status ("completed" or "failed")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if task was newly marked completed, False if already completed/failed
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
|
||||||
|
# Atomic compare-and-swap: only update if status is "running"
|
||||||
|
# This prevents race conditions when multiple callers try to complete simultaneously
|
||||||
|
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||||
|
try:
|
||||||
|
await publish_chunk(task_id, StreamFinish())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to publish finish event for task {task_id}: {e}. "
|
||||||
|
"Listeners will detect completion via status polling."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up local task reference if exists
|
||||||
|
_local_tasks.pop(task_id, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||||
|
"""Find a task by its operation ID.
|
||||||
|
|
||||||
|
Used by webhook callbacks to locate the task to update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: Operation ID to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ActiveTask if found, None otherwise
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
task_id = await redis.get(op_key)
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||||
|
return await get_task(task_id_str)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task(task_id: str) -> ActiveTask | None:
|
||||||
|
"""Get a task by its ID from Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ActiveTask if found, None otherwise
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
return ActiveTask(
|
||||||
|
task_id=meta.get("task_id", ""),
|
||||||
|
session_id=meta.get("session_id", ""),
|
||||||
|
user_id=meta.get("user_id", "") or None,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task_with_expiry_info(
|
||||||
|
task_id: str,
|
||||||
|
) -> tuple[ActiveTask | None, str | None]:
|
||||||
|
"""Get a task by its ID with expiration detection.
|
||||||
|
|
||||||
|
Returns (task, error_code) where error_code is:
|
||||||
|
- None if task found
|
||||||
|
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
||||||
|
- "TASK_NOT_FOUND" if neither exists
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ActiveTask or None, error_code or None)
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
# Check if stream still has data (metadata expired but stream hasn't)
|
||||||
|
stream_len = await redis.xlen(stream_key)
|
||||||
|
if stream_len > 0:
|
||||||
|
return None, "TASK_EXPIRED"
|
||||||
|
return None, "TASK_NOT_FOUND"
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
return (
|
||||||
|
ActiveTask(
|
||||||
|
task_id=meta.get("task_id", ""),
|
||||||
|
session_id=meta.get("session_id", ""),
|
||||||
|
user_id=meta.get("user_id", "") or None,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_active_task_for_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> tuple[ActiveTask | None, str]:
|
||||||
|
"""Get the active (running) task for a session, if any.
|
||||||
|
|
||||||
|
Scans Redis for tasks matching the session_id with status="running".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session ID to look up
|
||||||
|
user_id: User ID for ownership validation (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||||
|
"""
|
||||||
|
|
||||||
|
redis = await get_redis_async()
|
||||||
|
|
||||||
|
# Scan Redis for task metadata keys
|
||||||
|
cursor = 0
|
||||||
|
tasks_checked = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
tasks_checked += 1
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||||
|
if not meta:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
task_session_id = meta.get("session_id", "")
|
||||||
|
task_status = meta.get("status", "")
|
||||||
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
task_id = meta.get("task_id", "")
|
||||||
|
|
||||||
|
if task_session_id == session_id and task_status == "running":
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task_user_id and user_id != task_user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the last message ID from Redis Stream
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
last_id = "0-0"
|
||||||
|
try:
|
||||||
|
messages = await redis.xrevrange(stream_key, count=1)
|
||||||
|
if messages:
|
||||||
|
msg_id = messages[0][0]
|
||||||
|
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get last message ID: {e}")
|
||||||
|
|
||||||
|
return (
|
||||||
|
ActiveTask(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=task_session_id,
|
||||||
|
user_id=task_user_id,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status="running",
|
||||||
|
),
|
||||||
|
last_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return None, "0-0"
|
||||||
|
|
||||||
|
|
||||||
|
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||||
|
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_data: Parsed JSON data from Redis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reconstructed response object, or None if unknown type
|
||||||
|
"""
|
||||||
|
from .response_model import (
|
||||||
|
ResponseType,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamHeartbeat,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map response types to their corresponding classes
|
||||||
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
|
ResponseType.START.value: StreamStart,
|
||||||
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
||||||
|
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
||||||
|
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
||||||
|
ResponseType.ERROR.value: StreamError,
|
||||||
|
ResponseType.USAGE.value: StreamUsage,
|
||||||
|
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_type = chunk_data.get("type")
|
||||||
|
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
if chunk_class is None:
|
||||||
|
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return chunk_class(**chunk_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||||
|
"""Track the asyncio.Task for a task (local reference only).
|
||||||
|
|
||||||
|
This is just for cleanup purposes - the task state is in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID
|
||||||
|
asyncio_task: The asyncio Task to track
|
||||||
|
"""
|
||||||
|
_local_tasks[task_id] = asyncio_task
|
||||||
|
|
||||||
|
|
||||||
|
async def unsubscribe_from_task(
|
||||||
|
task_id: str,
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
|
) -> None:
|
||||||
|
"""Clean up when a subscriber disconnects.
|
||||||
|
|
||||||
|
Cancels the XREAD-based listener task associated with this subscriber queue
|
||||||
|
to prevent resource leaks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID
|
||||||
|
subscriber_queue: The subscriber's queue used to look up the listener task
|
||||||
|
"""
|
||||||
|
queue_id = id(subscriber_queue)
|
||||||
|
listener_entry = _listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
if listener_entry is None:
|
||||||
|
logger.debug(
|
||||||
|
f"No listener task found for task {task_id} queue {queue_id} "
|
||||||
|
"(may have already completed)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
stored_task_id, listener_task = listener_entry
|
||||||
|
|
||||||
|
if stored_task_id != task_id:
|
||||||
|
logger.warning(
|
||||||
|
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
||||||
|
f"found {stored_task_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if listener_task.done():
|
||||||
|
logger.debug(f"Listener task for task {task_id} already completed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cancel the listener task
|
||||||
|
listener_task.cancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for the task to be cancelled with a timeout
|
||||||
|
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Expected - the task was successfully cancelled
|
||||||
|
pass
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout waiting for listener task cancellation for task {task_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
||||||
@@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool
|
|||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
|
from .customize_agent import CustomizeAgentTool
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
from .find_block import FindBlockTool
|
from .find_block import FindBlockTool
|
||||||
@@ -34,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
"add_understanding": AddUnderstandingTool(),
|
"add_understanding": AddUnderstandingTool(),
|
||||||
"create_agent": CreateAgentTool(),
|
"create_agent": CreateAgentTool(),
|
||||||
|
"customize_agent": CustomizeAgentTool(),
|
||||||
"edit_agent": EditAgentTool(),
|
"edit_agent": EditAgentTool(),
|
||||||
"find_agent": FindAgentTool(),
|
"find_agent": FindAgentTool(),
|
||||||
"find_block": FindBlockTool(),
|
"find_block": FindBlockTool(),
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .core import (
|
|||||||
DecompositionStep,
|
DecompositionStep,
|
||||||
LibraryAgentSummary,
|
LibraryAgentSummary,
|
||||||
MarketplaceAgentSummary,
|
MarketplaceAgentSummary,
|
||||||
|
customize_template,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
enrich_library_agents_from_steps,
|
enrich_library_agents_from_steps,
|
||||||
extract_search_terms_from_steps,
|
extract_search_terms_from_steps,
|
||||||
@@ -19,6 +20,7 @@ from .core import (
|
|||||||
get_library_agent_by_graph_id,
|
get_library_agent_by_graph_id,
|
||||||
get_library_agent_by_id,
|
get_library_agent_by_id,
|
||||||
get_library_agents_for_generation,
|
get_library_agents_for_generation,
|
||||||
|
graph_to_json,
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
search_marketplace_agents_for_generation,
|
search_marketplace_agents_for_generation,
|
||||||
@@ -36,6 +38,7 @@ __all__ = [
|
|||||||
"LibraryAgentSummary",
|
"LibraryAgentSummary",
|
||||||
"MarketplaceAgentSummary",
|
"MarketplaceAgentSummary",
|
||||||
"check_external_service_health",
|
"check_external_service_health",
|
||||||
|
"customize_template",
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"enrich_library_agents_from_steps",
|
"enrich_library_agents_from_steps",
|
||||||
"extract_search_terms_from_steps",
|
"extract_search_terms_from_steps",
|
||||||
@@ -48,6 +51,7 @@ __all__ = [
|
|||||||
"get_library_agent_by_id",
|
"get_library_agent_by_id",
|
||||||
"get_library_agents_for_generation",
|
"get_library_agents_for_generation",
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
|
"graph_to_json",
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"json_to_graph",
|
"json_to_graph",
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
|
|||||||
@@ -7,18 +7,11 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import (
|
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||||
Graph,
|
|
||||||
Link,
|
|
||||||
Node,
|
|
||||||
create_graph,
|
|
||||||
get_graph,
|
|
||||||
get_graph_all_versions,
|
|
||||||
get_store_listed_graphs,
|
|
||||||
)
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
|
customize_template_external,
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
generate_agent_external,
|
generate_agent_external,
|
||||||
generate_agent_patch_external,
|
generate_agent_patch_external,
|
||||||
@@ -27,8 +20,6 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -549,15 +540,21 @@ async def decompose_goal(
|
|||||||
async def generate_agent(
|
async def generate_agent(
|
||||||
instructions: DecompositionResult | dict[str, Any],
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams
|
||||||
|
completion notification)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||||
|
and SSE delivery)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -565,8 +562,13 @@ async def generate_agent(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(
|
result = await generate_agent_external(
|
||||||
dict(instructions), _to_dict_list(library_agents)
|
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Don't modify async response
|
||||||
|
if result and result.get("status") == "accepted":
|
||||||
|
return result
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
return result
|
return result
|
||||||
@@ -657,45 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reassign_node_ids(graph: Graph) -> None:
|
|
||||||
"""Reassign all node and link IDs to new UUIDs.
|
|
||||||
|
|
||||||
This is needed when creating a new version to avoid unique constraint violations.
|
|
||||||
"""
|
|
||||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
|
||||||
|
|
||||||
for node in graph.nodes:
|
|
||||||
node.id = id_map[node.id]
|
|
||||||
|
|
||||||
for link in graph.links:
|
|
||||||
link.id = str(uuid.uuid4())
|
|
||||||
if link.source_id in id_map:
|
|
||||||
link.source_id = id_map[link.source_id]
|
|
||||||
if link.sink_id in id_map:
|
|
||||||
link.sink_id = id_map[link.sink_id]
|
|
||||||
|
|
||||||
|
|
||||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
|
||||||
"""Populate user_id in AgentExecutorBlock nodes.
|
|
||||||
|
|
||||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
|
||||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_json: Agent JSON dict (modified in place)
|
|
||||||
user_id: User ID to set
|
|
||||||
"""
|
|
||||||
for node in agent_json.get("nodes", []):
|
|
||||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
|
||||||
input_default = node.get("input_default") or {}
|
|
||||||
if not input_default.get("user_id"):
|
|
||||||
input_default["user_id"] = user_id
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(
|
|
||||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -709,63 +672,21 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
|
||||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
if graph.id:
|
return await library_db.update_graph_in_library(graph, user_id)
|
||||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
return await library_db.create_graph_in_library(graph, user_id)
|
||||||
if existing_versions:
|
|
||||||
latest_version = max(v.version for v in existing_versions)
|
|
||||||
graph.version = latest_version + 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
|
||||||
else:
|
|
||||||
graph.id = str(uuid.uuid4())
|
|
||||||
graph.version = 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Creating new agent with ID {graph.id}")
|
|
||||||
|
|
||||||
created_graph = await create_graph(graph, user_id)
|
|
||||||
|
|
||||||
library_agents = await library_db.create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
agent_id: str, user_id: str | None
|
"""Convert a Graph object to JSON format for the agent generator.
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_id: Graph ID or library agent ID
|
graph: Graph object to convert
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict
|
||||||
"""
|
"""
|
||||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
|
||||||
|
|
||||||
if not graph and user_id:
|
|
||||||
try:
|
|
||||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
|
||||||
graph = await get_graph(
|
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
return None
|
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -802,10 +723,41 @@ async def get_agent_as_json(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_agent_as_json(
|
||||||
|
agent_id: str, user_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Graph ID or library agent ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent as JSON dict or None if not found
|
||||||
|
"""
|
||||||
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return graph_to_json(graph)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | None = None,
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -818,10 +770,12 @@ async def generate_agent_patch(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
error dict {"type": "error", ...}, or None on unexpected error
|
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -829,5 +783,43 @@ async def generate_agent_patch(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(
|
return await generate_agent_patch_external(
|
||||||
update_request, current_agent, _to_dict_list(library_agents)
|
update_request,
|
||||||
|
current_agent,
|
||||||
|
_to_dict_list(library_agents),
|
||||||
|
operation_id,
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def customize_template(
|
||||||
|
template_agent: dict[str, Any],
|
||||||
|
modification_request: str,
|
||||||
|
context: str = "",
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Customize a template/marketplace agent using natural language.
|
||||||
|
|
||||||
|
This is used when users want to modify a template or marketplace agent
|
||||||
|
to fit their specific needs before adding it to their library.
|
||||||
|
|
||||||
|
The external Agent Generator service handles:
|
||||||
|
- Understanding the modification request
|
||||||
|
- Applying changes to the template
|
||||||
|
- Fixing and validating the result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_agent: The template agent JSON to customize
|
||||||
|
modification_request: Natural language description of customizations
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
|
"""
|
||||||
|
_check_service_configured()
|
||||||
|
logger.info("Calling external Agent Generator service for customize_template")
|
||||||
|
return await customize_template_external(
|
||||||
|
template_agent, modification_request, context
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -212,24 +212,45 @@ async def decompose_goal_external(
|
|||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any],
|
instructions: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build request payload
|
||||||
payload: dict[str, Any] = {"instructions": instructions}
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
|
if operation_id and task_id:
|
||||||
|
payload["operation_id"] = operation_id
|
||||||
|
payload["task_id"] = task_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/generate-agent", json=payload)
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
|
|
||||||
|
# Handle 202 Accepted for async processing
|
||||||
|
if response.status_code == 202:
|
||||||
|
logger.info(
|
||||||
|
f"Agent Generator accepted async request "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "accepted",
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -261,6 +282,8 @@ async def generate_agent_patch_external(
|
|||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
@@ -268,21 +291,40 @@ async def generate_agent_patch_external(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build request payload
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"update_request": update_request,
|
"update_request": update_request,
|
||||||
"current_agent_json": current_agent,
|
"current_agent_json": current_agent,
|
||||||
}
|
}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
|
if operation_id and task_id:
|
||||||
|
payload["operation_id"] = operation_id
|
||||||
|
payload["task_id"] = task_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/update-agent", json=payload)
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
|
|
||||||
|
# Handle 202 Accepted for async processing
|
||||||
|
if response.status_code == 202:
|
||||||
|
logger.info(
|
||||||
|
f"Agent Generator accepted async update request "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "accepted",
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -326,6 +368,77 @@ async def generate_agent_patch_external(
|
|||||||
return _create_error_response(error_msg, "unexpected_error")
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
|
async def customize_template_external(
|
||||||
|
template_agent: dict[str, Any],
|
||||||
|
modification_request: str,
|
||||||
|
context: str = "",
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to customize a template/marketplace agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_agent: The template agent JSON to customize
|
||||||
|
modification_request: Natural language description of customizations
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
request = modification_request
|
||||||
|
if context:
|
||||||
|
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"template_agent_json": template_agent,
|
||||||
|
"modification_request": request,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/template-modification", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator template customization failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
|
# Check if it's clarifying questions
|
||||||
|
if data.get("type") == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if it's an error passed through
|
||||||
|
if data.get("type") == "error":
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise return the customized agent JSON
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_type, error_msg = _classify_http_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
error_type, error_msg = _classify_request_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
"""Get available blocks from the external service.
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library."
|
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents."
|
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
|
AsyncProcessingResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
# Extract async processing params (passed by long-running tool handler)
|
||||||
|
operation_id = kwargs.get("_operation_id")
|
||||||
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not description:
|
if not description:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a description of what the agent should do.",
|
message="Please provide a description of what the agent should do.",
|
||||||
@@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool):
|
|||||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(decomposition_result, library_agents)
|
agent_json = await generate_agent(
|
||||||
|
decomposition_result,
|
||||||
|
library_agents,
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if Agent Generator accepted for async processing
|
||||||
|
if agent_json.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Agent generation delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return AsyncProcessingResponse(
|
||||||
|
message="Agent generation started. You'll be notified when it's complete.",
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
agent_description = agent_json.get("description", "")
|
agent_description = agent_json.get("description", "")
|
||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
|
|||||||
@@ -0,0 +1,337 @@
|
|||||||
|
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
|
|
||||||
|
from .agent_generator import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
|
customize_template,
|
||||||
|
get_user_message_for_error,
|
||||||
|
graph_to_json,
|
||||||
|
save_agent_to_library,
|
||||||
|
)
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
ClarifyingQuestion,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizeAgentTool(BaseTool):
|
||||||
|
"""Tool for customizing marketplace/template agents using natural language."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "customize_agent"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Customize a marketplace or template agent using natural language. "
|
||||||
|
"Takes an existing agent from the marketplace and modifies it based on "
|
||||||
|
"the user's requirements before adding to their library."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The marketplace agent ID in format 'creator/slug' "
|
||||||
|
"(e.g., 'autogpt/newsletter-writer'). "
|
||||||
|
"Get this from find_agent results."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"modifications": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Natural language description of how to customize the agent. "
|
||||||
|
"Be specific about what changes you want to make."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Additional context or answers to previous clarifying questions."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"save": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"Whether to save the customized agent to the user's library. "
|
||||||
|
"Default is true. Set to false for preview only."
|
||||||
|
),
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_id", "modifications"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Execute the customize_agent tool.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Parse the agent ID to get creator/slug
|
||||||
|
2. Fetch the template agent from the marketplace
|
||||||
|
3. Call customize_template with the modification request
|
||||||
|
4. Preview or save based on the save parameter
|
||||||
|
"""
|
||||||
|
agent_id = kwargs.get("agent_id", "").strip()
|
||||||
|
modifications = kwargs.get("modifications", "").strip()
|
||||||
|
context = kwargs.get("context", "")
|
||||||
|
save = kwargs.get("save", True)
|
||||||
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
if not agent_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||||
|
error="missing_agent_id",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not modifications:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please describe how you want to customize this agent.",
|
||||||
|
error="missing_modifications",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse agent_id in format "creator/slug"
|
||||||
|
parts = [p.strip() for p in agent_id.split("/")]
|
||||||
|
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Invalid agent ID format: '{agent_id}'. "
|
||||||
|
"Expected format is 'creator/agent-name' "
|
||||||
|
"(e.g., 'autogpt/newsletter-writer')."
|
||||||
|
),
|
||||||
|
error="invalid_agent_id_format",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
creator_username, agent_slug = parts
|
||||||
|
|
||||||
|
# Fetch the marketplace agent details
|
||||||
|
try:
|
||||||
|
agent_details = await store_db.get_store_agent_details(
|
||||||
|
username=creator_username, agent_name=agent_slug
|
||||||
|
)
|
||||||
|
except AgentNotFoundError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Could not find marketplace agent '{agent_id}'. "
|
||||||
|
"Please check the agent ID and try again."
|
||||||
|
),
|
||||||
|
error="agent_not_found",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to fetch the marketplace agent. Please try again.",
|
||||||
|
error="fetch_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_details.store_listing_version_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"The agent '{agent_id}' does not have an available version. "
|
||||||
|
"Please try a different agent."
|
||||||
|
),
|
||||||
|
error="no_version_available",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the full agent graph
|
||||||
|
try:
|
||||||
|
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||||
|
template_agent = graph_to_json(graph)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to fetch the agent configuration. Please try again.",
|
||||||
|
error="graph_fetch_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call customize_template
|
||||||
|
try:
|
||||||
|
result = await customize_template(
|
||||||
|
template_agent=template_agent,
|
||||||
|
modification_request=modifications,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Agent customization is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Failed to customize the agent due to a service error. "
|
||||||
|
"Please try again."
|
||||||
|
),
|
||||||
|
error="customization_service_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Failed to customize the agent. "
|
||||||
|
"The agent generation service may be unavailable or timed out. "
|
||||||
|
"Please try again."
|
||||||
|
),
|
||||||
|
error="customization_failed",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle error response
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="customize the agent",
|
||||||
|
llm_parse_message=(
|
||||||
|
"The AI had trouble customizing the agent. "
|
||||||
|
"Please try again or simplify your request."
|
||||||
|
),
|
||||||
|
validation_message=(
|
||||||
|
"The customized agent failed validation. "
|
||||||
|
"Please try rephrasing your request."
|
||||||
|
),
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"customization_failed:{error_type}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle clarifying questions
|
||||||
|
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||||
|
questions = result.get("questions") or []
|
||||||
|
if not isinstance(questions, list):
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected clarifying questions format: {type(questions)}"
|
||||||
|
)
|
||||||
|
questions = []
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information to customize this agent. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=[
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", ""),
|
||||||
|
keyword=q.get("keyword", ""),
|
||||||
|
example=q.get("example"),
|
||||||
|
)
|
||||||
|
for q in questions
|
||||||
|
if isinstance(q, dict)
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result should be the customized agent JSON
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to customize the agent due to an unexpected response.",
|
||||||
|
error="unexpected_response_type",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
customized_agent = result
|
||||||
|
|
||||||
|
agent_name = customized_agent.get(
|
||||||
|
"name", f"Customized {agent_details.agent_name}"
|
||||||
|
)
|
||||||
|
agent_description = customized_agent.get("description", "")
|
||||||
|
nodes = customized_agent.get("nodes")
|
||||||
|
links = customized_agent.get("links")
|
||||||
|
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||||
|
link_count = len(links) if isinstance(links, list) else 0
|
||||||
|
|
||||||
|
if not save:
|
||||||
|
return AgentPreviewResponse(
|
||||||
|
message=(
|
||||||
|
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||||
|
f"The customized agent has {node_count} blocks. "
|
||||||
|
f"Review it and call customize_agent with save=true to save it."
|
||||||
|
),
|
||||||
|
agent_json=customized_agent,
|
||||||
|
agent_name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
node_count=node_count,
|
||||||
|
link_count=link_count,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="You must be logged in to save agents.",
|
||||||
|
error="auth_required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to user's library
|
||||||
|
try:
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
customized_agent, user_id, is_update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentSavedResponse(
|
||||||
|
message=(
|
||||||
|
f"Customized agent '{created_graph.name}' "
|
||||||
|
f"(based on '{agent_details.agent_name}') "
|
||||||
|
f"has been saved to your library!"
|
||||||
|
),
|
||||||
|
agent_id=created_graph.id,
|
||||||
|
agent_name=created_graph.name,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving customized agent: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to save the customized agent. Please try again.",
|
||||||
|
error="save_failed",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -17,6 +17,7 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
|
AsyncProcessingResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -104,6 +105,10 @@ class EditAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
# Extract async processing params (passed by long-running tool handler)
|
||||||
|
operation_id = kwargs.get("_operation_id")
|
||||||
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the agent ID to edit.",
|
message="Please provide the agent ID to edit.",
|
||||||
@@ -149,7 +154,11 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(
|
result = await generate_agent_patch(
|
||||||
update_request, current_agent, library_agents
|
update_request,
|
||||||
|
current_agent,
|
||||||
|
library_agents,
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -169,6 +178,20 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if Agent Generator accepted for async processing
|
||||||
|
if result.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Agent edit delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return AsyncProcessingResponse(
|
||||||
|
message="Agent edit started. You'll be notified when it's complete.",
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
|
|||||||
@@ -372,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase):
|
|||||||
|
|
||||||
This is returned immediately to the client while the operation continues
|
This is returned immediately to the client while the operation continues
|
||||||
to execute. The user can close the tab and check back later.
|
to execute. The user can close the tab and check back later.
|
||||||
|
|
||||||
|
The task_id can be used to reconnect to the SSE stream via
|
||||||
|
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
operation_id: str
|
operation_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
|
task_id: str | None = None # For SSE reconnection
|
||||||
|
|
||||||
|
|
||||||
class OperationPendingResponse(ToolResponseBase):
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
@@ -400,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncProcessingResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation has been delegated to async processing.
|
||||||
|
|
||||||
|
This is returned by tools when the external service accepts the request
|
||||||
|
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||||
|
consumer will handle the result when the external service completes.
|
||||||
|
|
||||||
|
The status field is specifically "accepted" to allow the long-running tool
|
||||||
|
handler to detect this response and skip LLM continuation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
status: str = "accepted" # Must be "accepted" for detection
|
||||||
|
operation_id: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
|||||||
@@ -8,7 +8,12 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import (
|
||||||
|
CredentialsFieldInfo,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
HostScopedCredentials,
|
||||||
|
OAuth2Credentials,
|
||||||
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -273,7 +278,14 @@ async def match_user_credentials_to_graph(
|
|||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
and _credential_has_required_scopes(cred, credential_requirements)
|
and (
|
||||||
|
cred.type != "oauth2"
|
||||||
|
or _credential_has_required_scopes(cred, credential_requirements)
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
cred.type != "host_scoped"
|
||||||
|
or _credential_is_for_host(cred, credential_requirements)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -318,19 +330,10 @@ async def match_user_credentials_to_graph(
|
|||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
credential: Credentials,
|
credential: OAuth2Credentials,
|
||||||
requirements: CredentialsFieldInfo,
|
requirements: CredentialsFieldInfo,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
||||||
Check if a credential has all the scopes required by the block.
|
|
||||||
|
|
||||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
|
||||||
of the required scopes. For other credential types, returns True (no scope check).
|
|
||||||
"""
|
|
||||||
# Only OAuth2 credentials have scopes to check
|
|
||||||
if credential.type != "oauth2":
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
@@ -339,6 +342,22 @@ def _credential_has_required_scopes(
|
|||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_is_for_host(
|
||||||
|
credential: HostScopedCredentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a host-scoped credential matches the host required by the input."""
|
||||||
|
# We need to know the host to match host-scoped credentials to.
|
||||||
|
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
||||||
|
# to discriminator_values. No discriminator_values -> no host to match against.
|
||||||
|
if not requirements.discriminator_values:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check that credential host matches required host.
|
||||||
|
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
||||||
|
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
|
on_graph_activate,
|
||||||
|
on_graph_deactivate,
|
||||||
|
)
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -537,6 +540,92 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new graph and add it to the user's library."""
|
||||||
|
graph.version = 1
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agents = await create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new version of an existing graph and update the library entry."""
|
||||||
|
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
||||||
|
current_active_version = (
|
||||||
|
next((v for v in existing_versions if v.is_active), None)
|
||||||
|
if existing_versions
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
graph.version = (
|
||||||
|
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||||
|
if not library_agent:
|
||||||
|
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||||
|
|
||||||
|
library_agent = await update_library_agent_version_and_settings(
|
||||||
|
user_id, created_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
await graph_db.set_graph_active_version(
|
||||||
|
graph_id=created_graph.id,
|
||||||
|
version=created_graph.version,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if current_active_version:
|
||||||
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agent
|
||||||
|
|
||||||
|
|
||||||
|
async def update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
"""Update library agent to point to new graph version and sync settings."""
|
||||||
|
library = await update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
cleanup_embeddings: list,
|
cleanup_embeddings: list,
|
||||||
):
|
):
|
||||||
"""Test unified search pagination works correctly."""
|
"""Test unified search pagination works correctly."""
|
||||||
|
# Use a unique search term to avoid matching other test data
|
||||||
|
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
# Create multiple items
|
# Create multiple items
|
||||||
content_ids = []
|
content_ids = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
content_id=content_id,
|
content_id=content_id,
|
||||||
embedding=mock_embedding,
|
embedding=mock_embedding,
|
||||||
searchable_text=f"pagination test item number {i}",
|
searchable_text=f"{unique_term} item number {i}",
|
||||||
metadata={"index": i},
|
metadata={"index": i},
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get first page
|
# Get first page
|
||||||
page1_results, total1 = await unified_hybrid_search(
|
page1_results, total1 = await unified_hybrid_search(
|
||||||
query="pagination test",
|
query=unique_term,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
@@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
|
|
||||||
# Get second page
|
# Get second page
|
||||||
page2_results, total2 = await unified_hybrid_search(
|
page2_results, total2 = await unified_hybrid_search(
|
||||||
query="pagination test",
|
query=unique_term,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=2,
|
page=2,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
|
|||||||
@@ -101,7 +101,6 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
from .library import model as library_model
|
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -823,18 +822,16 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
# Sanity check
|
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
# Determine new version
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
latest_version_number = max(g.version for g in existing_versions)
|
|
||||||
graph.version = latest_version_number + 1
|
|
||||||
|
|
||||||
|
graph.version = max(g.version for g in existing_versions) + 1
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -842,27 +839,23 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
# Keep the library agent up to date with the new active version
|
await library_db.update_library_agent_version_and_settings(
|
||||||
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
user_id, new_graph_version
|
||||||
|
)
|
||||||
# Handle activation of the new graph first to ensure continuity
|
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
# Ensure new version is the only active version
|
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
# Handle deactivation of the previously active version
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs # make type checker happy
|
assert new_graph_version_with_subgraphs
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -900,33 +893,15 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
await library_db.update_library_agent_version_and_settings(
|
||||||
|
user_id, new_active_graph
|
||||||
|
)
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
library = await library_db.update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await library_db.update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
|
from backend.api.features.chat.completion_consumer import (
|
||||||
|
start_completion_consumer,
|
||||||
|
stop_completion_consumer,
|
||||||
|
)
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
|
# Start chat completion consumer for Redis Streams notifications
|
||||||
|
try:
|
||||||
|
await start_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Stop chat completion consumer
|
||||||
|
try:
|
||||||
|
await stop_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||||
|
title="Mock ElevenLabs API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
ElevenLabsCredentials = APIKeyCredentials
|
||||||
|
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||||
|
]
|
||||||
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Text encoding block for converting special characters to escape sequences."""
|
||||||
|
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderBlock(Block):
|
||||||
|
"""
|
||||||
|
Encodes a string by converting special characters into escape sequences.
|
||||||
|
|
||||||
|
This block is the inverse of TextDecoderBlock. It takes text containing
|
||||||
|
special characters (like newlines, tabs, etc.) and converts them into
|
||||||
|
their escape sequence representations (e.g., newline becomes \\n).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
"""Input schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
text: str = SchemaField(
|
||||||
|
description="A string containing special characters to be encoded",
|
||||||
|
placeholder="Your text with newlines and quotes to encode",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
"""Output schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
encoded_text: str = SchemaField(
|
||||||
|
description="The encoded text with special characters converted to escape sequences"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if encoding fails")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
||||||
|
description="Encodes a string by converting special characters into escape sequences",
|
||||||
|
categories={BlockCategory.TEXT},
|
||||||
|
input_schema=TextEncoderBlock.Input,
|
||||||
|
output_schema=TextEncoderBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"text": """Hello
|
||||||
|
World!
|
||||||
|
This is a "quoted" string."""
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"encoded_text",
|
||||||
|
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Encode the input text by converting special characters to escape sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: The input containing the text to encode.
|
||||||
|
**kwargs: Additional keyword arguments (unused).
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The encoded text with escape sequences, or an error message if encoding fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
yield "encoded_text", encoded_text
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Encoding error: {str(e)}"
|
||||||
@@ -162,8 +162,16 @@ class LinearClient:
|
|||||||
"searchTerm": team_name,
|
"searchTerm": team_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
team_id = await self.query(query, variables)
|
result = await self.query(query, variables)
|
||||||
return team_id["teams"]["nodes"][0]["id"]
|
nodes = result["teams"]["nodes"]
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
raise LinearAPIException(
|
||||||
|
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nodes[0]["id"]
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -240,17 +248,44 @@ class LinearClient:
|
|||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def try_search_issues(self, term: str) -> list[Issue]:
|
async def try_search_issues(
|
||||||
|
self,
|
||||||
|
term: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
team_id: str | None = None,
|
||||||
|
) -> list[Issue]:
|
||||||
try:
|
try:
|
||||||
query = """
|
query = """
|
||||||
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
query SearchIssues(
|
||||||
searchIssues(term: $term, includeComments: $includeComments) {
|
$term: String!,
|
||||||
|
$first: Int,
|
||||||
|
$teamId: String
|
||||||
|
) {
|
||||||
|
searchIssues(
|
||||||
|
term: $term,
|
||||||
|
first: $first,
|
||||||
|
teamId: $teamId
|
||||||
|
) {
|
||||||
nodes {
|
nodes {
|
||||||
id
|
id
|
||||||
identifier
|
identifier
|
||||||
title
|
title
|
||||||
description
|
description
|
||||||
priority
|
priority
|
||||||
|
createdAt
|
||||||
|
state {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
type
|
||||||
|
}
|
||||||
|
project {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
assignee {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,7 +293,8 @@ class LinearClient:
|
|||||||
|
|
||||||
variables: dict[str, Any] = {
|
variables: dict[str, Any] = {
|
||||||
"term": term,
|
"term": term,
|
||||||
"includeComments": True,
|
"first": max_results,
|
||||||
|
"teamId": team_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
issues = await self.query(query, variables)
|
issues = await self.query(query, variables)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from ._config import (
|
|||||||
LinearScope,
|
LinearScope,
|
||||||
linear,
|
linear,
|
||||||
)
|
)
|
||||||
from .models import CreateIssueResponse, Issue
|
from .models import CreateIssueResponse, Issue, State
|
||||||
|
|
||||||
|
|
||||||
class LinearCreateIssueBlock(Block):
|
class LinearCreateIssueBlock(Block):
|
||||||
@@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Linear credentials with read permissions",
|
description="Linear credentials with read permissions",
|
||||||
required_scopes={LinearScope.READ},
|
required_scopes={LinearScope.READ},
|
||||||
)
|
)
|
||||||
|
max_results: int = SchemaField(
|
||||||
|
description="Maximum number of results to return",
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
)
|
||||||
|
team_name: str | None = SchemaField(
|
||||||
|
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
issues: list[Issue] = SchemaField(description="List of issues")
|
issues: list[Issue] = SchemaField(description="List of issues")
|
||||||
|
error: str = SchemaField(description="Error message if the search failed")
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Searches for issues on Linear",
|
description="Searches for issues on Linear",
|
||||||
input_schema=self.Input,
|
input_schema=self.Input,
|
||||||
output_schema=self.Output,
|
output_schema=self.Output,
|
||||||
|
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||||
test_input={
|
test_input={
|
||||||
"term": "Test issue",
|
"term": "Test issue",
|
||||||
|
"max_results": 10,
|
||||||
|
"team_name": None,
|
||||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
@@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
[
|
[
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="abc123",
|
identifier="TST-123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
|
state=State(
|
||||||
|
id="state1", name="In Progress", type="started"
|
||||||
|
),
|
||||||
|
createdAt="2026-01-15T10:00:00.000Z",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"search_issues": lambda *args, **kwargs: [
|
"search_issues": lambda *args, **kwargs: [
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="abc123",
|
identifier="TST-123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
|
state=State(id="state1", name="In Progress", type="started"),
|
||||||
|
createdAt="2026-01-15T10:00:00.000Z",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
async def search_issues(
|
async def search_issues(
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||||
term: str,
|
term: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
team_name: str | None = None,
|
||||||
) -> list[Issue]:
|
) -> list[Issue]:
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
response: list[Issue] = await client.try_search_issues(term=term)
|
|
||||||
return response
|
# Resolve team name to ID if provided
|
||||||
|
# Raises LinearAPIException with descriptive message if team not found
|
||||||
|
team_id: str | None = None
|
||||||
|
if team_name:
|
||||||
|
team_id = await client.try_get_team_by_name(team_name=team_name)
|
||||||
|
|
||||||
|
return await client.try_search_issues(
|
||||||
|
term=term,
|
||||||
|
max_results=max_results,
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"""Execute the issue search"""
|
"""Execute the issue search"""
|
||||||
try:
|
try:
|
||||||
issues = await self.search_issues(
|
issues = await self.search_issues(
|
||||||
credentials=credentials, term=input_data.term
|
credentials=credentials,
|
||||||
|
term=input_data.term,
|
||||||
|
max_results=input_data.max_results,
|
||||||
|
team_name=input_data.team_name,
|
||||||
)
|
)
|
||||||
yield "issues", issues
|
yield "issues", issues
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
|
|||||||
@@ -36,12 +36,21 @@ class Project(BaseModel):
|
|||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class State(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: str | None = (
|
||||||
|
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Issue(BaseModel):
|
class Issue(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
identifier: str
|
identifier: str
|
||||||
title: str
|
title: str
|
||||||
description: str | None
|
description: str | None
|
||||||
priority: int
|
priority: int
|
||||||
|
state: State | None = None
|
||||||
project: Project | None = None
|
project: Project | None = None
|
||||||
createdAt: str | None = None
|
createdAt: str | None = None
|
||||||
comments: list[Comment] | None = None
|
comments: list[Comment] | None = None
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
|
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -270,6 +271,9 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
), # claude-4-sonnet-20250514
|
), # claude-4-sonnet-20250514
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-opus-4-6
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
|
|||||||
@@ -1,246 +0,0 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.fx.Loop import Loop
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class MediaDurationBlock(Block):
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
media_in: MediaFileType = SchemaField(
|
|
||||||
description="Media input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
is_video: bool = SchemaField(
|
|
||||||
description="Whether the media is a video (True) or audio (False).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
duration: float = SchemaField(
|
|
||||||
description="Duration of the media file (in seconds)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
|
||||||
description="Block to get the duration of a media file.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=MediaDurationBlock.Input,
|
|
||||||
output_schema=MediaDurationBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input media locally
|
|
||||||
local_media_path = await store_media_file(
|
|
||||||
file=input_data.media_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
media_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_media_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
if input_data.is_video:
|
|
||||||
clip = VideoFileClip(media_abspath)
|
|
||||||
else:
|
|
||||||
clip = AudioFileClip(media_abspath)
|
|
||||||
|
|
||||||
yield "duration", clip.duration
|
|
||||||
|
|
||||||
|
|
||||||
class LoopVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="The input video (can be a URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
|
||||||
duration: Optional[float] = SchemaField(
|
|
||||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
)
|
|
||||||
n_loops: Optional[int] = SchemaField(
|
|
||||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
|
||||||
default=None,
|
|
||||||
ge=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: str = SchemaField(
|
|
||||||
description="Looped video returned either as a relative path or a data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
|
||||||
description="Block to loop a video to a given duration or number of repeats.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=LoopVideoBlock.Input,
|
|
||||||
output_schema=LoopVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the input video locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
clip = VideoFileClip(input_abspath)
|
|
||||||
|
|
||||||
# 3) Apply the loop effect
|
|
||||||
looped_clip = clip
|
|
||||||
if input_data.duration:
|
|
||||||
# Loop until we reach the specified duration
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
|
||||||
elif input_data.n_loops:
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
|
||||||
else:
|
|
||||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
|
||||||
|
|
||||||
assert isinstance(looped_clip, VideoFileClip)
|
|
||||||
|
|
||||||
# 4) Save the looped output
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
|
|
||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
|
|
||||||
|
|
||||||
class AddAudioToVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that adds (attaches) an audio track to an existing video.
|
|
||||||
Optionally scale the volume of the new track.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Video input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
audio_in: MediaFileType = SchemaField(
|
|
||||||
description="Audio input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
volume: float = SchemaField(
|
|
||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Final video (with attached audio), as a path or data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
|
||||||
description="Block to attach an audio file to a video file using moviepy.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=AddAudioToVideoBlock.Input,
|
|
||||||
output_schema=AddAudioToVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
local_audio_path = await store_media_file(
|
|
||||||
file=input_data.audio_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
|
||||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
|
||||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
|
||||||
|
|
||||||
# 2) Load video + audio with moviepy
|
|
||||||
video_clip = VideoFileClip(video_abspath)
|
|
||||||
audio_clip = AudioFileClip(audio_abspath)
|
|
||||||
# Optionally scale volume
|
|
||||||
if input_data.volume != 1.0:
|
|
||||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
|
||||||
|
|
||||||
# 3) Attach the new audio track
|
|
||||||
final_clip = video_clip.with_audio(audio_clip)
|
|
||||||
|
|
||||||
# 4) Write to output file
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -182,10 +182,7 @@ class StagehandObserveBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||||
logger.info(
|
|
||||||
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -282,10 +279,7 @@ class StagehandActBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||||
logger.info(
|
|
||||||
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -370,10 +364,7 @@ class StagehandExtractBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||||
logger.info(
|
|
||||||
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.encoder_block import TextEncoderBlock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_basic():
|
||||||
|
"""Test basic encoding of newlines and special characters."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == "Hello\\nWorld"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_multiple_escapes():
|
||||||
|
"""Test encoding of multiple escape sequences."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(
|
||||||
|
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
||||||
|
):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
assert "\\t" in result[0][1]
|
||||||
|
assert "\\r" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_unicode():
|
||||||
|
"""Test that unicode characters are handled correctly."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
# Unicode characters should be escaped as \uXXXX sequences
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_empty_string():
|
||||||
|
"""Test encoding of an empty string."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_error_handling():
|
||||||
|
"""Test that encoding errors are handled gracefully."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
|
||||||
|
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "error"
|
||||||
|
assert "Mocked encoding error" in result[0][1]
|
||||||
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Video editing blocks for AutoGPT Platform.
|
||||||
|
|
||||||
|
This module provides blocks for:
|
||||||
|
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
||||||
|
- Clipping/trimming video segments
|
||||||
|
- Concatenating multiple videos
|
||||||
|
- Adding text overlays
|
||||||
|
- Adding AI-generated narration
|
||||||
|
- Getting media duration
|
||||||
|
- Looping videos
|
||||||
|
- Adding audio to videos
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- yt-dlp: For video downloading
|
||||||
|
- moviepy: For video editing operations
|
||||||
|
- elevenlabs: For AI narration (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
||||||
|
from backend.blocks.video.clip import VideoClipBlock
|
||||||
|
from backend.blocks.video.concat import VideoConcatBlock
|
||||||
|
from backend.blocks.video.download import VideoDownloadBlock
|
||||||
|
from backend.blocks.video.duration import MediaDurationBlock
|
||||||
|
from backend.blocks.video.loop import LoopVideoBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
|
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AddAudioToVideoBlock",
|
||||||
|
"LoopVideoBlock",
|
||||||
|
"MediaDurationBlock",
|
||||||
|
"VideoClipBlock",
|
||||||
|
"VideoConcatBlock",
|
||||||
|
"VideoDownloadBlock",
|
||||||
|
"VideoNarrationBlock",
|
||||||
|
"VideoTextOverlayBlock",
|
||||||
|
]
|
||||||
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Shared utilities for video blocks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Known operation tags added by video blocks
|
||||||
|
_VIDEO_OPS = (
|
||||||
|
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
||||||
|
_BLOCK_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*"
|
||||||
|
r"_" + _VIDEO_OPS + r"_"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
||||||
|
_UUID_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*_"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
||||||
|
"""Extract the original source filename by stripping block-generated prefixes.
|
||||||
|
|
||||||
|
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
||||||
|
when chaining video blocks, recovering the original human-readable name.
|
||||||
|
|
||||||
|
Safe for plain filenames (no UUID -> no stripping).
|
||||||
|
Falls back to "video" if everything is stripped.
|
||||||
|
"""
|
||||||
|
stem = Path(input_path).stem
|
||||||
|
|
||||||
|
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
||||||
|
while _BLOCK_PREFIX_RE.match(stem):
|
||||||
|
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
||||||
|
if _UUID_PREFIX_RE.match(stem):
|
||||||
|
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
if not stem:
|
||||||
|
return "video"
|
||||||
|
|
||||||
|
return stem[:max_length]
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
||||||
|
"""Get appropriate video and audio codecs based on output file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path to the output file (used to determine extension)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (video_codec, audio_codec)
|
||||||
|
|
||||||
|
Codec mappings:
|
||||||
|
- .mp4: H.264 + AAC (universal compatibility)
|
||||||
|
- .webm: VP8 + Vorbis (web streaming)
|
||||||
|
- .mkv: H.264 + AAC (container supports many codecs)
|
||||||
|
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
||||||
|
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
||||||
|
- .avi: MPEG-4 + MP3 (legacy Windows)
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(output_path)[1].lower()
|
||||||
|
|
||||||
|
codec_map: dict[str, tuple[str, str]] = {
|
||||||
|
".mp4": ("libx264", "aac"),
|
||||||
|
".webm": ("libvpx", "libvorbis"),
|
||||||
|
".mkv": ("libx264", "aac"),
|
||||||
|
".mov": ("libx264", "aac"),
|
||||||
|
".m4v": ("libx264", "aac"),
|
||||||
|
".avi": ("mpeg4", "libmp3lame"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return codec_map.get(ext, ("libx264", "aac"))
|
||||||
|
|
||||||
|
|
||||||
|
def strip_chapters_inplace(video_path: str) -> None:
|
||||||
|
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
||||||
|
|
||||||
|
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
||||||
|
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
||||||
|
This strips chapters without re-encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Absolute path to the media file to strip chapters from.
|
||||||
|
"""
|
||||||
|
base, ext = os.path.splitext(video_path)
|
||||||
|
tmp_path = base + ".tmp" + ext
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
video_path,
|
||||||
|
"-map_chapters",
|
||||||
|
"-1",
|
||||||
|
"-codec",
|
||||||
|
"copy",
|
||||||
|
tmp_path,
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.warning(
|
||||||
|
"ffmpeg chapter strip failed (rc=%d): %s",
|
||||||
|
result.returncode,
|
||||||
|
result.stderr,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
os.replace(tmp_path, video_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning("ffmpeg not found; skipping chapter strip")
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.unlink(tmp_path)
|
||||||
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class AddAudioToVideoBlock(Block):
|
||||||
|
"""Add (attach) an audio track to an existing video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Video input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
audio_in: MediaFileType = SchemaField(
|
||||||
|
description="Audio input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
volume: float = SchemaField(
|
||||||
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Final video (with attached audio), as a path or data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||||
|
description="Block to attach an audio file to a video file using moviepy.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=AddAudioToVideoBlock.Input,
|
||||||
|
output_schema=AddAudioToVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the inputs locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
local_audio_path = await store_media_file(
|
||||||
|
file=input_data.audio_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
||||||
|
|
||||||
|
# 2) Load video + audio with moviepy
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
strip_chapters_inplace(audio_abspath)
|
||||||
|
video_clip = None
|
||||||
|
audio_clip = None
|
||||||
|
final_clip = None
|
||||||
|
try:
|
||||||
|
video_clip = VideoFileClip(video_abspath)
|
||||||
|
audio_clip = AudioFileClip(audio_abspath)
|
||||||
|
# Optionally scale volume
|
||||||
|
if input_data.volume != 1.0:
|
||||||
|
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||||
|
|
||||||
|
# 3) Attach the new audio track
|
||||||
|
final_clip = video_clip.with_audio(audio_clip)
|
||||||
|
|
||||||
|
# 4) Write to output file
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
final_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if final_clip:
|
||||||
|
final_clip.close()
|
||||||
|
if audio_clip:
|
||||||
|
audio_clip.close()
|
||||||
|
if video_clip:
|
||||||
|
video_clip.close()
|
||||||
|
|
||||||
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""VideoClipBlock - Extract a segment from a video file."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoClipBlock(Block):
|
||||||
|
"""Extract a time segment from a video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
||||||
|
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Clipped video file (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Clip duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
||||||
|
description="Extract a time segment from a video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"end_time": 10.0,
|
||||||
|
},
|
||||||
|
test_output=[("video_out", str), ("duration", float)],
|
||||||
|
test_mock={
|
||||||
|
"_clip_video": lambda *args: 10.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clip_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> float:
|
||||||
|
"""Extract a clip from a video. Extracted for testability."""
|
||||||
|
clip = None
|
||||||
|
subclip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
clip = VideoFileClip(video_abspath)
|
||||||
|
subclip = clip.subclipped(start_time, end_time)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
subclip.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
return subclip.duration
|
||||||
|
finally:
|
||||||
|
if subclip:
|
||||||
|
subclip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range
|
||||||
|
if input_data.end_time <= input_data.start_time:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = self._clip_video(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to clip video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import concatenate_videoclips
|
||||||
|
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoConcatBlock(Block):
|
||||||
|
"""Merge multiple video clips into one continuous video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
videos: list[MediaFileType] = SchemaField(
|
||||||
|
description="List of video files to concatenate (in order)"
|
||||||
|
)
|
||||||
|
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
||||||
|
description="Transition between clips", default="none"
|
||||||
|
)
|
||||||
|
transition_duration: int = SchemaField(
|
||||||
|
description="Transition duration in seconds",
|
||||||
|
default=1,
|
||||||
|
ge=0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Concatenated video file (path or data URI)"
|
||||||
|
)
|
||||||
|
total_duration: float = SchemaField(description="Total duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
||||||
|
description="Merge multiple video clips into one continuous video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_out", str),
|
||||||
|
("total_duration", float),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_concat_videos": lambda *args: 20.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _concat_videos(
|
||||||
|
self,
|
||||||
|
video_abspaths: list[str],
|
||||||
|
output_abspath: str,
|
||||||
|
transition: str,
|
||||||
|
transition_duration: int,
|
||||||
|
) -> float:
|
||||||
|
"""Concatenate videos. Extracted for testability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total duration of the concatenated video.
|
||||||
|
"""
|
||||||
|
clips = []
|
||||||
|
faded_clips = []
|
||||||
|
final = None
|
||||||
|
try:
|
||||||
|
# Load clips
|
||||||
|
for v in video_abspaths:
|
||||||
|
strip_chapters_inplace(v)
|
||||||
|
clips.append(VideoFileClip(v))
|
||||||
|
|
||||||
|
# Validate transition_duration against shortest clip
|
||||||
|
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
||||||
|
min_duration = min(c.duration for c in clips)
|
||||||
|
if transition_duration >= min_duration:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=(
|
||||||
|
f"transition_duration ({transition_duration}s) must be "
|
||||||
|
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
||||||
|
),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
if transition == "crossfade":
|
||||||
|
for i, clip in enumerate(clips):
|
||||||
|
effects = []
|
||||||
|
if i > 0:
|
||||||
|
effects.append(CrossFadeIn(transition_duration))
|
||||||
|
if i < len(clips) - 1:
|
||||||
|
effects.append(CrossFadeOut(transition_duration))
|
||||||
|
if effects:
|
||||||
|
clip = clip.with_effects(effects)
|
||||||
|
faded_clips.append(clip)
|
||||||
|
final = concatenate_videoclips(
|
||||||
|
faded_clips,
|
||||||
|
method="compose",
|
||||||
|
padding=-transition_duration,
|
||||||
|
)
|
||||||
|
elif transition == "fade_black":
|
||||||
|
for clip in clips:
|
||||||
|
faded = clip.with_effects(
|
||||||
|
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
||||||
|
)
|
||||||
|
faded_clips.append(faded)
|
||||||
|
final = concatenate_videoclips(faded_clips)
|
||||||
|
else:
|
||||||
|
final = concatenate_videoclips(clips)
|
||||||
|
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
return final.duration
|
||||||
|
finally:
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
for clip in faded_clips:
|
||||||
|
clip.close()
|
||||||
|
for clip in clips:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate minimum clips
|
||||||
|
if len(input_data.videos) < 2:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message="At least 2 videos are required for concatenation",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store all input videos locally
|
||||||
|
video_abspaths = []
|
||||||
|
for video in input_data.videos:
|
||||||
|
local_path = await self._store_input_video(execution_context, video)
|
||||||
|
video_abspaths.append(
|
||||||
|
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = (
|
||||||
|
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
||||||
|
)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
total_duration = self._concat_videos(
|
||||||
|
video_abspaths,
|
||||||
|
output_abspath,
|
||||||
|
input_data.transition,
|
||||||
|
input_data.transition_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "total_duration", total_duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to concatenate videos: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import typing
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import yt_dlp
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from yt_dlp import _Params
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDownloadBlock(Block):
|
||||||
|
"""Download video from URL using yt-dlp."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
url: str = SchemaField(
|
||||||
|
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
||||||
|
placeholder="https://www.youtube.com/watch?v=...",
|
||||||
|
)
|
||||||
|
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
||||||
|
description="Video quality preference", default="720p"
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
||||||
|
description="Output video format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_file: MediaFileType = SchemaField(
|
||||||
|
description="Downloaded video (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Video duration in seconds")
|
||||||
|
title: str = SchemaField(description="Video title from source")
|
||||||
|
source_url: str = SchemaField(description="Original source URL")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
||||||
|
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
||||||
|
test_input={
|
||||||
|
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||||
|
"quality": "480p",
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_file", str),
|
||||||
|
("duration", float),
|
||||||
|
("title", str),
|
||||||
|
("source_url", str),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_download_video": lambda *args: (
|
||||||
|
"video.mp4",
|
||||||
|
212.0,
|
||||||
|
"Test Video",
|
||||||
|
),
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_format_string(self, quality: str) -> str:
|
||||||
|
formats = {
|
||||||
|
"best": "bestvideo+bestaudio/best",
|
||||||
|
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||||
|
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
||||||
|
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
||||||
|
"audio_only": "bestaudio/best",
|
||||||
|
}
|
||||||
|
return formats.get(quality, formats["720p"])
|
||||||
|
|
||||||
|
def _download_video(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
quality: str,
|
||||||
|
output_format: str,
|
||||||
|
output_dir: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
) -> tuple[str, float, str]:
|
||||||
|
"""Download video. Extracted for testability."""
|
||||||
|
output_template = os.path.join(
|
||||||
|
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
ydl_opts: "_Params" = {
|
||||||
|
"format": f"{self._get_format_string(quality)}/best",
|
||||||
|
"outtmpl": output_template,
|
||||||
|
"merge_output_format": output_format,
|
||||||
|
"quiet": True,
|
||||||
|
"no_warnings": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
info = ydl.extract_info(url, download=True)
|
||||||
|
video_path = ydl.prepare_filename(info)
|
||||||
|
|
||||||
|
# Handle format conversion in filename
|
||||||
|
if not video_path.endswith(f".{output_format}"):
|
||||||
|
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
||||||
|
|
||||||
|
# Return just the filename, not the full path
|
||||||
|
filename = os.path.basename(video_path)
|
||||||
|
|
||||||
|
return (
|
||||||
|
filename,
|
||||||
|
info.get("duration") or 0.0,
|
||||||
|
info.get("title") or "Unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Get the exec file directory
|
||||||
|
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
filename, duration, title = self._download_video(
|
||||||
|
input_data.url,
|
||||||
|
input_data.quality,
|
||||||
|
input_data.output_format,
|
||||||
|
output_dir,
|
||||||
|
node_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, MediaFileType(filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_file", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
yield "title", title
|
||||||
|
yield "source_url", input_data.url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to download video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""MediaDurationBlock - Get the duration of a media file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class MediaDurationBlock(Block):
|
||||||
|
"""Get the duration of a media file (video or audio)."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
media_in: MediaFileType = SchemaField(
|
||||||
|
description="Media input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (True) or audio (False).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
duration: float = SchemaField(
|
||||||
|
description="Duration of the media file (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||||
|
description="Block to get the duration of a media file.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=MediaDurationBlock.Input,
|
||||||
|
output_schema=MediaDurationBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# 1) Store the input media locally
|
||||||
|
local_media_path = await store_media_file(
|
||||||
|
file=input_data.media_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
||||||
|
strip_chapters_inplace(media_abspath)
|
||||||
|
clip = None
|
||||||
|
try:
|
||||||
|
if input_data.is_video:
|
||||||
|
clip = VideoFileClip(media_abspath)
|
||||||
|
else:
|
||||||
|
clip = AudioFileClip(media_abspath)
|
||||||
|
|
||||||
|
duration = clip.duration
|
||||||
|
finally:
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
yield "duration", duration
|
||||||
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from moviepy.video.fx.Loop import Loop
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVideoBlock(Block):
|
||||||
|
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="The input video (can be a URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
duration: Optional[float] = SchemaField(
|
||||||
|
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
n_loops: Optional[int] = SchemaField(
|
||||||
|
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
le=10, # Max 10 loops to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Looped video returned either as a relative path or a data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||||
|
description="Block to loop a video to a given duration or number of repeats.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=LoopVideoBlock.Input,
|
||||||
|
output_schema=LoopVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the input video locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
strip_chapters_inplace(input_abspath)
|
||||||
|
clip = None
|
||||||
|
looped_clip = None
|
||||||
|
try:
|
||||||
|
clip = VideoFileClip(input_abspath)
|
||||||
|
|
||||||
|
# 3) Apply the loop effect
|
||||||
|
if input_data.duration:
|
||||||
|
# Loop until we reach the specified duration
|
||||||
|
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
||||||
|
elif input_data.n_loops:
|
||||||
|
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
||||||
|
else:
|
||||||
|
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(looped_clip, VideoFileClip)
|
||||||
|
|
||||||
|
# 4) Save the looped output
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
|
||||||
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
|
looped_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if looped_clip:
|
||||||
|
looped_clip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from elevenlabs import ElevenLabs
|
||||||
|
from moviepy import CompositeAudioClip
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.elevenlabs._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
ElevenLabsCredentials,
|
||||||
|
ElevenLabsCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoNarrationBlock(Block):
|
||||||
|
"""Generate AI narration and add to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
||||||
|
description="ElevenLabs API key for voice synthesis"
|
||||||
|
)
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
script: str = SchemaField(description="Narration script text")
|
||||||
|
voice_id: str = SchemaField(
|
||||||
|
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||||
|
)
|
||||||
|
model_id: Literal[
|
||||||
|
"eleven_multilingual_v2",
|
||||||
|
"eleven_flash_v2_5",
|
||||||
|
"eleven_turbo_v2_5",
|
||||||
|
"eleven_turbo_v2",
|
||||||
|
] = SchemaField(
|
||||||
|
description="ElevenLabs TTS model",
|
||||||
|
default="eleven_multilingual_v2",
|
||||||
|
)
|
||||||
|
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
||||||
|
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
||||||
|
default="ducking",
|
||||||
|
)
|
||||||
|
narration_volume: float = SchemaField(
|
||||||
|
description="Narration volume (0.0 to 2.0)",
|
||||||
|
default=1.0,
|
||||||
|
ge=0.0,
|
||||||
|
le=2.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
original_volume: float = SchemaField(
|
||||||
|
description="Original audio volume when mixing (0.0 to 1.0)",
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with narration (path or data URI)"
|
||||||
|
)
|
||||||
|
audio_file: MediaFileType = SchemaField(
|
||||||
|
description="Generated audio file (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
||||||
|
description="Generate AI narration and add to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"script": "Hello world",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("video_out", str), ("audio_file", str)],
|
||||||
|
test_mock={
|
||||||
|
"_generate_narration_audio": lambda *args: b"mock audio content",
|
||||||
|
"_add_narration_to_video": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_narration_audio(
|
||||||
|
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||||
|
) -> bytes:
|
||||||
|
"""Generate narration audio via ElevenLabs API."""
|
||||||
|
client = ElevenLabs(api_key=api_key)
|
||||||
|
audio_generator = client.text_to_speech.convert(
|
||||||
|
voice_id=voice_id,
|
||||||
|
text=script,
|
||||||
|
model_id=model_id,
|
||||||
|
)
|
||||||
|
# The SDK returns a generator, collect all chunks
|
||||||
|
return b"".join(audio_generator)
|
||||||
|
|
||||||
|
def _add_narration_to_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
audio_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
mix_mode: str,
|
||||||
|
narration_volume: float,
|
||||||
|
original_volume: float,
|
||||||
|
) -> None:
|
||||||
|
"""Add narration audio to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
narration_original = None
|
||||||
|
narration_scaled = None
|
||||||
|
original = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
narration_original = AudioFileClip(audio_abspath)
|
||||||
|
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
||||||
|
narration = narration_scaled
|
||||||
|
|
||||||
|
if mix_mode == "replace":
|
||||||
|
final_audio = narration
|
||||||
|
elif mix_mode == "mix":
|
||||||
|
if video.audio:
|
||||||
|
original = video.audio.with_volume_scaled(original_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
else: # ducking - apply stronger attenuation
|
||||||
|
if video.audio:
|
||||||
|
# Ducking uses a much lower volume for original audio
|
||||||
|
ducking_volume = original_volume * 0.3
|
||||||
|
original = video.audio.with_volume_scaled(ducking_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
|
||||||
|
final = video.with_audio(final_audio)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if original:
|
||||||
|
original.close()
|
||||||
|
if narration_scaled:
|
||||||
|
narration_scaled.close()
|
||||||
|
if narration_original:
|
||||||
|
narration_original.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: ElevenLabsCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate narration audio via ElevenLabs
|
||||||
|
audio_content = self._generate_narration_audio(
|
||||||
|
credentials.api_key.get_secret_value(),
|
||||||
|
input_data.script,
|
||||||
|
input_data.voice_id,
|
||||||
|
input_data.model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save audio to exec file path
|
||||||
|
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||||
|
audio_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, audio_filename
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
||||||
|
with open(audio_abspath, "wb") as f:
|
||||||
|
f.write(audio_content)
|
||||||
|
|
||||||
|
# Add narration to video
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_narration_to_video(
|
||||||
|
video_abspath,
|
||||||
|
audio_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.mix_mode,
|
||||||
|
input_data.narration_volume,
|
||||||
|
input_data.original_volume,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
audio_out = await self._store_output_video(
|
||||||
|
execution_context, audio_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "audio_file", audio_out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add narration: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""VideoTextOverlayBlock - Add text overlay to video."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import CompositeVideoClip, TextClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTextOverlayBlock(Block):
|
||||||
|
"""Add text overlay/caption to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
text: str = SchemaField(description="Text to overlay on video")
|
||||||
|
position: Literal[
|
||||||
|
"top",
|
||||||
|
"center",
|
||||||
|
"bottom",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
] = SchemaField(description="Position of text on screen", default="bottom")
|
||||||
|
start_time: float | None = SchemaField(
|
||||||
|
description="When to show text (seconds). None = entire video",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
end_time: float | None = SchemaField(
|
||||||
|
description="When to hide text (seconds). None = until end",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
font_size: int = SchemaField(
|
||||||
|
description="Font size", default=48, ge=12, le=200, advanced=True
|
||||||
|
)
|
||||||
|
font_color: str = SchemaField(
|
||||||
|
description="Font color (hex or name)", default="white", advanced=True
|
||||||
|
)
|
||||||
|
bg_color: str | None = SchemaField(
|
||||||
|
description="Background color behind text (None for transparent)",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with text overlay (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
||||||
|
description="Add text overlay/caption to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can lockdown imagemagick security policy
|
||||||
|
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
||||||
|
test_output=[("video_out", str)],
|
||||||
|
test_mock={
|
||||||
|
"_add_text_overlay": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_text_overlay(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
text: str,
|
||||||
|
position: str,
|
||||||
|
start_time: float | None,
|
||||||
|
end_time: float | None,
|
||||||
|
font_size: int,
|
||||||
|
font_color: str,
|
||||||
|
bg_color: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Add text overlay to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
txt_clip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
|
||||||
|
txt_clip = TextClip(
|
||||||
|
text=text,
|
||||||
|
font_size=font_size,
|
||||||
|
color=font_color,
|
||||||
|
bg_color=bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position mapping
|
||||||
|
pos_map = {
|
||||||
|
"top": ("center", "top"),
|
||||||
|
"center": ("center", "center"),
|
||||||
|
"bottom": ("center", "bottom"),
|
||||||
|
"top-left": ("left", "top"),
|
||||||
|
"top-right": ("right", "top"),
|
||||||
|
"bottom-left": ("left", "bottom"),
|
||||||
|
"bottom-right": ("right", "bottom"),
|
||||||
|
}
|
||||||
|
|
||||||
|
txt_clip = txt_clip.with_position(pos_map[position])
|
||||||
|
|
||||||
|
# Set timing
|
||||||
|
start = start_time or 0
|
||||||
|
end = end_time or video.duration
|
||||||
|
duration = max(0, end - start)
|
||||||
|
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
||||||
|
|
||||||
|
final = CompositeVideoClip([video, txt_clip])
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if txt_clip:
|
||||||
|
txt_clip.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range if both are provided
|
||||||
|
if (
|
||||||
|
input_data.start_time is not None
|
||||||
|
and input_data.end_time is not None
|
||||||
|
and input_data.end_time <= input_data.start_time
|
||||||
|
):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_text_overlay(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.text,
|
||||||
|
input_data.position,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
input_data.font_size,
|
||||||
|
input_data.font_color,
|
||||||
|
input_data.bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add text overlay: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
@@ -165,10 +165,13 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
try:
|
||||||
yield "video_id", video_id
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
|
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
# Only yield after all operations succeed
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
yield "video_id", video_id
|
||||||
|
yield "transcript", transcript_text
|
||||||
yield "transcript", transcript_text
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|||||||
@@ -873,14 +873,13 @@ def is_block_auth_configured(
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_blocks() -> None:
|
async def initialize_blocks() -> None:
|
||||||
# First, sync all provider costs to blocks
|
|
||||||
# Imported here to avoid circular import
|
|
||||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||||
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
sync_all_provider_costs()
|
sync_all_provider_costs()
|
||||||
|
|
||||||
for cls in get_blocks().values():
|
@func_retry
|
||||||
block = cls()
|
async def sync_block_to_db(block: Block) -> None:
|
||||||
existing_block = await AgentBlock.prisma().find_first(
|
existing_block = await AgentBlock.prisma().find_first(
|
||||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||||
)
|
)
|
||||||
@@ -893,7 +892,7 @@ async def initialize_blocks() -> None:
|
|||||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
continue
|
return
|
||||||
|
|
||||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||||
@@ -913,6 +912,25 @@ async def initialize_blocks() -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
failed_blocks: list[str] = []
|
||||||
|
for cls in get_blocks().values():
|
||||||
|
block = cls()
|
||||||
|
try:
|
||||||
|
await sync_block_to_db(block)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to sync block {block.name} to database: {e}. "
|
||||||
|
"Block is still available in memory.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
failed_blocks.append(block.name)
|
||||||
|
|
||||||
|
if failed_blocks:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
||||||
|
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||||
|
|||||||
@@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
|||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
from backend.data.block import Block, BlockCost, BlockCostType
|
from backend.data.block import Block, BlockCost, BlockCostType
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
apollo_credentials,
|
apollo_credentials,
|
||||||
did_credentials,
|
did_credentials,
|
||||||
|
elevenlabs_credentials,
|
||||||
enrichlayer_credentials,
|
enrichlayer_credentials,
|
||||||
groq_credentials,
|
groq_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
@@ -78,6 +80,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
LlmModel.CLAUDE_4_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
LlmModel.CLAUDE_4_SONNET: 5,
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
@@ -639,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
VideoNarrationBlock: [
|
||||||
|
BlockCost(
|
||||||
|
cost_amount=5, # ElevenLabs TTS cost
|
||||||
|
cost_filter={
|
||||||
|
"credentials": {
|
||||||
|
"id": elevenlabs_credentials.id,
|
||||||
|
"provider": elevenlabs_credentials.provider,
|
||||||
|
"type": elevenlabs_credentials.type,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,6 +134,16 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
|
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
||||||
|
# in a different month than month1 (January). This fixes a timing bug
|
||||||
|
# where if the test runs in early February, 35 days ago would be January,
|
||||||
|
# matching the mocked month1 and preventing the refill from triggering.
|
||||||
|
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
||||||
|
await UserBalance.prisma().update(
|
||||||
|
where={"userId": DEFAULT_USER_ID},
|
||||||
|
data={"updatedAt": dec_previous_year},
|
||||||
|
)
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
)
|
)
|
||||||
from urllib.parse import urlparse
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||||
@@ -42,6 +41,7 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
|
from backend.util.request import parse_url
|
||||||
from backend.util.settings import Secrets
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
# Type alias for any provider name (including custom ones)
|
# Type alias for any provider name (including custom ones)
|
||||||
@@ -397,19 +397,25 @@ class HostScopedCredentials(_BaseCredentials):
|
|||||||
def matches_url(self, url: str) -> bool:
|
def matches_url(self, url: str) -> bool:
|
||||||
"""Check if this credential should be applied to the given URL."""
|
"""Check if this credential should be applied to the given URL."""
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
request_host, request_port = _extract_host_from_url(url)
|
||||||
# Extract hostname without port
|
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
||||||
request_host = parsed_url.hostname
|
|
||||||
if not request_host:
|
if not request_host:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Simple host matching - exact match or wildcard subdomain match
|
# If a port is specified in credential host, the request host port must match
|
||||||
if self.host == request_host:
|
if cred_scope_port is not None and request_port != cred_scope_port:
|
||||||
|
return False
|
||||||
|
# Non-standard ports are only allowed if explicitly specified in credential host
|
||||||
|
elif cred_scope_port is None and request_port not in (80, 443, None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Simple host matching
|
||||||
|
if cred_scope_host == request_host:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||||
if self.host.startswith("*."):
|
if cred_scope_host.startswith("*."):
|
||||||
domain = self.host[2:] # Remove "*."
|
domain = cred_scope_host[2:] # Remove "*."
|
||||||
return request_host.endswith(f".{domain}") or request_host == domain
|
return request_host.endswith(f".{domain}") or request_host == domain
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -551,13 +557,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_host_from_url(url: str) -> str:
|
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
||||||
"""Extract host from URL for grouping host-scoped credentials."""
|
"""Extract host and port from URL for grouping host-scoped credentials."""
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(url)
|
parsed = parse_url(url)
|
||||||
return parsed.hostname or url
|
return parsed.hostname or url, parsed.port
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return "", None
|
||||||
|
|
||||||
|
|
||||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||||
@@ -606,7 +612,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, "http")]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, _extract_host_from_url(str(value)))
|
cast(CP, parse_url(str(value)).netloc)
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,10 +79,23 @@ class TestHostScopedCredentials:
|
|||||||
headers={"Authorization": SecretStr("Bearer token")},
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
# Non-standard ports require explicit port in credential host
|
||||||
|
assert not creds.matches_url("http://localhost:8080/api/v1")
|
||||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||||
assert creds.matches_url("http://localhost/simple")
|
assert creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
|
def test_matches_url_with_explicit_port(self):
|
||||||
|
"""Test URL matching with explicit port in credential host."""
|
||||||
|
creds = HostScopedCredentials(
|
||||||
|
provider="custom",
|
||||||
|
host="localhost:8080",
|
||||||
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||||
|
assert not creds.matches_url("http://localhost:3000/api/v1")
|
||||||
|
assert not creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
def test_empty_headers_dict(self):
|
def test_empty_headers_dict(self):
|
||||||
"""Test HostScopedCredentials with empty headers."""
|
"""Test HostScopedCredentials with empty headers."""
|
||||||
creds = HostScopedCredentials(
|
creds = HostScopedCredentials(
|
||||||
@@ -128,8 +141,20 @@ class TestHostScopedCredentials:
|
|||||||
("*.example.com", "https://sub.api.example.com/test", True),
|
("*.example.com", "https://sub.api.example.com/test", True),
|
||||||
("*.example.com", "https://example.com/test", True),
|
("*.example.com", "https://example.com/test", True),
|
||||||
("*.example.com", "https://example.org/test", False),
|
("*.example.com", "https://example.org/test", False),
|
||||||
("localhost", "http://localhost:3000/test", True),
|
# Non-standard ports require explicit port in credential host
|
||||||
|
("localhost", "http://localhost:3000/test", False),
|
||||||
|
("localhost:3000", "http://localhost:3000/test", True),
|
||||||
("localhost", "http://127.0.0.1:3000/test", False),
|
("localhost", "http://127.0.0.1:3000/test", False),
|
||||||
|
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
||||||
|
("[::1]", "http://[::1]/test", True),
|
||||||
|
("[::1]", "http://[::1]:80/test", True),
|
||||||
|
("[::1]", "https://[::1]:443/test", True),
|
||||||
|
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
||||||
|
("[::1]:8080", "http://[::1]:8080/test", True),
|
||||||
|
("[::1]:8080", "http://[::1]:9090/test", False),
|
||||||
|
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
||||||
|
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
||||||
|
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||||
|
|||||||
@@ -224,6 +224,14 @@ openweathermap_credentials = APIKeyCredentials(
|
|||||||
expires_at=None,
|
expires_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elevenlabs_credentials = APIKeyCredentials(
|
||||||
|
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
||||||
|
title="Use Credits for ElevenLabs",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_CREDENTIALS = [
|
DEFAULT_CREDENTIALS = [
|
||||||
ollama_credentials,
|
ollama_credentials,
|
||||||
revid_credentials,
|
revid_credentials,
|
||||||
@@ -252,6 +260,7 @@ DEFAULT_CREDENTIALS = [
|
|||||||
v0_credentials,
|
v0_credentials,
|
||||||
webshare_proxy_credentials,
|
webshare_proxy_credentials,
|
||||||
openweathermap_credentials,
|
openweathermap_credentials,
|
||||||
|
elevenlabs_credentials,
|
||||||
]
|
]
|
||||||
|
|
||||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||||
@@ -366,6 +375,8 @@ class IntegrationCredentialsStore:
|
|||||||
all_credentials.append(webshare_proxy_credentials)
|
all_credentials.append(webshare_proxy_credentials)
|
||||||
if settings.secrets.openweathermap_api_key:
|
if settings.secrets.openweathermap_api_key:
|
||||||
all_credentials.append(openweathermap_credentials)
|
all_credentials.append(openweathermap_credentials)
|
||||||
|
if settings.secrets.elevenlabs_api_key:
|
||||||
|
all_credentials.append(elevenlabs_credentials)
|
||||||
return all_credentials
|
return all_credentials
|
||||||
|
|
||||||
async def get_creds_by_id(
|
async def get_creds_by_id(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class ProviderName(str, Enum):
|
|||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
D_ID = "d_id"
|
D_ID = "d_id"
|
||||||
E2B = "e2b"
|
E2B = "e2b"
|
||||||
|
ELEVENLABS = "elevenlabs"
|
||||||
FAL = "fal"
|
FAL = "fal"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -17,6 +19,35 @@ from backend.util.virus_scanner import scan_content_safe
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceUri(BaseModel):
|
||||||
|
"""Parsed workspace:// URI."""
|
||||||
|
|
||||||
|
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
|
||||||
|
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
|
||||||
|
is_path: bool = False # True if file_ref is a path (starts with "/")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_workspace_uri(uri: str) -> WorkspaceUri:
|
||||||
|
"""Parse a workspace:// URI into its components.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
|
||||||
|
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
|
||||||
|
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
|
||||||
|
"""
|
||||||
|
raw = uri.removeprefix("workspace://")
|
||||||
|
mime_type: str | None = None
|
||||||
|
if "#" in raw:
|
||||||
|
raw, fragment = raw.split("#", 1)
|
||||||
|
mime_type = fragment or None
|
||||||
|
return WorkspaceUri(
|
||||||
|
file_ref=raw,
|
||||||
|
mime_type=mime_type,
|
||||||
|
is_path=raw.startswith("/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Return format options for store_media_file
|
# Return format options for store_media_file
|
||||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
@@ -183,22 +214,20 @@ async def store_media_file(
|
|||||||
"This file type is only available in CoPilot sessions."
|
"This file type is only available in CoPilot sessions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse workspace reference
|
# Parse workspace reference (strips #mimeType fragment from file ID)
|
||||||
# workspace://abc123 - by file ID
|
ws = parse_workspace_uri(file)
|
||||||
# workspace:///path/to/file.txt - by virtual path
|
|
||||||
file_ref = file[12:] # Remove "workspace://"
|
|
||||||
|
|
||||||
if file_ref.startswith("/"):
|
if ws.is_path:
|
||||||
# Path reference
|
# Path reference: workspace:///path/to/file.txt
|
||||||
workspace_content = await workspace_manager.read_file(file_ref)
|
workspace_content = await workspace_manager.read_file(ws.file_ref)
|
||||||
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# ID reference
|
# ID reference: workspace://abc123 or workspace://abc123#video/mp4
|
||||||
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref)
|
||||||
file_info = await workspace_manager.get_file_info(file_ref)
|
file_info = await workspace_manager.get_file_info(ws.file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
@@ -334,7 +363,21 @@ async def store_media_file(
|
|||||||
|
|
||||||
# Don't re-save if input was already from workspace
|
# Don't re-save if input was already from workspace
|
||||||
if is_from_workspace:
|
if is_from_workspace:
|
||||||
# Return original workspace reference
|
# Return original workspace reference, ensuring MIME type fragment
|
||||||
|
ws = parse_workspace_uri(file)
|
||||||
|
if not ws.mime_type:
|
||||||
|
# Add MIME type fragment if missing (older refs without it)
|
||||||
|
try:
|
||||||
|
if ws.is_path:
|
||||||
|
info = await workspace_manager.get_file_info_by_path(
|
||||||
|
ws.file_ref
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info = await workspace_manager.get_file_info(ws.file_ref)
|
||||||
|
if info:
|
||||||
|
return MediaFileType(f"{file}#{info.mimeType}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return MediaFileType(file)
|
return MediaFileType(file)
|
||||||
|
|
||||||
# Save new content to workspace
|
# Save new content to workspace
|
||||||
@@ -346,7 +389,7 @@ async def store_media_file(
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
return MediaFileType(f"workspace://{file_record.id}")
|
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|||||||
@@ -157,12 +157,7 @@ async def validate_url(
|
|||||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||||
"""
|
"""
|
||||||
# Canonicalize URL
|
parsed = parse_url(url)
|
||||||
url = url.strip("/ ").replace("\\", "/")
|
|
||||||
parsed = urlparse(url)
|
|
||||||
if not parsed.scheme:
|
|
||||||
url = f"http://{url}"
|
|
||||||
parsed = urlparse(url)
|
|
||||||
|
|
||||||
# Check scheme
|
# Check scheme
|
||||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||||
@@ -220,6 +215,17 @@ async def validate_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_url(url: str) -> URL:
|
||||||
|
"""Canonicalizes and parses a URL string."""
|
||||||
|
url = url.strip("/ ").replace("\\", "/")
|
||||||
|
|
||||||
|
# Ensure scheme is present for proper parsing
|
||||||
|
if not re.match(r"[a-z0-9+.\-]+://", url):
|
||||||
|
url = f"http://{url}"
|
||||||
|
|
||||||
|
return urlparse(url)
|
||||||
|
|
||||||
|
|
||||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||||
"""
|
"""
|
||||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||||
|
|||||||
@@ -656,6 +656,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||||
|
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||||
|
|
||||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||||
|
|||||||
47
autogpt_platform/backend/poetry.lock
generated
47
autogpt_platform/backend/poetry.lock
generated
@@ -1169,6 +1169,29 @@ attrs = ">=21.3.0"
|
|||||||
e2b = ">=1.5.4,<2.0.0"
|
e2b = ">=1.5.4,<2.0.0"
|
||||||
httpx = ">=0.20.0,<1.0.0"
|
httpx = ">=0.20.0,<1.0.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "elevenlabs"
|
||||||
|
version = "1.59.0"
|
||||||
|
description = ""
|
||||||
|
optional = false
|
||||||
|
python-versions = "<4.0,>=3.8"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "elevenlabs-1.59.0-py3-none-any.whl", hash = "sha256:468145db81a0bc867708b4a8619699f75583e9481b395ec1339d0b443da771ed"},
|
||||||
|
{file = "elevenlabs-1.59.0.tar.gz", hash = "sha256:16e735bd594e86d415dd445d249c8cc28b09996cfd627fbc10102c0a84698859"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
httpx = ">=0.21.2"
|
||||||
|
pydantic = ">=1.9.2"
|
||||||
|
pydantic-core = ">=2.18.2,<3.0.0"
|
||||||
|
requests = ">=2.20"
|
||||||
|
typing_extensions = ">=4.0.0"
|
||||||
|
websockets = ">=11.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
pyaudio = ["pyaudio (>=0.2.14)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "email-validator"
|
name = "email-validator"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
@@ -7361,6 +7384,28 @@ files = [
|
|||||||
defusedxml = ">=0.7.1,<0.8.0"
|
defusedxml = ">=0.7.1,<0.8.0"
|
||||||
requests = "*"
|
requests = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "yt-dlp"
|
||||||
|
version = "2025.12.8"
|
||||||
|
description = "A feature-rich command-line audio/video downloader"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.10"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "yt_dlp-2025.12.8-py3-none-any.whl", hash = "sha256:36e2584342e409cfbfa0b5e61448a1c5189e345cf4564294456ee509e7d3e065"},
|
||||||
|
{file = "yt_dlp-2025.12.8.tar.gz", hash = "sha256:b773c81bb6b71cb2c111cfb859f453c7a71cf2ef44eff234ff155877184c3e4f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
build = ["build", "hatchling (>=1.27.0)", "pip", "setuptools (>=71.0.2)", "wheel"]
|
||||||
|
curl-cffi = ["curl-cffi (>=0.5.10,<0.6.dev0 || >=0.10.dev0,<0.14) ; implementation_name == \"cpython\""]
|
||||||
|
default = ["brotli ; implementation_name == \"cpython\"", "brotlicffi ; implementation_name != \"cpython\"", "certifi", "mutagen", "pycryptodomex", "requests (>=2.32.2,<3)", "urllib3 (>=2.0.2,<3)", "websockets (>=13.0)", "yt-dlp-ejs (==0.3.2)"]
|
||||||
|
dev = ["autopep8 (>=2.0,<3.0)", "pre-commit", "pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)", "ruff (>=0.14.0,<0.15.0)"]
|
||||||
|
pyinstaller = ["pyinstaller (>=6.17.0)"]
|
||||||
|
secretstorage = ["cffi", "secretstorage"]
|
||||||
|
static-analysis = ["autopep8 (>=2.0,<3.0)", "ruff (>=0.14.0,<0.15.0)"]
|
||||||
|
test = ["pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zerobouncesdk"
|
name = "zerobouncesdk"
|
||||||
version = "1.1.2"
|
version = "1.1.2"
|
||||||
@@ -7512,4 +7557,4 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<3.14"
|
python-versions = ">=3.10,<3.14"
|
||||||
content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf"
|
content-hash = "8239323f9ae6713224dffd1fe8ba8b449fe88b6c3c7a90940294a74f43a0387a"
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ click = "^8.2.0"
|
|||||||
cryptography = "^45.0"
|
cryptography = "^45.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
|
elevenlabs = "^1.50.0"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.116.1"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
@@ -71,6 +72,7 @@ tweepy = "^4.16.0"
|
|||||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
|
yt-dlp = "2025.12.08"
|
||||||
zerobouncesdk = "^1.1.2"
|
zerobouncesdk = "^1.1.2"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
|
|||||||
@@ -111,9 +111,7 @@ class TestGenerateAgent:
|
|||||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
result = await core.generate_agent(instructions)
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
# library_agents defaults to None
|
mock_external.assert_called_once_with(instructions, None, None, None)
|
||||||
mock_external.assert_called_once_with(instructions, None)
|
|
||||||
# Result should have id, version, is_active added if not present
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test Agent"
|
assert result["name"] == "Test Agent"
|
||||||
assert "id" in result
|
assert "id" in result
|
||||||
@@ -177,8 +175,9 @@ class TestGenerateAgentPatch:
|
|||||||
current_agent = {"nodes": [], "links": []}
|
current_agent = {"nodes": [], "links": []}
|
||||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
# library_agents defaults to None
|
mock_external.assert_called_once_with(
|
||||||
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
"Add a node", current_agent, None, None, None
|
||||||
|
)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -1,6 +1,17 @@
|
|||||||
import { OAuthPopupResultMessage } from "./types";
|
import { OAuthPopupResultMessage } from "./types";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Safely encode a value as JSON for embedding in a script tag.
|
||||||
|
* Escapes characters that could break out of the script context to prevent XSS.
|
||||||
|
*/
|
||||||
|
function safeJsonStringify(value: unknown): string {
|
||||||
|
return JSON.stringify(value)
|
||||||
|
.replace(/</g, "\\u003c")
|
||||||
|
.replace(/>/g, "\\u003e")
|
||||||
|
.replace(/&/g, "\\u0026");
|
||||||
|
}
|
||||||
|
|
||||||
// This route is intended to be used as the callback for integration OAuth flows,
|
// This route is intended to be used as the callback for integration OAuth flows,
|
||||||
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
||||||
// page in a pop-up window, which then redirects to this route to close the loop.
|
// page in a pop-up window, which then redirects to this route to close the loop.
|
||||||
@@ -23,12 +34,13 @@ export async function GET(request: Request) {
|
|||||||
console.debug("Sending message to opener:", message);
|
console.debug("Sending message to opener:", message);
|
||||||
|
|
||||||
// Return a response with the message as JSON and a script to close the window
|
// Return a response with the message as JSON and a script to close the window
|
||||||
|
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
|
||||||
return new NextResponse(
|
return new NextResponse(
|
||||||
`
|
`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<script>
|
<script>
|
||||||
window.opener.postMessage(${JSON.stringify(message)});
|
window.opener.postMessage(${safeJsonStringify(message)});
|
||||||
window.close();
|
window.close();
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { beautifyString } from "@/lib/utils";
|
import { beautifyString } from "@/lib/utils";
|
||||||
import { Clipboard, Maximize2 } from "lucide-react";
|
import { Clipboard, Maximize2 } from "lucide-react";
|
||||||
import React, { useState } from "react";
|
import React, { useMemo, useState } from "react";
|
||||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||||
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
||||||
import {
|
import {
|
||||||
@@ -11,6 +11,12 @@ import {
|
|||||||
TableHeader,
|
TableHeader,
|
||||||
TableRow,
|
TableRow,
|
||||||
} from "../../../../../components/__legacy__/ui/table";
|
} from "../../../../../components/__legacy__/ui/table";
|
||||||
|
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||||
|
import {
|
||||||
|
globalRegistry,
|
||||||
|
OutputItem,
|
||||||
|
} from "@/components/contextual/OutputRenderers";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
||||||
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
||||||
|
|
||||||
@@ -26,6 +32,9 @@ export default function DataTable({
|
|||||||
data,
|
data,
|
||||||
}: DataTableProps) {
|
}: DataTableProps) {
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
const enableEnhancedOutputHandling = useGetFlag(
|
||||||
|
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||||
|
);
|
||||||
const [expandedDialog, setExpandedDialog] = useState<{
|
const [expandedDialog, setExpandedDialog] = useState<{
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
execId: string;
|
execId: string;
|
||||||
@@ -33,6 +42,15 @@ export default function DataTable({
|
|||||||
data: any[];
|
data: any[];
|
||||||
} | null>(null);
|
} | null>(null);
|
||||||
|
|
||||||
|
// Prepare renderers for each item when enhanced mode is enabled
|
||||||
|
const getItemRenderer = useMemo(() => {
|
||||||
|
if (!enableEnhancedOutputHandling) return null;
|
||||||
|
return (item: unknown) => {
|
||||||
|
const metadata: OutputMetadata = {};
|
||||||
|
return globalRegistry.getRenderer(item, metadata);
|
||||||
|
};
|
||||||
|
}, [enableEnhancedOutputHandling]);
|
||||||
|
|
||||||
const copyData = (pin: string, data: string) => {
|
const copyData = (pin: string, data: string) => {
|
||||||
navigator.clipboard.writeText(data).then(() => {
|
navigator.clipboard.writeText(data).then(() => {
|
||||||
toast({
|
toast({
|
||||||
@@ -102,15 +120,31 @@ export default function DataTable({
|
|||||||
<Clipboard size={18} />
|
<Clipboard size={18} />
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
{value.map((item, index) => (
|
{value.map((item, index) => {
|
||||||
<React.Fragment key={index}>
|
const renderer = getItemRenderer?.(item);
|
||||||
<ContentRenderer
|
if (enableEnhancedOutputHandling && renderer) {
|
||||||
value={item}
|
const metadata: OutputMetadata = {};
|
||||||
truncateLongData={truncateLongData}
|
return (
|
||||||
/>
|
<React.Fragment key={index}>
|
||||||
{index < value.length - 1 && ", "}
|
<OutputItem
|
||||||
</React.Fragment>
|
value={item}
|
||||||
))}
|
metadata={metadata}
|
||||||
|
renderer={renderer}
|
||||||
|
/>
|
||||||
|
{index < value.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<React.Fragment key={index}>
|
||||||
|
<ContentRenderer
|
||||||
|
value={item}
|
||||||
|
truncateLongData={truncateLongData}
|
||||||
|
/>
|
||||||
|
{index < value.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
})}
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
import React, { useContext, useState } from "react";
|
import React, { useContext, useMemo, useState } from "react";
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { Maximize2 } from "lucide-react";
|
import { Maximize2 } from "lucide-react";
|
||||||
import * as Separator from "@radix-ui/react-separator";
|
import * as Separator from "@radix-ui/react-separator";
|
||||||
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
||||||
|
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||||
|
import {
|
||||||
|
globalRegistry,
|
||||||
|
OutputItem,
|
||||||
|
} from "@/components/contextual/OutputRenderers";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
|
||||||
import { beautifyString } from "@/lib/utils";
|
import { beautifyString } from "@/lib/utils";
|
||||||
|
|
||||||
@@ -21,6 +27,9 @@ export default function NodeOutputs({
|
|||||||
data,
|
data,
|
||||||
}: NodeOutputsProps) {
|
}: NodeOutputsProps) {
|
||||||
const builderContext = useContext(BuilderContext);
|
const builderContext = useContext(BuilderContext);
|
||||||
|
const enableEnhancedOutputHandling = useGetFlag(
|
||||||
|
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||||
|
);
|
||||||
|
|
||||||
const [expandedDialog, setExpandedDialog] = useState<{
|
const [expandedDialog, setExpandedDialog] = useState<{
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
@@ -37,6 +46,15 @@ export default function NodeOutputs({
|
|||||||
|
|
||||||
const { getNodeTitle } = builderContext;
|
const { getNodeTitle } = builderContext;
|
||||||
|
|
||||||
|
// Prepare renderers for each item when enhanced mode is enabled
|
||||||
|
const getItemRenderer = useMemo(() => {
|
||||||
|
if (!enableEnhancedOutputHandling) return null;
|
||||||
|
return (item: unknown) => {
|
||||||
|
const metadata: OutputMetadata = {};
|
||||||
|
return globalRegistry.getRenderer(item, metadata);
|
||||||
|
};
|
||||||
|
}, [enableEnhancedOutputHandling]);
|
||||||
|
|
||||||
const getBeautifiedPinName = (pin: string) => {
|
const getBeautifiedPinName = (pin: string) => {
|
||||||
if (!pin.startsWith("tools_^_")) {
|
if (!pin.startsWith("tools_^_")) {
|
||||||
return beautifyString(pin);
|
return beautifyString(pin);
|
||||||
@@ -87,15 +105,31 @@ export default function NodeOutputs({
|
|||||||
<div className="mt-2">
|
<div className="mt-2">
|
||||||
<strong className="mr-2">Data:</strong>
|
<strong className="mr-2">Data:</strong>
|
||||||
<div className="mt-1">
|
<div className="mt-1">
|
||||||
{dataArray.slice(0, 10).map((item, index) => (
|
{dataArray.slice(0, 10).map((item, index) => {
|
||||||
<React.Fragment key={index}>
|
const renderer = getItemRenderer?.(item);
|
||||||
<ContentRenderer
|
if (enableEnhancedOutputHandling && renderer) {
|
||||||
value={item}
|
const metadata: OutputMetadata = {};
|
||||||
truncateLongData={truncateLongData}
|
return (
|
||||||
/>
|
<React.Fragment key={index}>
|
||||||
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
<OutputItem
|
||||||
</React.Fragment>
|
value={item}
|
||||||
))}
|
metadata={metadata}
|
||||||
|
renderer={renderer}
|
||||||
|
/>
|
||||||
|
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<React.Fragment key={index}>
|
||||||
|
<ContentRenderer
|
||||||
|
value={item}
|
||||||
|
truncateLongData={truncateLongData}
|
||||||
|
/>
|
||||||
|
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
})}
|
||||||
{dataArray.length > 10 && (
|
{dataArray.length > 10 && (
|
||||||
<span style={{ color: "#888" }}>
|
<span style={{ color: "#888" }}>
|
||||||
<br />
|
<br />
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
|||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { usePathname, useSearchParams } from "next/navigation";
|
import { usePathname, useSearchParams } from "next/navigation";
|
||||||
import { useRef } from "react";
|
|
||||||
import { useCopilotStore } from "../../copilot-page-store";
|
import { useCopilotStore } from "../../copilot-page-store";
|
||||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||||
@@ -70,41 +69,16 @@ export function useCopilotShell() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const stopStream = useChatStore((s) => s.stopStream);
|
const stopStream = useChatStore((s) => s.stopStream);
|
||||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
|
||||||
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
|
||||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||||
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
|
||||||
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
|
||||||
|
|
||||||
const pendingActionRef = useRef<(() => void) | null>(null);
|
function handleSessionClick(sessionId: string) {
|
||||||
|
|
||||||
async function stopCurrentStream() {
|
|
||||||
if (!currentSessionId) return;
|
|
||||||
|
|
||||||
setIsSwitchingSession(true);
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
const unsubscribe = onStreamComplete((completedId) => {
|
|
||||||
if (completedId === currentSessionId) {
|
|
||||||
clearTimeout(timeout);
|
|
||||||
unsubscribe();
|
|
||||||
resolve();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const timeout = setTimeout(() => {
|
|
||||||
unsubscribe();
|
|
||||||
resolve();
|
|
||||||
}, 3000);
|
|
||||||
stopStream(currentSessionId);
|
|
||||||
});
|
|
||||||
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
|
||||||
});
|
|
||||||
setIsSwitchingSession(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
function selectSession(sessionId: string) {
|
|
||||||
if (sessionId === currentSessionId) return;
|
if (sessionId === currentSessionId) return;
|
||||||
|
|
||||||
|
// Stop current stream - SSE reconnection allows resuming later
|
||||||
|
if (currentSessionId) {
|
||||||
|
stopStream(currentSessionId);
|
||||||
|
}
|
||||||
|
|
||||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
@@ -114,7 +88,12 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function startNewChat() {
|
function handleNewChatClick() {
|
||||||
|
// Stop current stream - SSE reconnection allows resuming later
|
||||||
|
if (currentSessionId) {
|
||||||
|
stopStream(currentSessionId);
|
||||||
|
}
|
||||||
|
|
||||||
resetPagination();
|
resetPagination();
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
@@ -123,32 +102,6 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleSessionClick(sessionId: string) {
|
|
||||||
if (sessionId === currentSessionId) return;
|
|
||||||
|
|
||||||
if (isStreaming) {
|
|
||||||
pendingActionRef.current = async () => {
|
|
||||||
await stopCurrentStream();
|
|
||||||
selectSession(sessionId);
|
|
||||||
};
|
|
||||||
openInterruptModal(pendingActionRef.current);
|
|
||||||
} else {
|
|
||||||
selectSession(sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleNewChatClick() {
|
|
||||||
if (isStreaming) {
|
|
||||||
pendingActionRef.current = async () => {
|
|
||||||
await stopCurrentStream();
|
|
||||||
startNewChat();
|
|
||||||
};
|
|
||||||
openInterruptModal(pendingActionRef.current);
|
|
||||||
} else {
|
|
||||||
startNewChat();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isMobile,
|
isMobile,
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
|
|||||||
@@ -26,8 +26,20 @@ export function buildCopilotChatUrl(prompt: string): string {
|
|||||||
|
|
||||||
export function getQuickActions(): string[] {
|
export function getQuickActions(): string[] {
|
||||||
return [
|
return [
|
||||||
"Show me what I can automate",
|
"I don't know where to start, just ask me stuff",
|
||||||
"Design a custom workflow",
|
"I do the same thing every week and it's killing me",
|
||||||
"Help me with content creation",
|
"Help me find where I'm wasting my time",
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getInputPlaceholder(width?: number) {
|
||||||
|
if (!width) return "What's your role and what eats up most of your day?";
|
||||||
|
|
||||||
|
if (width < 500) {
|
||||||
|
return "I'm a chef and I hate...";
|
||||||
|
}
|
||||||
|
if (width <= 1080) {
|
||||||
|
return "What's your role and what eats up most of your day?";
|
||||||
|
}
|
||||||
|
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
|
import { getInputPlaceholder } from "./helpers";
|
||||||
import { useCopilotPage } from "./useCopilotPage";
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
export default function CopilotPage() {
|
export default function CopilotPage() {
|
||||||
@@ -14,8 +16,25 @@ export default function CopilotPage() {
|
|||||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
|
|
||||||
|
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||||
|
getInputPlaceholder(),
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleResize = () => {
|
||||||
|
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||||
|
};
|
||||||
|
|
||||||
|
handleResize();
|
||||||
|
|
||||||
|
window.addEventListener("resize", handleResize);
|
||||||
|
return () => window.removeEventListener("resize", handleResize);
|
||||||
|
}, []);
|
||||||
|
|
||||||
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
||||||
state;
|
state;
|
||||||
|
|
||||||
const {
|
const {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
@@ -73,7 +92,7 @@ export default function CopilotPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
|
||||||
<div className="w-full text-center">
|
<div className="w-full text-center">
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
<div className="mx-auto max-w-2xl">
|
<div className="mx-auto max-w-2xl">
|
||||||
@@ -90,25 +109,25 @@ export default function CopilotPage() {
|
|||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<div className="mx-auto max-w-2xl">
|
<div className="mx-auto max-w-3xl">
|
||||||
<Text
|
<Text
|
||||||
variant="h3"
|
variant="h3"
|
||||||
className="mb-3 !text-[1.375rem] text-zinc-700"
|
className="mb-1 !text-[1.375rem] text-zinc-700"
|
||||||
>
|
>
|
||||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||||
</Text>
|
</Text>
|
||||||
<Text variant="h3" className="mb-8 !font-normal">
|
<Text variant="h3" className="mb-8 !font-normal">
|
||||||
What do you want to automate?
|
Tell me about your work — I'll find what to automate.
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
<div className="mb-6">
|
<div className="mb-6">
|
||||||
<ChatInput
|
<ChatInput
|
||||||
onSend={startChatWithPrompt}
|
onSend={startChatWithPrompt}
|
||||||
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
placeholder={inputPlaceholder}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||||
{quickActions.map((action) => (
|
{quickActions.map((action) => (
|
||||||
<Button
|
<Button
|
||||||
key={action}
|
key={action}
|
||||||
@@ -116,7 +135,7 @@ export default function CopilotPage() {
|
|||||||
variant="outline"
|
variant="outline"
|
||||||
size="small"
|
size="small"
|
||||||
onClick={() => handleQuickAction(action)}
|
onClick={() => handleQuickAction(action)}
|
||||||
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||||
>
|
>
|
||||||
{action}
|
{action}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
import { environment } from "@/services/environment";
|
||||||
|
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||||
|
import { NextRequest } from "next/server";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SSE Proxy for task stream reconnection.
|
||||||
|
*
|
||||||
|
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
||||||
|
* background task's stream. It replays missed messages from Redis Streams and
|
||||||
|
* subscribes to live updates if the task is still running.
|
||||||
|
*
|
||||||
|
* Client contract:
|
||||||
|
* 1. When receiving an operation_started event, store the task_id
|
||||||
|
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||||
|
* 3. Messages are replayed from the last_message_id position
|
||||||
|
* 4. Stream ends when "finish" event is received
|
||||||
|
*/
|
||||||
|
export async function GET(
|
||||||
|
request: NextRequest,
|
||||||
|
{ params }: { params: Promise<{ taskId: string }> },
|
||||||
|
) {
|
||||||
|
const { taskId } = await params;
|
||||||
|
const searchParams = request.nextUrl.searchParams;
|
||||||
|
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Get auth token from server-side session
|
||||||
|
const token = await getServerAuthToken();
|
||||||
|
|
||||||
|
// Build backend URL
|
||||||
|
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||||
|
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
||||||
|
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
||||||
|
|
||||||
|
// Forward request to backend with auth header
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
Accept: "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
Connection: "keep-alive",
|
||||||
|
};
|
||||||
|
|
||||||
|
if (token) {
|
||||||
|
headers["Authorization"] = `Bearer ${token}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(streamUrl.toString(), {
|
||||||
|
method: "GET",
|
||||||
|
headers,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const error = await response.text();
|
||||||
|
return new Response(error, {
|
||||||
|
status: response.status,
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the SSE stream directly
|
||||||
|
return new Response(response.body, {
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
Connection: "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Task stream proxy error:", error);
|
||||||
|
return new Response(
|
||||||
|
JSON.stringify({
|
||||||
|
error: "Failed to connect to task stream",
|
||||||
|
detail: error instanceof Error ? error.message : String(error),
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
status: 500,
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -917,6 +917,28 @@
|
|||||||
"security": [{ "HTTPBearerJWT": [] }]
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/chat/config/ttl": {
|
||||||
|
"get": {
|
||||||
|
"tags": ["v2", "chat", "chat"],
|
||||||
|
"summary": "Get Ttl Config",
|
||||||
|
"description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.",
|
||||||
|
"operationId": "getV2GetTtlConfig",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"additionalProperties": true,
|
||||||
|
"type": "object",
|
||||||
|
"title": "Response Getv2Getttlconfig"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/api/chat/health": {
|
"/api/chat/health": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
@@ -939,6 +961,63 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/chat/operations/{operation_id}/complete": {
|
||||||
|
"post": {
|
||||||
|
"tags": ["v2", "chat", "chat"],
|
||||||
|
"summary": "Complete Operation",
|
||||||
|
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
||||||
|
"operationId": "postV2CompleteOperation",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "operation_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": { "type": "string", "title": "Operation Id" }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "x-api-key",
|
||||||
|
"in": "header",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "X-Api-Key"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"requestBody": {
|
||||||
|
"required": true,
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/OperationCompleteRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true,
|
||||||
|
"title": "Response Postv2Completeoperation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/api/chat/sessions": {
|
"/api/chat/sessions": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
@@ -1022,7 +1101,7 @@
|
|||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
"summary": "Get Session",
|
"summary": "Get Session",
|
||||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
||||||
"operationId": "getV2GetSession",
|
"operationId": "getV2GetSession",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -1157,7 +1236,7 @@
|
|||||||
"post": {
|
"post": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
"summary": "Stream Chat Post",
|
"summary": "Stream Chat Post",
|
||||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
|
||||||
"operationId": "postV2StreamChatPost",
|
"operationId": "postV2StreamChatPost",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -1195,6 +1274,94 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/chat/tasks/{task_id}": {
|
||||||
|
"get": {
|
||||||
|
"tags": ["v2", "chat", "chat"],
|
||||||
|
"summary": "Get Task Status",
|
||||||
|
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||||
|
"operationId": "getV2GetTaskStatus",
|
||||||
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "task_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": { "type": "string", "title": "Task Id" }
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true,
|
||||||
|
"title": "Response Getv2Gettaskstatus"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/api/chat/tasks/{task_id}/stream": {
|
||||||
|
"get": {
|
||||||
|
"tags": ["v2", "chat", "chat"],
|
||||||
|
"summary": "Stream Task",
|
||||||
|
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
|
||||||
|
"operationId": "getV2StreamTask",
|
||||||
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "task_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": { "type": "string", "title": "Task Id" }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "last_message_id",
|
||||||
|
"in": "query",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||||
|
"default": "0-0",
|
||||||
|
"title": "Last Message Id"
|
||||||
|
},
|
||||||
|
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": { "application/json": { "schema": {} } }
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/api/credits": {
|
"/api/credits": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v1", "credits"],
|
"tags": ["v1", "credits"],
|
||||||
@@ -6168,6 +6335,18 @@
|
|||||||
"title": "AccuracyTrendsResponse",
|
"title": "AccuracyTrendsResponse",
|
||||||
"description": "Response model for accuracy trends and alerts."
|
"description": "Response model for accuracy trends and alerts."
|
||||||
},
|
},
|
||||||
|
"ActiveStreamInfo": {
|
||||||
|
"properties": {
|
||||||
|
"task_id": { "type": "string", "title": "Task Id" },
|
||||||
|
"last_message_id": { "type": "string", "title": "Last Message Id" },
|
||||||
|
"operation_id": { "type": "string", "title": "Operation Id" },
|
||||||
|
"tool_name": { "type": "string", "title": "Tool Name" }
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
|
||||||
|
"title": "ActiveStreamInfo",
|
||||||
|
"description": "Information about an active stream for reconnection."
|
||||||
|
},
|
||||||
"AddUserCreditsResponse": {
|
"AddUserCreditsResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"new_balance": { "type": "integer", "title": "New Balance" },
|
"new_balance": { "type": "integer", "title": "New Balance" },
|
||||||
@@ -8823,6 +9002,27 @@
|
|||||||
],
|
],
|
||||||
"title": "OnboardingStep"
|
"title": "OnboardingStep"
|
||||||
},
|
},
|
||||||
|
"OperationCompleteRequest": {
|
||||||
|
"properties": {
|
||||||
|
"success": { "type": "boolean", "title": "Success" },
|
||||||
|
"result": {
|
||||||
|
"anyOf": [
|
||||||
|
{ "additionalProperties": true, "type": "object" },
|
||||||
|
{ "type": "string" },
|
||||||
|
{ "type": "null" }
|
||||||
|
],
|
||||||
|
"title": "Result"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Error"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["success"],
|
||||||
|
"title": "OperationCompleteRequest",
|
||||||
|
"description": "Request model for external completion webhook."
|
||||||
|
},
|
||||||
"Pagination": {
|
"Pagination": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"total_items": {
|
"total_items": {
|
||||||
@@ -9678,6 +9878,12 @@
|
|||||||
"items": { "additionalProperties": true, "type": "object" },
|
"items": { "additionalProperties": true, "type": "object" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Messages"
|
"title": "Messages"
|
||||||
|
},
|
||||||
|
"active_stream": {
|
||||||
|
"anyOf": [
|
||||||
|
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||||
|
{ "type": "null" }
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ const isValidVideoUrl = (url: string): boolean => {
|
|||||||
if (url.startsWith("data:video")) {
|
if (url.startsWith("data:video")) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const videoExtensions = /\.(mp4|webm|ogg)$/i;
|
const videoExtensions = /\.(mp4|webm|ogg|mov|avi|mkv|m4v)$/i;
|
||||||
const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.?be)\/.+$/;
|
const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.?be)\/.+$/;
|
||||||
const cleanedUrl = url.split("?")[0];
|
const cleanedUrl = url.split("?")[0];
|
||||||
return (
|
return (
|
||||||
@@ -44,11 +44,29 @@ const isValidAudioUrl = (url: string): boolean => {
|
|||||||
if (url.startsWith("data:audio")) {
|
if (url.startsWith("data:audio")) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const audioExtensions = /\.(mp3|wav)$/i;
|
const audioExtensions = /\.(mp3|wav|ogg|m4a|aac|flac)$/i;
|
||||||
const cleanedUrl = url.split("?")[0];
|
const cleanedUrl = url.split("?")[0];
|
||||||
return isValidMediaUri(url) && audioExtensions.test(cleanedUrl);
|
return isValidMediaUri(url) && audioExtensions.test(cleanedUrl);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getVideoMimeType = (url: string): string => {
|
||||||
|
if (url.startsWith("data:video/")) {
|
||||||
|
const match = url.match(/^data:(video\/[^;]+)/);
|
||||||
|
return match?.[1] || "video/mp4";
|
||||||
|
}
|
||||||
|
const extension = url.split("?")[0].split(".").pop()?.toLowerCase();
|
||||||
|
const mimeMap: Record<string, string> = {
|
||||||
|
mp4: "video/mp4",
|
||||||
|
webm: "video/webm",
|
||||||
|
ogg: "video/ogg",
|
||||||
|
mov: "video/quicktime",
|
||||||
|
avi: "video/x-msvideo",
|
||||||
|
mkv: "video/x-matroska",
|
||||||
|
m4v: "video/mp4",
|
||||||
|
};
|
||||||
|
return mimeMap[extension || ""] || "video/mp4";
|
||||||
|
};
|
||||||
|
|
||||||
const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
||||||
const videoId = getYouTubeVideoId(videoUrl);
|
const videoId = getYouTubeVideoId(videoUrl);
|
||||||
return (
|
return (
|
||||||
@@ -63,7 +81,7 @@ const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
|||||||
></iframe>
|
></iframe>
|
||||||
) : (
|
) : (
|
||||||
<video controls width="100%" height="315">
|
<video controls width="100%" height="315">
|
||||||
<source src={videoUrl} type="video/mp4" />
|
<source src={videoUrl} type={getVideoMimeType(videoUrl)} />
|
||||||
Your browser does not support the video tag.
|
Your browser does not support the video tag.
|
||||||
</video>
|
</video>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||||
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
@@ -25,8 +24,8 @@ export function Chat({
|
|||||||
}: ChatProps) {
|
}: ChatProps) {
|
||||||
const { urlSessionId } = useCopilotSessionId();
|
const { urlSessionId } = useCopilotSessionId();
|
||||||
const hasHandledNotFoundRef = useRef(false);
|
const hasHandledNotFoundRef = useRef(false);
|
||||||
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
|
||||||
const {
|
const {
|
||||||
|
session,
|
||||||
messages,
|
messages,
|
||||||
isLoading,
|
isLoading,
|
||||||
isCreating,
|
isCreating,
|
||||||
@@ -38,6 +37,18 @@ export function Chat({
|
|||||||
startPollingForOperation,
|
startPollingForOperation,
|
||||||
} = useChat({ urlSessionId });
|
} = useChat({ urlSessionId });
|
||||||
|
|
||||||
|
// Extract active stream info for reconnection
|
||||||
|
const activeStream = (
|
||||||
|
session as {
|
||||||
|
active_stream?: {
|
||||||
|
task_id: string;
|
||||||
|
last_message_id: string;
|
||||||
|
operation_id: string;
|
||||||
|
tool_name: string;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
)?.active_stream;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!onSessionNotFound) return;
|
if (!onSessionNotFound) return;
|
||||||
if (!urlSessionId) return;
|
if (!urlSessionId) return;
|
||||||
@@ -53,8 +64,7 @@ export function Chat({
|
|||||||
isCreating,
|
isCreating,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const shouldShowLoader =
|
const shouldShowLoader = showLoader && (isLoading || isCreating);
|
||||||
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("flex h-full flex-col", className)}>
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
@@ -66,21 +76,19 @@ export function Chat({
|
|||||||
<div className="flex flex-col items-center gap-3">
|
<div className="flex flex-col items-center gap-3">
|
||||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||||
<Text variant="body" className="text-zinc-500">
|
<Text variant="body" className="text-zinc-500">
|
||||||
{isSwitchingSession
|
Loading your chat...
|
||||||
? "Switching chat..."
|
|
||||||
: "Loading your chat..."}
|
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Error State */}
|
{/* Error State */}
|
||||||
{error && !isLoading && !isSwitchingSession && (
|
{error && !isLoading && (
|
||||||
<ChatErrorState error={error} onRetry={createSession} />
|
<ChatErrorState error={error} onRetry={createSession} />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Session Content */}
|
{/* Session Content */}
|
||||||
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
{sessionId && !isLoading && !error && (
|
||||||
<ChatContainer
|
<ChatContainer
|
||||||
sessionId={sessionId}
|
sessionId={sessionId}
|
||||||
initialMessages={messages}
|
initialMessages={messages}
|
||||||
@@ -88,6 +96,16 @@ export function Chat({
|
|||||||
className="flex-1"
|
className="flex-1"
|
||||||
onStreamingChange={onStreamingChange}
|
onStreamingChange={onStreamingChange}
|
||||||
onOperationStarted={startPollingForOperation}
|
onOperationStarted={startPollingForOperation}
|
||||||
|
activeStream={
|
||||||
|
activeStream
|
||||||
|
? {
|
||||||
|
taskId: activeStream.task_id,
|
||||||
|
lastMessageId: activeStream.last_message_id,
|
||||||
|
operationId: activeStream.operation_id,
|
||||||
|
toolName: activeStream.tool_name,
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
# SSE Reconnection Contract for Long-Running Operations
|
||||||
|
|
||||||
|
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
When a user triggers a long-running operation (like agent generation), the backend:
|
||||||
|
|
||||||
|
1. Spawns a background task that survives SSE disconnections
|
||||||
|
2. Returns an `operation_started` response with a `task_id`
|
||||||
|
3. Stores stream messages in Redis Streams for replay
|
||||||
|
|
||||||
|
Clients can reconnect to the task stream at any time to receive missed messages.
|
||||||
|
|
||||||
|
## Client-Side Flow
|
||||||
|
|
||||||
|
### 1. Receiving Operation Started
|
||||||
|
|
||||||
|
When you receive an `operation_started` tool response:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// The response includes a task_id for reconnection
|
||||||
|
{
|
||||||
|
type: "operation_started",
|
||||||
|
tool_name: "generate_agent",
|
||||||
|
operation_id: "uuid-...",
|
||||||
|
task_id: "task-uuid-...", // <-- Store this for reconnection
|
||||||
|
message: "Operation started. You can close this tab."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Storing Task Info
|
||||||
|
|
||||||
|
Use the chat store to track the active task:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { useChatStore } from "./chat-store";
|
||||||
|
|
||||||
|
// When operation_started is received:
|
||||||
|
useChatStore.getState().setActiveTask(sessionId, {
|
||||||
|
taskId: response.task_id,
|
||||||
|
operationId: response.operation_id,
|
||||||
|
toolName: response.tool_name,
|
||||||
|
lastMessageId: "0",
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Reconnecting to a Task
|
||||||
|
|
||||||
|
To reconnect (e.g., after page refresh or tab reopen):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const { reconnectToTask, getActiveTask } = useChatStore.getState();
|
||||||
|
|
||||||
|
// Check if there's an active task for this session
|
||||||
|
const activeTask = getActiveTask(sessionId);
|
||||||
|
|
||||||
|
if (activeTask) {
|
||||||
|
// Reconnect to the task stream
|
||||||
|
await reconnectToTask(
|
||||||
|
sessionId,
|
||||||
|
activeTask.taskId,
|
||||||
|
activeTask.lastMessageId, // Resume from last position
|
||||||
|
(chunk) => {
|
||||||
|
// Handle incoming chunks
|
||||||
|
console.log("Received chunk:", chunk);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Tracking Message Position
|
||||||
|
|
||||||
|
To enable precise replay, update the last message ID as chunks arrive:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const { updateTaskLastMessageId } = useChatStore.getState();
|
||||||
|
|
||||||
|
function handleChunk(chunk: StreamChunk) {
|
||||||
|
// If chunk has an index/id, track it
|
||||||
|
if (chunk.idx !== undefined) {
|
||||||
|
updateTaskLastMessageId(sessionId, String(chunk.idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### Task Stream Reconnection
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `taskId`: The task ID from `operation_started`
|
||||||
|
- `last_message_id`: Last received message index (default: "0" for full replay)
|
||||||
|
|
||||||
|
Returns: SSE stream of missed messages + live updates
|
||||||
|
|
||||||
|
## Chunk Types
|
||||||
|
|
||||||
|
The reconnected stream follows the same Vercel AI SDK protocol:
|
||||||
|
|
||||||
|
| Type | Description |
|
||||||
|
| ----------------------- | ----------------------- |
|
||||||
|
| `start` | Message lifecycle start |
|
||||||
|
| `text-delta` | Streaming text content |
|
||||||
|
| `text-end` | Text block completed |
|
||||||
|
| `tool-output-available` | Tool result available |
|
||||||
|
| `finish` | Stream completed |
|
||||||
|
| `error` | Error occurred |
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
If reconnection fails:
|
||||||
|
|
||||||
|
1. Check if task still exists (may have expired - default TTL: 1 hour)
|
||||||
|
2. Fall back to polling the session for final state
|
||||||
|
3. Show appropriate UI message to user
|
||||||
|
|
||||||
|
## Persistence Considerations
|
||||||
|
|
||||||
|
For robust reconnection across browser restarts:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Store in localStorage/sessionStorage
|
||||||
|
const ACTIVE_TASKS_KEY = "chat_active_tasks";
|
||||||
|
|
||||||
|
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
|
||||||
|
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||||
|
tasks[sessionId] = task;
|
||||||
|
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
|
||||||
|
}
|
||||||
|
|
||||||
|
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
|
||||||
|
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Backend Configuration
|
||||||
|
|
||||||
|
The following backend settings affect reconnection behavior:
|
||||||
|
|
||||||
|
| Setting | Default | Description |
|
||||||
|
| ------------------- | ------- | ---------------------------------- |
|
||||||
|
| `stream_ttl` | 3600s | How long streams are kept in Redis |
|
||||||
|
| `stream_max_length` | 1000 | Max messages per stream |
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
To test reconnection locally:
|
||||||
|
|
||||||
|
1. Start a long-running operation (e.g., agent generation)
|
||||||
|
2. Note the `task_id` from the `operation_started` response
|
||||||
|
3. Close the browser tab
|
||||||
|
4. Reopen and call `reconnectToTask` with the saved `task_id`
|
||||||
|
5. Verify that missed messages are replayed
|
||||||
|
|
||||||
|
See the main README for full local development setup.
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
/**
|
||||||
|
* Constants for the chat system.
|
||||||
|
*
|
||||||
|
* Centralizes magic strings and values used across chat components.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// LocalStorage keys
|
||||||
|
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
|
||||||
|
|
||||||
|
// Redis Stream IDs
|
||||||
|
export const INITIAL_MESSAGE_ID = "0";
|
||||||
|
export const INITIAL_STREAM_ID = "0-0";
|
||||||
|
|
||||||
|
// TTL values (in milliseconds)
|
||||||
|
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
||||||
|
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour
|
||||||
@@ -1,6 +1,12 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { create } from "zustand";
|
import { create } from "zustand";
|
||||||
|
import {
|
||||||
|
ACTIVE_TASK_TTL_MS,
|
||||||
|
COMPLETED_STREAM_TTL_MS,
|
||||||
|
INITIAL_STREAM_ID,
|
||||||
|
STORAGE_KEY_ACTIVE_TASKS,
|
||||||
|
} from "./chat-constants";
|
||||||
import type {
|
import type {
|
||||||
ActiveStream,
|
ActiveStream,
|
||||||
StreamChunk,
|
StreamChunk,
|
||||||
@@ -8,15 +14,59 @@ import type {
|
|||||||
StreamResult,
|
StreamResult,
|
||||||
StreamStatus,
|
StreamStatus,
|
||||||
} from "./chat-types";
|
} from "./chat-types";
|
||||||
import { executeStream } from "./stream-executor";
|
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
||||||
|
|
||||||
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
export interface ActiveTaskInfo {
|
||||||
|
taskId: string;
|
||||||
|
sessionId: string;
|
||||||
|
operationId: string;
|
||||||
|
toolName: string;
|
||||||
|
lastMessageId: string;
|
||||||
|
startedAt: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Load active tasks from localStorage */
|
||||||
|
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
||||||
|
if (typeof window === "undefined") return new Map();
|
||||||
|
try {
|
||||||
|
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
|
||||||
|
if (!stored) return new Map();
|
||||||
|
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
||||||
|
const now = Date.now();
|
||||||
|
const tasks = new Map<string, ActiveTaskInfo>();
|
||||||
|
// Filter out expired tasks
|
||||||
|
for (const [sessionId, task] of Object.entries(parsed)) {
|
||||||
|
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
|
||||||
|
tasks.set(sessionId, task);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tasks;
|
||||||
|
} catch {
|
||||||
|
return new Map();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Save active tasks to localStorage */
|
||||||
|
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
try {
|
||||||
|
const obj: Record<string, ActiveTaskInfo> = {};
|
||||||
|
for (const [sessionId, task] of tasks) {
|
||||||
|
obj[sessionId] = task;
|
||||||
|
}
|
||||||
|
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
|
||||||
|
} catch {
|
||||||
|
// Ignore storage errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
interface ChatStoreState {
|
interface ChatStoreState {
|
||||||
activeStreams: Map<string, ActiveStream>;
|
activeStreams: Map<string, ActiveStream>;
|
||||||
completedStreams: Map<string, StreamResult>;
|
completedStreams: Map<string, StreamResult>;
|
||||||
activeSessions: Set<string>;
|
activeSessions: Set<string>;
|
||||||
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
||||||
|
/** Active tasks for SSE reconnection - keyed by sessionId */
|
||||||
|
activeTasks: Map<string, ActiveTaskInfo>;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ChatStoreActions {
|
interface ChatStoreActions {
|
||||||
@@ -41,6 +91,24 @@ interface ChatStoreActions {
|
|||||||
unregisterActiveSession: (sessionId: string) => void;
|
unregisterActiveSession: (sessionId: string) => void;
|
||||||
isSessionActive: (sessionId: string) => boolean;
|
isSessionActive: (sessionId: string) => boolean;
|
||||||
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
||||||
|
/** Track active task for SSE reconnection */
|
||||||
|
setActiveTask: (
|
||||||
|
sessionId: string,
|
||||||
|
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
|
||||||
|
) => void;
|
||||||
|
/** Get active task for a session */
|
||||||
|
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
|
||||||
|
/** Clear active task when operation completes */
|
||||||
|
clearActiveTask: (sessionId: string) => void;
|
||||||
|
/** Reconnect to an existing task stream */
|
||||||
|
reconnectToTask: (
|
||||||
|
sessionId: string,
|
||||||
|
taskId: string,
|
||||||
|
lastMessageId?: string,
|
||||||
|
onChunk?: (chunk: StreamChunk) => void,
|
||||||
|
) => Promise<void>;
|
||||||
|
/** Update last message ID for a task (for tracking replay position) */
|
||||||
|
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatStore = ChatStoreState & ChatStoreActions;
|
type ChatStore = ChatStoreState & ChatStoreActions;
|
||||||
@@ -64,18 +132,126 @@ function cleanupExpiredStreams(
|
|||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
const cleaned = new Map(completedStreams);
|
const cleaned = new Map(completedStreams);
|
||||||
for (const [sessionId, result] of cleaned) {
|
for (const [sessionId, result] of cleaned) {
|
||||||
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
|
||||||
cleaned.delete(sessionId);
|
cleaned.delete(sessionId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cleaned;
|
return cleaned;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finalize a stream by moving it from activeStreams to completedStreams.
|
||||||
|
* Also handles cleanup and notifications.
|
||||||
|
*/
|
||||||
|
function finalizeStream(
|
||||||
|
sessionId: string,
|
||||||
|
stream: ActiveStream,
|
||||||
|
onChunk: ((chunk: StreamChunk) => void) | undefined,
|
||||||
|
get: () => ChatStoreState & ChatStoreActions,
|
||||||
|
set: (state: Partial<ChatStoreState>) => void,
|
||||||
|
): void {
|
||||||
|
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||||
|
|
||||||
|
if (stream.status !== "streaming") {
|
||||||
|
const currentState = get();
|
||||||
|
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||||
|
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||||
|
|
||||||
|
const storedStream = finalActiveStreams.get(sessionId);
|
||||||
|
if (storedStream === stream) {
|
||||||
|
const result: StreamResult = {
|
||||||
|
sessionId,
|
||||||
|
status: stream.status,
|
||||||
|
chunks: stream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: stream.error,
|
||||||
|
};
|
||||||
|
finalCompletedStreams.set(sessionId, result);
|
||||||
|
finalActiveStreams.delete(sessionId);
|
||||||
|
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||||
|
set({
|
||||||
|
activeStreams: finalActiveStreams,
|
||||||
|
completedStreams: finalCompletedStreams,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (stream.status === "completed" || stream.status === "error") {
|
||||||
|
notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean up an existing stream for a session and move it to completed streams.
|
||||||
|
* Returns updated maps for both active and completed streams.
|
||||||
|
*/
|
||||||
|
function cleanupExistingStream(
|
||||||
|
sessionId: string,
|
||||||
|
activeStreams: Map<string, ActiveStream>,
|
||||||
|
completedStreams: Map<string, StreamResult>,
|
||||||
|
callbacks: Set<StreamCompleteCallback>,
|
||||||
|
): {
|
||||||
|
activeStreams: Map<string, ActiveStream>;
|
||||||
|
completedStreams: Map<string, StreamResult>;
|
||||||
|
} {
|
||||||
|
const newActiveStreams = new Map(activeStreams);
|
||||||
|
let newCompletedStreams = new Map(completedStreams);
|
||||||
|
|
||||||
|
const existingStream = newActiveStreams.get(sessionId);
|
||||||
|
if (existingStream) {
|
||||||
|
existingStream.abortController.abort();
|
||||||
|
const normalizedStatus =
|
||||||
|
existingStream.status === "streaming"
|
||||||
|
? "completed"
|
||||||
|
: existingStream.status;
|
||||||
|
const result: StreamResult = {
|
||||||
|
sessionId,
|
||||||
|
status: normalizedStatus,
|
||||||
|
chunks: existingStream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: existingStream.error,
|
||||||
|
};
|
||||||
|
newCompletedStreams.set(sessionId, result);
|
||||||
|
newActiveStreams.delete(sessionId);
|
||||||
|
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||||
|
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||||
|
notifyStreamComplete(callbacks, sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
activeStreams: newActiveStreams,
|
||||||
|
completedStreams: newCompletedStreams,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new active stream with initial state.
|
||||||
|
*/
|
||||||
|
function createActiveStream(
|
||||||
|
sessionId: string,
|
||||||
|
onChunk?: (chunk: StreamChunk) => void,
|
||||||
|
): ActiveStream {
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||||
|
if (onChunk) initialCallbacks.add(onChunk);
|
||||||
|
|
||||||
|
return {
|
||||||
|
sessionId,
|
||||||
|
abortController,
|
||||||
|
status: "streaming",
|
||||||
|
startedAt: Date.now(),
|
||||||
|
chunks: [],
|
||||||
|
onChunkCallbacks: initialCallbacks,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||||
activeStreams: new Map(),
|
activeStreams: new Map(),
|
||||||
completedStreams: new Map(),
|
completedStreams: new Map(),
|
||||||
activeSessions: new Set(),
|
activeSessions: new Set(),
|
||||||
streamCompleteCallbacks: new Set(),
|
streamCompleteCallbacks: new Set(),
|
||||||
|
activeTasks: loadPersistedTasks(),
|
||||||
|
|
||||||
startStream: async function startStream(
|
startStream: async function startStream(
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -85,45 +261,21 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
onChunk,
|
onChunk,
|
||||||
) {
|
) {
|
||||||
const state = get();
|
const state = get();
|
||||||
const newActiveStreams = new Map(state.activeStreams);
|
|
||||||
let newCompletedStreams = new Map(state.completedStreams);
|
|
||||||
const callbacks = state.streamCompleteCallbacks;
|
const callbacks = state.streamCompleteCallbacks;
|
||||||
|
|
||||||
const existingStream = newActiveStreams.get(sessionId);
|
// Clean up any existing stream for this session
|
||||||
if (existingStream) {
|
const {
|
||||||
existingStream.abortController.abort();
|
activeStreams: newActiveStreams,
|
||||||
const normalizedStatus =
|
completedStreams: newCompletedStreams,
|
||||||
existingStream.status === "streaming"
|
} = cleanupExistingStream(
|
||||||
? "completed"
|
|
||||||
: existingStream.status;
|
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: normalizedStatus,
|
|
||||||
chunks: existingStream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: existingStream.error,
|
|
||||||
};
|
|
||||||
newCompletedStreams.set(sessionId, result);
|
|
||||||
newActiveStreams.delete(sessionId);
|
|
||||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
|
||||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
|
||||||
notifyStreamComplete(callbacks, sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const abortController = new AbortController();
|
|
||||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
|
||||||
if (onChunk) initialCallbacks.add(onChunk);
|
|
||||||
|
|
||||||
const stream: ActiveStream = {
|
|
||||||
sessionId,
|
sessionId,
|
||||||
abortController,
|
state.activeStreams,
|
||||||
status: "streaming",
|
state.completedStreams,
|
||||||
startedAt: Date.now(),
|
callbacks,
|
||||||
chunks: [],
|
);
|
||||||
onChunkCallbacks: initialCallbacks,
|
|
||||||
};
|
|
||||||
|
|
||||||
|
// Create new stream
|
||||||
|
const stream = createActiveStream(sessionId, onChunk);
|
||||||
newActiveStreams.set(sessionId, stream);
|
newActiveStreams.set(sessionId, stream);
|
||||||
set({
|
set({
|
||||||
activeStreams: newActiveStreams,
|
activeStreams: newActiveStreams,
|
||||||
@@ -133,36 +285,7 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
try {
|
try {
|
||||||
await executeStream(stream, message, isUserMessage, context);
|
await executeStream(stream, message, isUserMessage, context);
|
||||||
} finally {
|
} finally {
|
||||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
finalizeStream(sessionId, stream, onChunk, get, set);
|
||||||
if (stream.status !== "streaming") {
|
|
||||||
const currentState = get();
|
|
||||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
|
||||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
|
||||||
|
|
||||||
const storedStream = finalActiveStreams.get(sessionId);
|
|
||||||
if (storedStream === stream) {
|
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: stream.status,
|
|
||||||
chunks: stream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: stream.error,
|
|
||||||
};
|
|
||||||
finalCompletedStreams.set(sessionId, result);
|
|
||||||
finalActiveStreams.delete(sessionId);
|
|
||||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
|
||||||
set({
|
|
||||||
activeStreams: finalActiveStreams,
|
|
||||||
completedStreams: finalCompletedStreams,
|
|
||||||
});
|
|
||||||
if (stream.status === "completed" || stream.status === "error") {
|
|
||||||
notifyStreamComplete(
|
|
||||||
currentState.streamCompleteCallbacks,
|
|
||||||
sessionId,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -286,4 +409,93 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
|
||||||
|
setActiveTask: function setActiveTask(sessionId, taskInfo) {
|
||||||
|
const state = get();
|
||||||
|
const newActiveTasks = new Map(state.activeTasks);
|
||||||
|
newActiveTasks.set(sessionId, {
|
||||||
|
...taskInfo,
|
||||||
|
sessionId,
|
||||||
|
startedAt: Date.now(),
|
||||||
|
});
|
||||||
|
set({ activeTasks: newActiveTasks });
|
||||||
|
persistTasks(newActiveTasks);
|
||||||
|
},
|
||||||
|
|
||||||
|
getActiveTask: function getActiveTask(sessionId) {
|
||||||
|
return get().activeTasks.get(sessionId);
|
||||||
|
},
|
||||||
|
|
||||||
|
clearActiveTask: function clearActiveTask(sessionId) {
|
||||||
|
const state = get();
|
||||||
|
if (!state.activeTasks.has(sessionId)) return;
|
||||||
|
|
||||||
|
const newActiveTasks = new Map(state.activeTasks);
|
||||||
|
newActiveTasks.delete(sessionId);
|
||||||
|
set({ activeTasks: newActiveTasks });
|
||||||
|
persistTasks(newActiveTasks);
|
||||||
|
},
|
||||||
|
|
||||||
|
reconnectToTask: async function reconnectToTask(
|
||||||
|
sessionId,
|
||||||
|
taskId,
|
||||||
|
lastMessageId = INITIAL_STREAM_ID,
|
||||||
|
onChunk,
|
||||||
|
) {
|
||||||
|
const state = get();
|
||||||
|
const callbacks = state.streamCompleteCallbacks;
|
||||||
|
|
||||||
|
// Clean up any existing stream for this session
|
||||||
|
const {
|
||||||
|
activeStreams: newActiveStreams,
|
||||||
|
completedStreams: newCompletedStreams,
|
||||||
|
} = cleanupExistingStream(
|
||||||
|
sessionId,
|
||||||
|
state.activeStreams,
|
||||||
|
state.completedStreams,
|
||||||
|
callbacks,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create new stream for reconnection
|
||||||
|
const stream = createActiveStream(sessionId, onChunk);
|
||||||
|
newActiveStreams.set(sessionId, stream);
|
||||||
|
set({
|
||||||
|
activeStreams: newActiveStreams,
|
||||||
|
completedStreams: newCompletedStreams,
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
await executeTaskReconnect(stream, taskId, lastMessageId);
|
||||||
|
} finally {
|
||||||
|
finalizeStream(sessionId, stream, onChunk, get, set);
|
||||||
|
|
||||||
|
// Clear active task on completion
|
||||||
|
if (stream.status === "completed" || stream.status === "error") {
|
||||||
|
const taskState = get();
|
||||||
|
if (taskState.activeTasks.has(sessionId)) {
|
||||||
|
const newActiveTasks = new Map(taskState.activeTasks);
|
||||||
|
newActiveTasks.delete(sessionId);
|
||||||
|
set({ activeTasks: newActiveTasks });
|
||||||
|
persistTasks(newActiveTasks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
updateTaskLastMessageId: function updateTaskLastMessageId(
|
||||||
|
sessionId,
|
||||||
|
lastMessageId,
|
||||||
|
) {
|
||||||
|
const state = get();
|
||||||
|
const task = state.activeTasks.get(sessionId);
|
||||||
|
if (!task) return;
|
||||||
|
|
||||||
|
const newActiveTasks = new Map(state.activeTasks);
|
||||||
|
newActiveTasks.set(sessionId, {
|
||||||
|
...task,
|
||||||
|
lastMessageId,
|
||||||
|
});
|
||||||
|
set({ activeTasks: newActiveTasks });
|
||||||
|
persistTasks(newActiveTasks);
|
||||||
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
|
|||||||
|
|
||||||
export interface StreamChunk {
|
export interface StreamChunk {
|
||||||
type:
|
type:
|
||||||
|
| "stream_start"
|
||||||
| "text_chunk"
|
| "text_chunk"
|
||||||
| "text_ended"
|
| "text_ended"
|
||||||
| "tool_call"
|
| "tool_call"
|
||||||
@@ -15,6 +16,7 @@ export interface StreamChunk {
|
|||||||
| "error"
|
| "error"
|
||||||
| "usage"
|
| "usage"
|
||||||
| "stream_end";
|
| "stream_end";
|
||||||
|
taskId?: string;
|
||||||
timestamp?: string;
|
timestamp?: string;
|
||||||
content?: string;
|
content?: string;
|
||||||
message?: string;
|
message?: string;
|
||||||
@@ -41,7 +43,7 @@ export interface StreamChunk {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type VercelStreamChunk =
|
export type VercelStreamChunk =
|
||||||
| { type: "start"; messageId: string }
|
| { type: "start"; messageId: string; taskId?: string }
|
||||||
| { type: "finish" }
|
| { type: "finish" }
|
||||||
| { type: "text-start"; id: string }
|
| { type: "text-start"; id: string }
|
||||||
| { type: "text-delta"; id: string; delta: string }
|
| { type: "text-delta"; id: string; delta: string }
|
||||||
@@ -92,3 +94,70 @@ export interface StreamResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type StreamCompleteCallback = (sessionId: string) => void;
|
export type StreamCompleteCallback = (sessionId: string) => void;
|
||||||
|
|
||||||
|
// Type guards for message types
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a message has a toolId property.
|
||||||
|
*/
|
||||||
|
export function hasToolId<T extends { type: string }>(
|
||||||
|
msg: T,
|
||||||
|
): msg is T & { toolId: string } {
|
||||||
|
return (
|
||||||
|
"toolId" in msg &&
|
||||||
|
typeof (msg as Record<string, unknown>).toolId === "string"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a message has an operationId property.
|
||||||
|
*/
|
||||||
|
export function hasOperationId<T extends { type: string }>(
|
||||||
|
msg: T,
|
||||||
|
): msg is T & { operationId: string } {
|
||||||
|
return (
|
||||||
|
"operationId" in msg &&
|
||||||
|
typeof (msg as Record<string, unknown>).operationId === "string"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a message has a toolCallId property.
|
||||||
|
*/
|
||||||
|
export function hasToolCallId<T extends { type: string }>(
|
||||||
|
msg: T,
|
||||||
|
): msg is T & { toolCallId: string } {
|
||||||
|
return (
|
||||||
|
"toolCallId" in msg &&
|
||||||
|
typeof (msg as Record<string, unknown>).toolCallId === "string"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a message is an operation message type.
|
||||||
|
*/
|
||||||
|
export function isOperationMessage<T extends { type: string }>(
|
||||||
|
msg: T,
|
||||||
|
): msg is T & {
|
||||||
|
type: "operation_started" | "operation_pending" | "operation_in_progress";
|
||||||
|
} {
|
||||||
|
return (
|
||||||
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the tool ID from a message if available.
|
||||||
|
* Checks toolId, operationId, and toolCallId properties.
|
||||||
|
*/
|
||||||
|
export function getToolIdFromMessage<T extends { type: string }>(
|
||||||
|
msg: T,
|
||||||
|
): string | undefined {
|
||||||
|
const record = msg as Record<string, unknown>;
|
||||||
|
if (typeof record.toolId === "string") return record.toolId;
|
||||||
|
if (typeof record.operationId === "string") return record.operationId;
|
||||||
|
if (typeof record.toolCallId === "string") return record.toolCallId;
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi
|
|||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
|
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
@@ -17,6 +16,13 @@ export interface ChatContainerProps {
|
|||||||
className?: string;
|
className?: string;
|
||||||
onStreamingChange?: (isStreaming: boolean) => void;
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
|
/** Active stream info from the server for reconnection */
|
||||||
|
activeStream?: {
|
||||||
|
taskId: string;
|
||||||
|
lastMessageId: string;
|
||||||
|
operationId: string;
|
||||||
|
toolName: string;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatContainer({
|
export function ChatContainer({
|
||||||
@@ -26,6 +32,7 @@ export function ChatContainer({
|
|||||||
className,
|
className,
|
||||||
onStreamingChange,
|
onStreamingChange,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
|
activeStream,
|
||||||
}: ChatContainerProps) {
|
}: ChatContainerProps) {
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
@@ -41,16 +48,13 @@ export function ChatContainer({
|
|||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
|
activeStream,
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
onStreamingChange?.(isStreaming);
|
onStreamingChange?.(isStreaming);
|
||||||
}, [isStreaming, onStreamingChange]);
|
}, [isStreaming, onStreamingChange]);
|
||||||
|
|
||||||
const breakpoint = useBreakpoint();
|
|
||||||
const isMobile =
|
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -118,11 +122,7 @@ export function ChatContainer({
|
|||||||
disabled={isStreaming || !sessionId}
|
disabled={isStreaming || !sessionId}
|
||||||
isStreaming={isStreaming}
|
isStreaming={isStreaming}
|
||||||
onStop={stopStreaming}
|
onStop={stopStreaming}
|
||||||
placeholder={
|
placeholder="What else can I help with?"
|
||||||
isMobile
|
|
||||||
? "You can search or just ask"
|
|
||||||
: 'You can search or just ask — e.g. "create a blog post outline"'
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { toast } from "sonner";
|
|||||||
import type { StreamChunk } from "../../chat-types";
|
import type { StreamChunk } from "../../chat-types";
|
||||||
import type { HandlerDependencies } from "./handlers";
|
import type { HandlerDependencies } from "./handlers";
|
||||||
import {
|
import {
|
||||||
|
getErrorDisplayMessage,
|
||||||
handleError,
|
handleError,
|
||||||
handleLoginNeeded,
|
handleLoginNeeded,
|
||||||
handleStreamEnd,
|
handleStreamEnd,
|
||||||
@@ -24,16 +25,22 @@ export function createStreamEventDispatcher(
|
|||||||
chunk.type === "need_login" ||
|
chunk.type === "need_login" ||
|
||||||
chunk.type === "error"
|
chunk.type === "error"
|
||||||
) {
|
) {
|
||||||
if (!deps.hasResponseRef.current) {
|
|
||||||
console.info("[ChatStream] First response chunk:", {
|
|
||||||
type: chunk.type,
|
|
||||||
sessionId: deps.sessionId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
deps.hasResponseRef.current = true;
|
deps.hasResponseRef.current = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
|
case "stream_start":
|
||||||
|
// Store task ID for SSE reconnection
|
||||||
|
if (chunk.taskId && deps.onActiveTaskStarted) {
|
||||||
|
deps.onActiveTaskStarted({
|
||||||
|
taskId: chunk.taskId,
|
||||||
|
operationId: chunk.taskId,
|
||||||
|
toolName: "chat",
|
||||||
|
toolCallId: "chat_stream",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
case "text_chunk":
|
case "text_chunk":
|
||||||
handleTextChunk(chunk, deps);
|
handleTextChunk(chunk, deps);
|
||||||
break;
|
break;
|
||||||
@@ -56,11 +63,7 @@ export function createStreamEventDispatcher(
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
case "stream_end":
|
case "stream_end":
|
||||||
console.info("[ChatStream] Stream ended:", {
|
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
|
||||||
sessionId: deps.sessionId,
|
|
||||||
hasResponse: deps.hasResponseRef.current,
|
|
||||||
chunkCount: deps.streamingChunksRef.current.length,
|
|
||||||
});
|
|
||||||
handleStreamEnd(chunk, deps);
|
handleStreamEnd(chunk, deps);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@@ -70,7 +73,7 @@ export function createStreamEventDispatcher(
|
|||||||
// Show toast at dispatcher level to avoid circular dependencies
|
// Show toast at dispatcher level to avoid circular dependencies
|
||||||
if (!isRegionBlocked) {
|
if (!isRegionBlocked) {
|
||||||
toast.error("Chat Error", {
|
toast.error("Chat Error", {
|
||||||
description: chunk.message || chunk.content || "An error occurred",
|
description: getErrorDisplayMessage(chunk),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -18,11 +18,19 @@ export interface HandlerDependencies {
|
|||||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||||
streamingChunksRef: MutableRefObject<string[]>;
|
streamingChunksRef: MutableRefObject<string[]>;
|
||||||
hasResponseRef: MutableRefObject<boolean>;
|
hasResponseRef: MutableRefObject<boolean>;
|
||||||
|
textFinalizedRef: MutableRefObject<boolean>;
|
||||||
|
streamEndedRef: MutableRefObject<boolean>;
|
||||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
|
onActiveTaskStarted?: (taskInfo: {
|
||||||
|
taskId: string;
|
||||||
|
operationId: string;
|
||||||
|
toolName: string;
|
||||||
|
toolCallId: string;
|
||||||
|
}) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||||
@@ -32,6 +40,25 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
|||||||
return message.toLowerCase().includes("not available in your region");
|
return message.toLowerCase().includes("not available in your region");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getUserFriendlyErrorMessage(
|
||||||
|
code: string | undefined,
|
||||||
|
): string | undefined {
|
||||||
|
switch (code) {
|
||||||
|
case "TASK_EXPIRED":
|
||||||
|
return "This operation has expired. Please try again.";
|
||||||
|
case "TASK_NOT_FOUND":
|
||||||
|
return "Could not find the requested operation.";
|
||||||
|
case "ACCESS_DENIED":
|
||||||
|
return "You do not have access to this operation.";
|
||||||
|
case "QUEUE_OVERFLOW":
|
||||||
|
return "Connection was interrupted. Please refresh to continue.";
|
||||||
|
case "MODEL_NOT_AVAILABLE_REGION":
|
||||||
|
return "This model is not available in your region.";
|
||||||
|
default:
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
if (!chunk.content) return;
|
if (!chunk.content) return;
|
||||||
deps.setHasTextChunks(true);
|
deps.setHasTextChunks(true);
|
||||||
@@ -46,10 +73,15 @@ export function handleTextEnded(
|
|||||||
_chunk: StreamChunk,
|
_chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
|
if (deps.textFinalizedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const completedText = deps.streamingChunksRef.current.join("");
|
const completedText = deps.streamingChunksRef.current.join("");
|
||||||
if (completedText.trim()) {
|
if (completedText.trim()) {
|
||||||
|
deps.textFinalizedRef.current = true;
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
// Check if this exact message already exists to prevent duplicates
|
|
||||||
const exists = prev.some(
|
const exists = prev.some(
|
||||||
(msg) =>
|
(msg) =>
|
||||||
msg.type === "message" &&
|
msg.type === "message" &&
|
||||||
@@ -76,9 +108,14 @@ export function handleToolCallStart(
|
|||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
|
// Use deterministic fallback instead of Date.now() to ensure same ID on replay
|
||||||
|
const toolId =
|
||||||
|
chunk.tool_id ||
|
||||||
|
`tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`;
|
||||||
|
|
||||||
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
||||||
type: "tool_call",
|
type: "tool_call",
|
||||||
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
toolId,
|
||||||
toolName: chunk.tool_name || "Executing",
|
toolName: chunk.tool_name || "Executing",
|
||||||
arguments: chunk.arguments || {},
|
arguments: chunk.arguments || {},
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
@@ -111,6 +148,29 @@ export function handleToolCallStart(
|
|||||||
deps.setMessages(updateToolCallMessages);
|
deps.setMessages(updateToolCallMessages);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TOOL_RESPONSE_TYPES = new Set([
|
||||||
|
"tool_response",
|
||||||
|
"operation_started",
|
||||||
|
"operation_pending",
|
||||||
|
"operation_in_progress",
|
||||||
|
"execution_started",
|
||||||
|
"agent_carousel",
|
||||||
|
"clarification_needed",
|
||||||
|
]);
|
||||||
|
|
||||||
|
function hasResponseForTool(
|
||||||
|
messages: ChatMessageData[],
|
||||||
|
toolId: string,
|
||||||
|
): boolean {
|
||||||
|
return messages.some((msg) => {
|
||||||
|
if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false;
|
||||||
|
const msgToolId =
|
||||||
|
(msg as { toolId?: string }).toolId ||
|
||||||
|
(msg as { toolCallId?: string }).toolCallId;
|
||||||
|
return msgToolId === toolId;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export function handleToolResponse(
|
export function handleToolResponse(
|
||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
@@ -152,31 +212,49 @@ export function handleToolResponse(
|
|||||||
) {
|
) {
|
||||||
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
|
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
|
||||||
if (inputsMessage) {
|
if (inputsMessage) {
|
||||||
deps.setMessages((prev) => [...prev, inputsMessage]);
|
deps.setMessages((prev) => {
|
||||||
|
// Check for duplicate inputs_needed message
|
||||||
|
const exists = prev.some((msg) => msg.type === "inputs_needed");
|
||||||
|
if (exists) return prev;
|
||||||
|
return [...prev, inputsMessage];
|
||||||
|
});
|
||||||
}
|
}
|
||||||
const credentialsMessage = extractCredentialsNeeded(
|
const credentialsMessage = extractCredentialsNeeded(
|
||||||
parsedResult,
|
parsedResult,
|
||||||
chunk.tool_name,
|
chunk.tool_name,
|
||||||
);
|
);
|
||||||
if (credentialsMessage) {
|
if (credentialsMessage) {
|
||||||
deps.setMessages((prev) => [...prev, credentialsMessage]);
|
deps.setMessages((prev) => {
|
||||||
|
// Check for duplicate credentials_needed message
|
||||||
|
const exists = prev.some((msg) => msg.type === "credentials_needed");
|
||||||
|
if (exists) return prev;
|
||||||
|
return [...prev, credentialsMessage];
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Trigger polling when operation_started is received
|
|
||||||
if (responseMessage.type === "operation_started") {
|
if (responseMessage.type === "operation_started") {
|
||||||
deps.onOperationStarted?.();
|
deps.onOperationStarted?.();
|
||||||
|
const taskId = (responseMessage as { taskId?: string }).taskId;
|
||||||
|
if (taskId && deps.onActiveTaskStarted) {
|
||||||
|
deps.onActiveTaskStarted({
|
||||||
|
taskId,
|
||||||
|
operationId:
|
||||||
|
(responseMessage as { operationId?: string }).operationId || "",
|
||||||
|
toolName: (responseMessage as { toolName?: string }).toolName || "",
|
||||||
|
toolCallId: (responseMessage as { toolId?: string }).toolId || "",
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
const toolCallIndex = prev.findIndex(
|
const toolCallIndex = prev.findIndex(
|
||||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||||
);
|
);
|
||||||
const hasResponse = prev.some(
|
if (hasResponseForTool(prev, chunk.tool_id!)) {
|
||||||
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
|
return prev;
|
||||||
);
|
}
|
||||||
if (hasResponse) return prev;
|
|
||||||
if (toolCallIndex !== -1) {
|
if (toolCallIndex !== -1) {
|
||||||
const newMessages = [...prev];
|
const newMessages = [...prev];
|
||||||
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
||||||
@@ -198,28 +276,48 @@ export function handleLoginNeeded(
|
|||||||
agentInfo: chunk.agent_info,
|
agentInfo: chunk.agent_info,
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
};
|
};
|
||||||
deps.setMessages((prev) => [...prev, loginNeededMessage]);
|
deps.setMessages((prev) => {
|
||||||
|
// Check for duplicate login_needed message
|
||||||
|
const exists = prev.some((msg) => msg.type === "login_needed");
|
||||||
|
if (exists) return prev;
|
||||||
|
return [...prev, loginNeededMessage];
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
export function handleStreamEnd(
|
export function handleStreamEnd(
|
||||||
_chunk: StreamChunk,
|
_chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
|
if (deps.streamEndedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
deps.streamEndedRef.current = true;
|
||||||
|
|
||||||
const completedContent = deps.streamingChunksRef.current.join("");
|
const completedContent = deps.streamingChunksRef.current.join("");
|
||||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||||
deps.setMessages((prev) => [
|
|
||||||
...prev,
|
|
||||||
{
|
|
||||||
type: "message",
|
|
||||||
role: "assistant",
|
|
||||||
content: "No response received. Please try again.",
|
|
||||||
timestamp: new Date(),
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (completedContent.trim()) {
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
// Check if this exact message already exists to prevent duplicates
|
const exists = prev.some(
|
||||||
|
(msg) =>
|
||||||
|
msg.type === "message" &&
|
||||||
|
msg.role === "assistant" &&
|
||||||
|
msg.content === "No response received. Please try again.",
|
||||||
|
);
|
||||||
|
if (exists) return prev;
|
||||||
|
return [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: "No response received. Please try again.",
|
||||||
|
timestamp: new Date(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (completedContent.trim() && !deps.textFinalizedRef.current) {
|
||||||
|
deps.textFinalizedRef.current = true;
|
||||||
|
|
||||||
|
deps.setMessages((prev) => {
|
||||||
const exists = prev.some(
|
const exists = prev.some(
|
||||||
(msg) =>
|
(msg) =>
|
||||||
msg.type === "message" &&
|
msg.type === "message" &&
|
||||||
@@ -244,8 +342,6 @@ export function handleStreamEnd(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
|
||||||
console.error("Stream error:", errorMessage);
|
|
||||||
if (isRegionBlockedError(chunk)) {
|
if (isRegionBlockedError(chunk)) {
|
||||||
deps.setIsRegionBlockedModalOpen(true);
|
deps.setIsRegionBlockedModalOpen(true);
|
||||||
}
|
}
|
||||||
@@ -253,4 +349,14 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
|||||||
deps.setHasTextChunks(false);
|
deps.setHasTextChunks(false);
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
deps.streamingChunksRef.current = [];
|
deps.streamingChunksRef.current = [];
|
||||||
|
deps.textFinalizedRef.current = false;
|
||||||
|
deps.streamEndedRef.current = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getErrorDisplayMessage(chunk: StreamChunk): string {
|
||||||
|
const friendlyMessage = getUserFriendlyErrorMessage(chunk.code);
|
||||||
|
if (friendlyMessage) {
|
||||||
|
return friendlyMessage;
|
||||||
|
}
|
||||||
|
return chunk.message || chunk.content || "An error occurred";
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -349,6 +349,7 @@ export function parseToolResponse(
|
|||||||
toolName: (parsedResult.tool_name as string) || toolName,
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
toolId,
|
toolId,
|
||||||
operationId: (parsedResult.operation_id as string) || "",
|
operationId: (parsedResult.operation_id as string) || "",
|
||||||
|
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
|
||||||
message:
|
message:
|
||||||
(parsedResult.message as string) ||
|
(parsedResult.message as string) ||
|
||||||
"Operation started. You can close this tab.",
|
"Operation started. You can close this tab.",
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
import { useEffect, useMemo, useRef, useState } from "react";
|
import { useEffect, useMemo, useRef, useState } from "react";
|
||||||
|
import { INITIAL_STREAM_ID } from "../../chat-constants";
|
||||||
import { useChatStore } from "../../chat-store";
|
import { useChatStore } from "../../chat-store";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useChatStream } from "../../useChatStream";
|
import { useChatStream } from "../../useChatStream";
|
||||||
import { usePageContext } from "../../usePageContext";
|
import { usePageContext } from "../../usePageContext";
|
||||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||||
|
import {
|
||||||
|
getToolIdFromMessage,
|
||||||
|
hasToolId,
|
||||||
|
isOperationMessage,
|
||||||
|
type StreamChunk,
|
||||||
|
} from "../../chat-types";
|
||||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||||
import {
|
import {
|
||||||
createUserMessage,
|
createUserMessage,
|
||||||
@@ -14,6 +21,13 @@ import {
|
|||||||
processInitialMessages,
|
processInitialMessages,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
|
const TOOL_RESULT_TYPES = new Set([
|
||||||
|
"tool_response",
|
||||||
|
"agent_carousel",
|
||||||
|
"execution_started",
|
||||||
|
"clarification_needed",
|
||||||
|
]);
|
||||||
|
|
||||||
// Helper to generate deduplication key for a message
|
// Helper to generate deduplication key for a message
|
||||||
function getMessageKey(msg: ChatMessageData): string {
|
function getMessageKey(msg: ChatMessageData): string {
|
||||||
if (msg.type === "message") {
|
if (msg.type === "message") {
|
||||||
@@ -23,14 +37,18 @@ function getMessageKey(msg: ChatMessageData): string {
|
|||||||
return `msg:${msg.role}:${msg.content}`;
|
return `msg:${msg.role}:${msg.content}`;
|
||||||
} else if (msg.type === "tool_call") {
|
} else if (msg.type === "tool_call") {
|
||||||
return `toolcall:${msg.toolId}`;
|
return `toolcall:${msg.toolId}`;
|
||||||
} else if (msg.type === "tool_response") {
|
} else if (TOOL_RESULT_TYPES.has(msg.type)) {
|
||||||
return `toolresponse:${(msg as any).toolId}`;
|
// Unified key for all tool result types - same toolId with different types
|
||||||
} else if (
|
// (tool_response vs agent_carousel) should deduplicate to the same key
|
||||||
msg.type === "operation_started" ||
|
const toolId = getToolIdFromMessage(msg);
|
||||||
msg.type === "operation_pending" ||
|
// If no toolId, fall back to content-based key to avoid empty key collisions
|
||||||
msg.type === "operation_in_progress"
|
if (!toolId) {
|
||||||
) {
|
return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`;
|
||||||
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
}
|
||||||
|
return `toolresult:${toolId}`;
|
||||||
|
} else if (isOperationMessage(msg)) {
|
||||||
|
const toolId = getToolIdFromMessage(msg) || "";
|
||||||
|
return `op:${toolId}:${msg.toolName}`;
|
||||||
} else {
|
} else {
|
||||||
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||||
}
|
}
|
||||||
@@ -41,6 +59,13 @@ interface Args {
|
|||||||
initialMessages: SessionDetailResponse["messages"];
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
|
/** Active stream info from the server for reconnection */
|
||||||
|
activeStream?: {
|
||||||
|
taskId: string;
|
||||||
|
lastMessageId: string;
|
||||||
|
operationId: string;
|
||||||
|
toolName: string;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChatContainer({
|
export function useChatContainer({
|
||||||
@@ -48,6 +73,7 @@ export function useChatContainer({
|
|||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
|
activeStream,
|
||||||
}: Args) {
|
}: Args) {
|
||||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||||
@@ -57,6 +83,8 @@ export function useChatContainer({
|
|||||||
useState(false);
|
useState(false);
|
||||||
const hasResponseRef = useRef(false);
|
const hasResponseRef = useRef(false);
|
||||||
const streamingChunksRef = useRef<string[]>([]);
|
const streamingChunksRef = useRef<string[]>([]);
|
||||||
|
const textFinalizedRef = useRef(false);
|
||||||
|
const streamEndedRef = useRef(false);
|
||||||
const previousSessionIdRef = useRef<string | null>(null);
|
const previousSessionIdRef = useRef<string | null>(null);
|
||||||
const {
|
const {
|
||||||
error,
|
error,
|
||||||
@@ -65,44 +93,182 @@ export function useChatContainer({
|
|||||||
} = useChatStream();
|
} = useChatStream();
|
||||||
const activeStreams = useChatStore((s) => s.activeStreams);
|
const activeStreams = useChatStore((s) => s.activeStreams);
|
||||||
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
||||||
|
const setActiveTask = useChatStore((s) => s.setActiveTask);
|
||||||
|
const getActiveTask = useChatStore((s) => s.getActiveTask);
|
||||||
|
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
|
||||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||||
|
// Track whether we've already connected to this activeStream to avoid duplicate connections
|
||||||
|
const connectedActiveStreamRef = useRef<string | null>(null);
|
||||||
|
// Track if component is mounted to prevent state updates after unmount
|
||||||
|
const isMountedRef = useRef(true);
|
||||||
|
// Track current dispatcher to prevent multiple dispatchers from adding messages
|
||||||
|
const currentDispatcherIdRef = useRef(0);
|
||||||
|
|
||||||
|
// Set mounted flag - reset on every mount, cleanup on unmount
|
||||||
|
useEffect(function trackMountedState() {
|
||||||
|
isMountedRef.current = true;
|
||||||
|
return function cleanup() {
|
||||||
|
isMountedRef.current = false;
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Callback to store active task info for SSE reconnection
|
||||||
|
function handleActiveTaskStarted(taskInfo: {
|
||||||
|
taskId: string;
|
||||||
|
operationId: string;
|
||||||
|
toolName: string;
|
||||||
|
toolCallId: string;
|
||||||
|
}) {
|
||||||
|
if (!sessionId) return;
|
||||||
|
setActiveTask(sessionId, {
|
||||||
|
taskId: taskInfo.taskId,
|
||||||
|
operationId: taskInfo.operationId,
|
||||||
|
toolName: taskInfo.toolName,
|
||||||
|
lastMessageId: INITIAL_STREAM_ID,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create dispatcher for stream events - stable reference for current sessionId
|
||||||
|
// Each dispatcher gets a unique ID to prevent stale dispatchers from updating state
|
||||||
|
function createDispatcher() {
|
||||||
|
if (!sessionId) return () => {};
|
||||||
|
// Increment dispatcher ID - only the most recent dispatcher should update state
|
||||||
|
const dispatcherId = ++currentDispatcherIdRef.current;
|
||||||
|
|
||||||
|
const baseDispatcher = createStreamEventDispatcher({
|
||||||
|
setHasTextChunks,
|
||||||
|
setStreamingChunks,
|
||||||
|
streamingChunksRef,
|
||||||
|
hasResponseRef,
|
||||||
|
textFinalizedRef,
|
||||||
|
streamEndedRef,
|
||||||
|
setMessages,
|
||||||
|
setIsRegionBlockedModalOpen,
|
||||||
|
sessionId,
|
||||||
|
setIsStreamingInitiated,
|
||||||
|
onOperationStarted,
|
||||||
|
onActiveTaskStarted: handleActiveTaskStarted,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wrap dispatcher to check if it's still the current one
|
||||||
|
return function guardedDispatcher(chunk: StreamChunk) {
|
||||||
|
// Skip if component unmounted or this is a stale dispatcher
|
||||||
|
if (!isMountedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (dispatcherId !== currentDispatcherIdRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
baseDispatcher(chunk);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
useEffect(
|
useEffect(
|
||||||
function handleSessionChange() {
|
function handleSessionChange() {
|
||||||
if (sessionId === previousSessionIdRef.current) return;
|
const isSessionChange = sessionId !== previousSessionIdRef.current;
|
||||||
|
|
||||||
const prevSession = previousSessionIdRef.current;
|
// Handle session change - reset state
|
||||||
if (prevSession) {
|
if (isSessionChange) {
|
||||||
stopStreaming(prevSession);
|
const prevSession = previousSessionIdRef.current;
|
||||||
|
if (prevSession) {
|
||||||
|
stopStreaming(prevSession);
|
||||||
|
}
|
||||||
|
previousSessionIdRef.current = sessionId;
|
||||||
|
connectedActiveStreamRef.current = null;
|
||||||
|
setMessages([]);
|
||||||
|
setStreamingChunks([]);
|
||||||
|
streamingChunksRef.current = [];
|
||||||
|
setHasTextChunks(false);
|
||||||
|
setIsStreamingInitiated(false);
|
||||||
|
hasResponseRef.current = false;
|
||||||
|
textFinalizedRef.current = false;
|
||||||
|
streamEndedRef.current = false;
|
||||||
}
|
}
|
||||||
previousSessionIdRef.current = sessionId;
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
setIsStreamingInitiated(false);
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
|
|
||||||
if (!sessionId) return;
|
if (!sessionId) return;
|
||||||
|
|
||||||
const activeStream = activeStreams.get(sessionId);
|
// Priority 1: Check if server told us there's an active stream (most authoritative)
|
||||||
if (!activeStream || activeStream.status !== "streaming") return;
|
if (activeStream) {
|
||||||
|
const streamKey = `${sessionId}:${activeStream.taskId}`;
|
||||||
|
|
||||||
const dispatcher = createStreamEventDispatcher({
|
if (connectedActiveStreamRef.current === streamKey) {
|
||||||
setHasTextChunks,
|
return;
|
||||||
setStreamingChunks,
|
}
|
||||||
streamingChunksRef,
|
|
||||||
hasResponseRef,
|
// Skip if there's already an active stream for this session in the store
|
||||||
setMessages,
|
const existingStream = activeStreams.get(sessionId);
|
||||||
setIsRegionBlockedModalOpen,
|
if (existingStream && existingStream.status === "streaming") {
|
||||||
sessionId,
|
connectedActiveStreamRef.current = streamKey;
|
||||||
setIsStreamingInitiated,
|
return;
|
||||||
onOperationStarted,
|
}
|
||||||
});
|
|
||||||
|
connectedActiveStreamRef.current = streamKey;
|
||||||
|
|
||||||
|
// Clear all state before reconnection to prevent duplicates
|
||||||
|
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
|
||||||
|
setMessages([]);
|
||||||
|
setStreamingChunks([]);
|
||||||
|
streamingChunksRef.current = [];
|
||||||
|
setHasTextChunks(false);
|
||||||
|
textFinalizedRef.current = false;
|
||||||
|
streamEndedRef.current = false;
|
||||||
|
hasResponseRef.current = false;
|
||||||
|
|
||||||
|
setIsStreamingInitiated(true);
|
||||||
|
setActiveTask(sessionId, {
|
||||||
|
taskId: activeStream.taskId,
|
||||||
|
operationId: activeStream.operationId,
|
||||||
|
toolName: activeStream.toolName,
|
||||||
|
lastMessageId: activeStream.lastMessageId,
|
||||||
|
});
|
||||||
|
reconnectToTask(
|
||||||
|
sessionId,
|
||||||
|
activeStream.taskId,
|
||||||
|
activeStream.lastMessageId,
|
||||||
|
createDispatcher(),
|
||||||
|
);
|
||||||
|
// Don't return cleanup here - the guarded dispatcher handles stale events
|
||||||
|
// and the stream will complete naturally. Cleanup would prematurely stop
|
||||||
|
// the stream when effect re-runs due to activeStreams changing.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only check localStorage/in-memory on session change
|
||||||
|
if (!isSessionChange) return;
|
||||||
|
|
||||||
|
// Priority 2: Check localStorage for active task
|
||||||
|
const activeTask = getActiveTask(sessionId);
|
||||||
|
if (activeTask) {
|
||||||
|
// Clear all state before reconnection to prevent duplicates
|
||||||
|
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
|
||||||
|
setMessages([]);
|
||||||
|
setStreamingChunks([]);
|
||||||
|
streamingChunksRef.current = [];
|
||||||
|
setHasTextChunks(false);
|
||||||
|
textFinalizedRef.current = false;
|
||||||
|
streamEndedRef.current = false;
|
||||||
|
hasResponseRef.current = false;
|
||||||
|
|
||||||
|
setIsStreamingInitiated(true);
|
||||||
|
reconnectToTask(
|
||||||
|
sessionId,
|
||||||
|
activeTask.taskId,
|
||||||
|
activeTask.lastMessageId,
|
||||||
|
createDispatcher(),
|
||||||
|
);
|
||||||
|
// Don't return cleanup here - the guarded dispatcher handles stale events
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 3: Check for an in-memory active stream (same-tab scenario)
|
||||||
|
const inMemoryStream = activeStreams.get(sessionId);
|
||||||
|
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
const skipReplay = initialMessages.length > 0;
|
const skipReplay = initialMessages.length > 0;
|
||||||
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
return subscribeToStream(sessionId, createDispatcher(), skipReplay);
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -110,6 +276,10 @@ export function useChatContainer({
|
|||||||
activeStreams,
|
activeStreams,
|
||||||
subscribeToStream,
|
subscribeToStream,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
|
getActiveTask,
|
||||||
|
reconnectToTask,
|
||||||
|
activeStream,
|
||||||
|
setActiveTask,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -124,7 +294,7 @@ export function useChatContainer({
|
|||||||
msg.type === "agent_carousel" ||
|
msg.type === "agent_carousel" ||
|
||||||
msg.type === "execution_started"
|
msg.type === "execution_started"
|
||||||
) {
|
) {
|
||||||
const toolId = (msg as any).toolId;
|
const toolId = hasToolId(msg) ? msg.toolId : undefined;
|
||||||
if (toolId) {
|
if (toolId) {
|
||||||
ids.add(toolId);
|
ids.add(toolId);
|
||||||
}
|
}
|
||||||
@@ -141,12 +311,8 @@ export function useChatContainer({
|
|||||||
|
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
const filtered = prev.filter((msg) => {
|
const filtered = prev.filter((msg) => {
|
||||||
if (
|
if (isOperationMessage(msg)) {
|
||||||
msg.type === "operation_started" ||
|
const toolId = getToolIdFromMessage(msg);
|
||||||
msg.type === "operation_pending" ||
|
|
||||||
msg.type === "operation_in_progress"
|
|
||||||
) {
|
|
||||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
|
||||||
if (toolId && completedToolIds.has(toolId)) {
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
return false; // Remove - operation completed
|
return false; // Remove - operation completed
|
||||||
}
|
}
|
||||||
@@ -174,12 +340,8 @@ export function useChatContainer({
|
|||||||
// Filter local messages: remove duplicates and completed operation messages
|
// Filter local messages: remove duplicates and completed operation messages
|
||||||
const newLocalMessages = messages.filter((msg) => {
|
const newLocalMessages = messages.filter((msg) => {
|
||||||
// Remove operation messages for completed tools
|
// Remove operation messages for completed tools
|
||||||
if (
|
if (isOperationMessage(msg)) {
|
||||||
msg.type === "operation_started" ||
|
const toolId = getToolIdFromMessage(msg);
|
||||||
msg.type === "operation_pending" ||
|
|
||||||
msg.type === "operation_in_progress"
|
|
||||||
) {
|
|
||||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
|
||||||
if (toolId && completedToolIds.has(toolId)) {
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -190,7 +352,70 @@ export function useChatContainer({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Server messages first (correct order), then new local messages
|
// Server messages first (correct order), then new local messages
|
||||||
return [...processedInitial, ...newLocalMessages];
|
const combined = [...processedInitial, ...newLocalMessages];
|
||||||
|
|
||||||
|
// Post-processing: Remove duplicate assistant messages that can occur during
|
||||||
|
// race conditions (e.g., rapid screen switching during SSE reconnection).
|
||||||
|
// Two assistant messages are considered duplicates if:
|
||||||
|
// - They are both text messages with role "assistant"
|
||||||
|
// - One message's content starts with the other's content (partial vs complete)
|
||||||
|
// - Or they have very similar content (>80% overlap at the start)
|
||||||
|
const deduplicated: ChatMessageData[] = [];
|
||||||
|
for (let i = 0; i < combined.length; i++) {
|
||||||
|
const current = combined[i];
|
||||||
|
|
||||||
|
// Check if this is an assistant text message
|
||||||
|
if (current.type !== "message" || current.role !== "assistant") {
|
||||||
|
deduplicated.push(current);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for duplicate assistant messages in the rest of the array
|
||||||
|
let dominated = false;
|
||||||
|
for (let j = 0; j < combined.length; j++) {
|
||||||
|
if (i === j) continue;
|
||||||
|
const other = combined[j];
|
||||||
|
if (other.type !== "message" || other.role !== "assistant") continue;
|
||||||
|
|
||||||
|
const currentContent = current.content || "";
|
||||||
|
const otherContent = other.content || "";
|
||||||
|
|
||||||
|
// Skip empty messages
|
||||||
|
if (!currentContent.trim() || !otherContent.trim()) continue;
|
||||||
|
|
||||||
|
// Check if current is a prefix of other (current is incomplete version)
|
||||||
|
if (
|
||||||
|
otherContent.length > currentContent.length &&
|
||||||
|
otherContent.startsWith(currentContent.slice(0, 100))
|
||||||
|
) {
|
||||||
|
// Current is a shorter/incomplete version of other - skip it
|
||||||
|
dominated = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if messages are nearly identical (within a small difference)
|
||||||
|
// This catches cases where content differs only slightly
|
||||||
|
const minLen = Math.min(currentContent.length, otherContent.length);
|
||||||
|
const compareLen = Math.min(minLen, 200); // Compare first 200 chars
|
||||||
|
if (
|
||||||
|
compareLen > 50 &&
|
||||||
|
currentContent.slice(0, compareLen) ===
|
||||||
|
otherContent.slice(0, compareLen)
|
||||||
|
) {
|
||||||
|
// Same prefix - keep the longer one
|
||||||
|
if (otherContent.length > currentContent.length) {
|
||||||
|
dominated = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!dominated) {
|
||||||
|
deduplicated.push(current);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return deduplicated;
|
||||||
}, [initialMessages, messages, completedToolIds]);
|
}, [initialMessages, messages, completedToolIds]);
|
||||||
|
|
||||||
async function sendMessage(
|
async function sendMessage(
|
||||||
@@ -198,10 +423,8 @@ export function useChatContainer({
|
|||||||
isUserMessage: boolean = true,
|
isUserMessage: boolean = true,
|
||||||
context?: { url: string; content: string },
|
context?: { url: string; content: string },
|
||||||
) {
|
) {
|
||||||
if (!sessionId) {
|
if (!sessionId) return;
|
||||||
console.error("[useChatContainer] Cannot send message: no session ID");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setIsRegionBlockedModalOpen(false);
|
setIsRegionBlockedModalOpen(false);
|
||||||
if (isUserMessage) {
|
if (isUserMessage) {
|
||||||
const userMessage = createUserMessage(content);
|
const userMessage = createUserMessage(content);
|
||||||
@@ -214,31 +437,19 @@ export function useChatContainer({
|
|||||||
setHasTextChunks(false);
|
setHasTextChunks(false);
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
hasResponseRef.current = false;
|
hasResponseRef.current = false;
|
||||||
|
textFinalizedRef.current = false;
|
||||||
const dispatcher = createStreamEventDispatcher({
|
streamEndedRef.current = false;
|
||||||
setHasTextChunks,
|
|
||||||
setStreamingChunks,
|
|
||||||
streamingChunksRef,
|
|
||||||
hasResponseRef,
|
|
||||||
setMessages,
|
|
||||||
setIsRegionBlockedModalOpen,
|
|
||||||
sessionId,
|
|
||||||
setIsStreamingInitiated,
|
|
||||||
onOperationStarted,
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await sendStreamMessage(
|
await sendStreamMessage(
|
||||||
sessionId,
|
sessionId,
|
||||||
content,
|
content,
|
||||||
dispatcher,
|
createDispatcher(),
|
||||||
isUserMessage,
|
isUserMessage,
|
||||||
context,
|
context,
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("[useChatContainer] Failed to send message:", err);
|
|
||||||
setIsStreamingInitiated(false);
|
setIsStreamingInitiated(false);
|
||||||
|
|
||||||
if (err instanceof Error && err.name === "AbortError") return;
|
if (err instanceof Error && err.name === "AbortError") return;
|
||||||
|
|
||||||
const errorMessage =
|
const errorMessage =
|
||||||
|
|||||||
@@ -74,19 +74,20 @@ export function ChatInput({
|
|||||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
|
{!value && !isRecording && (
|
||||||
|
<div
|
||||||
|
className="pointer-events-none absolute inset-0 top-0.5 flex items-center justify-start pl-14 text-[1rem] text-zinc-400"
|
||||||
|
aria-hidden="true"
|
||||||
|
>
|
||||||
|
{isTranscribing ? "Transcribing..." : placeholder}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<textarea
|
<textarea
|
||||||
id={inputId}
|
id={inputId}
|
||||||
aria-label="Chat message input"
|
aria-label="Chat message input"
|
||||||
value={value}
|
value={value}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
placeholder={
|
|
||||||
isTranscribing
|
|
||||||
? "Transcribing..."
|
|
||||||
: isRecording
|
|
||||||
? ""
|
|
||||||
: placeholder
|
|
||||||
}
|
|
||||||
disabled={isInputDisabled}
|
disabled={isInputDisabled}
|
||||||
rows={1}
|
rows={1}
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -122,13 +123,14 @@ export function ChatInput({
|
|||||||
size="icon"
|
size="icon"
|
||||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||||
onClick={toggleRecording}
|
onClick={toggleRecording}
|
||||||
disabled={disabled || isTranscribing}
|
disabled={disabled || isTranscribing || isStreaming}
|
||||||
className={cn(
|
className={cn(
|
||||||
isRecording
|
isRecording
|
||||||
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
||||||
: isTranscribing
|
: isTranscribing
|
||||||
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
||||||
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||||
|
isStreaming && "opacity-40",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{isTranscribing ? (
|
{isTranscribing ? (
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ export function AudioWaveform({
|
|||||||
// Create audio context and analyser
|
// Create audio context and analyser
|
||||||
const audioContext = new AudioContext();
|
const audioContext = new AudioContext();
|
||||||
const analyser = audioContext.createAnalyser();
|
const analyser = audioContext.createAnalyser();
|
||||||
analyser.fftSize = 512;
|
analyser.fftSize = 256;
|
||||||
analyser.smoothingTimeConstant = 0.8;
|
analyser.smoothingTimeConstant = 0.3;
|
||||||
|
|
||||||
// Connect the stream to the analyser
|
// Connect the stream to the analyser
|
||||||
const source = audioContext.createMediaStreamSource(stream);
|
const source = audioContext.createMediaStreamSource(stream);
|
||||||
@@ -73,10 +73,11 @@ export function AudioWaveform({
|
|||||||
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map amplitude (0-128) to bar height
|
// Normalize amplitude (0-128 range) to 0-1
|
||||||
const normalized = (maxAmplitude / 128) * 255;
|
const normalized = maxAmplitude / 128;
|
||||||
const height =
|
// Apply sensitivity boost (multiply by 4) and use sqrt curve to amplify quiet sounds
|
||||||
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
|
const boosted = Math.min(1, Math.sqrt(normalized) * 4);
|
||||||
|
const height = minBarHeight + boosted * (maxBarHeight - minBarHeight);
|
||||||
newBars.push(height);
|
newBars.push(height);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -224,7 +224,7 @@ export function useVoiceRecording({
|
|||||||
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
||||||
);
|
);
|
||||||
|
|
||||||
const showMicButton = isSupported && !isStreaming;
|
const showMicButton = isSupported;
|
||||||
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
||||||
|
|
||||||
// Cleanup on unmount
|
// Cleanup on unmount
|
||||||
|
|||||||
@@ -346,6 +346,7 @@ export function ChatMessage({
|
|||||||
toolId={message.toolId}
|
toolId={message.toolId}
|
||||||
toolName={message.toolName}
|
toolName={message.toolName}
|
||||||
result={message.result}
|
result={message.result}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ export type ChatMessageData =
|
|||||||
toolName: string;
|
toolName: string;
|
||||||
toolId: string;
|
toolId: string;
|
||||||
operationId: string;
|
operationId: string;
|
||||||
|
taskId?: string; // For SSE reconnection
|
||||||
message: string;
|
message: string;
|
||||||
timestamp?: string | Date;
|
timestamp?: string | Date;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { EyeSlash } from "@phosphor-icons/react";
|
import { EyeSlash } from "@phosphor-icons/react";
|
||||||
import React from "react";
|
import React, { useState } from "react";
|
||||||
import ReactMarkdown from "react-markdown";
|
import ReactMarkdown from "react-markdown";
|
||||||
import remarkGfm from "remark-gfm";
|
import remarkGfm from "remark-gfm";
|
||||||
|
|
||||||
@@ -48,7 +48,9 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
|
|||||||
*/
|
*/
|
||||||
function resolveWorkspaceUrl(src: string): string {
|
function resolveWorkspaceUrl(src: string): string {
|
||||||
if (src.startsWith("workspace://")) {
|
if (src.startsWith("workspace://")) {
|
||||||
const fileId = src.replace("workspace://", "");
|
// Strip MIME type fragment if present (e.g., workspace://abc123#video/mp4 → abc123)
|
||||||
|
const withoutPrefix = src.replace("workspace://", "");
|
||||||
|
const fileId = withoutPrefix.split("#")[0];
|
||||||
// Use the generated API URL helper to get the correct path
|
// Use the generated API URL helper to get the correct path
|
||||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||||
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
||||||
@@ -65,13 +67,49 @@ function isWorkspaceImage(src: string | undefined): boolean {
|
|||||||
return src?.includes("/workspace/files/") ?? false;
|
return src?.includes("/workspace/files/") ?? false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders a workspace video with controls and an optional "AI cannot see" badge.
|
||||||
|
*/
|
||||||
|
function WorkspaceVideo({
|
||||||
|
src,
|
||||||
|
aiCannotSee,
|
||||||
|
}: {
|
||||||
|
src: string;
|
||||||
|
aiCannotSee: boolean;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<span className="relative my-2 inline-block">
|
||||||
|
<video
|
||||||
|
controls
|
||||||
|
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||||
|
preload="metadata"
|
||||||
|
>
|
||||||
|
<source src={src} />
|
||||||
|
Your browser does not support the video tag.
|
||||||
|
</video>
|
||||||
|
{aiCannotSee && (
|
||||||
|
<span
|
||||||
|
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
|
||||||
|
title="The AI cannot see this video"
|
||||||
|
>
|
||||||
|
<EyeSlash size={14} />
|
||||||
|
<span>AI cannot see this video</span>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Custom image component that shows an indicator when the AI cannot see the image.
|
* Custom image component that shows an indicator when the AI cannot see the image.
|
||||||
|
* Also handles the "video:" alt-text prefix convention to render <video> elements.
|
||||||
|
* For workspace files with unknown types, falls back to <video> if <img> fails.
|
||||||
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
||||||
*/
|
*/
|
||||||
function MarkdownImage(props: Record<string, unknown>) {
|
function MarkdownImage(props: Record<string, unknown>) {
|
||||||
const src = props.src as string | undefined;
|
const src = props.src as string | undefined;
|
||||||
const alt = props.alt as string | undefined;
|
const alt = props.alt as string | undefined;
|
||||||
|
const [imgFailed, setImgFailed] = useState(false);
|
||||||
|
|
||||||
const aiCannotSee = isWorkspaceImage(src);
|
const aiCannotSee = isWorkspaceImage(src);
|
||||||
|
|
||||||
@@ -84,6 +122,18 @@ function MarkdownImage(props: Record<string, unknown>) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect video: prefix in alt text (set by formatOutputValue in helpers.ts)
|
||||||
|
if (alt?.startsWith("video:")) {
|
||||||
|
return <WorkspaceVideo src={src} aiCannotSee={aiCannotSee} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the <img> failed to load and this is a workspace file, try as video.
|
||||||
|
// This handles generic output keys like "file_out" where the MIME type
|
||||||
|
// isn't known from the key name alone.
|
||||||
|
if (imgFailed && aiCannotSee) {
|
||||||
|
return <WorkspaceVideo src={src} aiCannotSee={aiCannotSee} />;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<span className="relative my-2 inline-block">
|
<span className="relative my-2 inline-block">
|
||||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||||
@@ -92,6 +142,9 @@ function MarkdownImage(props: Record<string, unknown>) {
|
|||||||
alt={alt || "Image"}
|
alt={alt || "Image"}
|
||||||
className="h-auto max-w-full rounded-md border border-zinc-200"
|
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||||
loading="lazy"
|
loading="lazy"
|
||||||
|
onError={() => {
|
||||||
|
if (aiCannotSee) setImgFailed(true);
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
{aiCannotSee && (
|
{aiCannotSee && (
|
||||||
<span
|
<span
|
||||||
|
|||||||
@@ -31,11 +31,6 @@ export function MessageList({
|
|||||||
isStreaming,
|
isStreaming,
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Keeps this for debugging purposes 💆🏽
|
|
||||||
*/
|
|
||||||
console.log(messages);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||||
{/* Top fade shadow */}
|
{/* Top fade shadow */}
|
||||||
@@ -78,6 +73,7 @@ export function MessageList({
|
|||||||
key={index}
|
key={index}
|
||||||
message={message}
|
message={message}
|
||||||
prevMessage={messages[index - 1]}
|
prevMessage={messages[index - 1]}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ import { shouldSkipAgentOutput } from "../../helpers";
|
|||||||
export interface LastToolResponseProps {
|
export interface LastToolResponseProps {
|
||||||
message: ChatMessageData;
|
message: ChatMessageData;
|
||||||
prevMessage: ChatMessageData | undefined;
|
prevMessage: ChatMessageData | undefined;
|
||||||
|
onSendMessage?: (content: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function LastToolResponse({
|
export function LastToolResponse({
|
||||||
message,
|
message,
|
||||||
prevMessage,
|
prevMessage,
|
||||||
|
onSendMessage,
|
||||||
}: LastToolResponseProps) {
|
}: LastToolResponseProps) {
|
||||||
if (message.type !== "tool_response") return null;
|
if (message.type !== "tool_response") return null;
|
||||||
|
|
||||||
@@ -21,6 +23,7 @@ export function LastToolResponse({
|
|||||||
toolId={message.toolId}
|
toolId={message.toolId}
|
||||||
toolName={message.toolName}
|
toolName={message.toolName}
|
||||||
result={message.result}
|
result={message.result}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
import { Progress } from "@/components/atoms/Progress/Progress";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||||
|
import { useAsymptoticProgress } from "../ToolCallMessage/useAsymptoticProgress";
|
||||||
|
|
||||||
export interface ThinkingMessageProps {
|
export interface ThinkingMessageProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
@@ -11,18 +13,19 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
||||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
const progress = useAsymptoticProgress(showCoffeeMessage);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (timerRef.current === null) {
|
if (timerRef.current === null) {
|
||||||
timerRef.current = setTimeout(() => {
|
timerRef.current = setTimeout(() => {
|
||||||
setShowSlowLoader(true);
|
setShowSlowLoader(true);
|
||||||
}, 8000);
|
}, 3000);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (coffeeTimerRef.current === null) {
|
if (coffeeTimerRef.current === null) {
|
||||||
coffeeTimerRef.current = setTimeout(() => {
|
coffeeTimerRef.current = setTimeout(() => {
|
||||||
setShowCoffeeMessage(true);
|
setShowCoffeeMessage(true);
|
||||||
}, 10000);
|
}, 8000);
|
||||||
}
|
}
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
@@ -49,9 +52,18 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
<AIChatBubble>
|
<AIChatBubble>
|
||||||
<div className="transition-all duration-500 ease-in-out">
|
<div className="transition-all duration-500 ease-in-out">
|
||||||
{showCoffeeMessage ? (
|
{showCoffeeMessage ? (
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
<div className="flex flex-col items-center gap-3">
|
||||||
This could take a few minutes, grab a coffee ☕️
|
<div className="flex w-full max-w-[280px] flex-col gap-1.5">
|
||||||
</span>
|
<div className="flex items-center justify-between text-xs text-neutral-500">
|
||||||
|
<span>Working on it...</span>
|
||||||
|
<span>{Math.round(progress)}%</span>
|
||||||
|
</div>
|
||||||
|
<Progress value={progress} className="h-2 w-full" />
|
||||||
|
</div>
|
||||||
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
|
This could take a few minutes, grab a coffee ☕️
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
) : showSlowLoader ? (
|
) : showSlowLoader ? (
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
Taking a bit more time...
|
Taking a bit more time...
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hook that returns a progress value that starts fast and slows down,
|
||||||
|
* asymptotically approaching but never reaching the max value.
|
||||||
|
*
|
||||||
|
* Uses a half-life formula: progress = max * (1 - 0.5^(time/halfLife))
|
||||||
|
* This creates the "game loading bar" effect where:
|
||||||
|
* - 50% is reached at halfLifeSeconds
|
||||||
|
* - 75% is reached at 2 * halfLifeSeconds
|
||||||
|
* - 87.5% is reached at 3 * halfLifeSeconds
|
||||||
|
* - and so on...
|
||||||
|
*
|
||||||
|
* @param isActive - Whether the progress should be animating
|
||||||
|
* @param halfLifeSeconds - Time in seconds to reach 50% progress (default: 30)
|
||||||
|
* @param maxProgress - Maximum progress value to approach (default: 100)
|
||||||
|
* @param intervalMs - Update interval in milliseconds (default: 100)
|
||||||
|
* @returns Current progress value (0-maxProgress)
|
||||||
|
*/
|
||||||
|
export function useAsymptoticProgress(
|
||||||
|
isActive: boolean,
|
||||||
|
halfLifeSeconds = 30,
|
||||||
|
maxProgress = 100,
|
||||||
|
intervalMs = 100,
|
||||||
|
) {
|
||||||
|
const [progress, setProgress] = useState(0);
|
||||||
|
const elapsedTimeRef = useRef(0);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!isActive) {
|
||||||
|
setProgress(0);
|
||||||
|
elapsedTimeRef.current = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const interval = setInterval(() => {
|
||||||
|
elapsedTimeRef.current += intervalMs / 1000;
|
||||||
|
// Half-life approach: progress = max * (1 - 0.5^(time/halfLife))
|
||||||
|
// At t=halfLife: 50%, at t=2*halfLife: 75%, at t=3*halfLife: 87.5%, etc.
|
||||||
|
const newProgress =
|
||||||
|
maxProgress *
|
||||||
|
(1 - Math.pow(0.5, elapsedTimeRef.current / halfLifeSeconds));
|
||||||
|
setProgress(newProgress);
|
||||||
|
}, intervalMs);
|
||||||
|
|
||||||
|
return () => clearInterval(interval);
|
||||||
|
}, [isActive, halfLifeSeconds, maxProgress, intervalMs]);
|
||||||
|
|
||||||
|
return progress;
|
||||||
|
}
|
||||||
@@ -0,0 +1,128 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useGetV2GetLibraryAgent } from "@/app/api/__generated__/endpoints/library/library";
|
||||||
|
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
||||||
|
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||||
|
import { RunAgentModal } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import {
|
||||||
|
CheckCircleIcon,
|
||||||
|
PencilLineIcon,
|
||||||
|
PlayIcon,
|
||||||
|
} from "@phosphor-icons/react";
|
||||||
|
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
agentName: string;
|
||||||
|
libraryAgentId: string;
|
||||||
|
onSendMessage?: (content: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AgentCreatedPrompt({
|
||||||
|
agentName,
|
||||||
|
libraryAgentId,
|
||||||
|
onSendMessage,
|
||||||
|
}: Props) {
|
||||||
|
// Fetch library agent eagerly so modal is ready when user clicks
|
||||||
|
const { data: libraryAgentResponse, isLoading } = useGetV2GetLibraryAgent(
|
||||||
|
libraryAgentId,
|
||||||
|
{
|
||||||
|
query: {
|
||||||
|
enabled: !!libraryAgentId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const libraryAgent =
|
||||||
|
libraryAgentResponse?.status === 200 ? libraryAgentResponse.data : null;
|
||||||
|
|
||||||
|
function handleRunWithPlaceholders() {
|
||||||
|
onSendMessage?.(
|
||||||
|
`Run the agent "${agentName}" with placeholder/example values so I can test it.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleRunCreated(execution: GraphExecutionMeta) {
|
||||||
|
onSendMessage?.(
|
||||||
|
`I've started the agent "${agentName}". The execution ID is ${execution.id}. Please monitor its progress and let me know when it completes.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleScheduleCreated(schedule: GraphExecutionJobInfo) {
|
||||||
|
const scheduleInfo = schedule.cron
|
||||||
|
? `with cron schedule "${schedule.cron}"`
|
||||||
|
: "to run on the specified schedule";
|
||||||
|
onSendMessage?.(
|
||||||
|
`I've scheduled the agent "${agentName}" ${scheduleInfo}. The schedule ID is ${schedule.id}.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AIChatBubble>
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-green-100">
|
||||||
|
<CheckCircleIcon
|
||||||
|
size={18}
|
||||||
|
weight="fill"
|
||||||
|
className="text-green-600"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<Text variant="body-medium" className="text-neutral-900">
|
||||||
|
Agent Created Successfully
|
||||||
|
</Text>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
"{agentName}" is ready to test
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<Text variant="small-medium" className="text-neutral-700">
|
||||||
|
Ready to test?
|
||||||
|
</Text>
|
||||||
|
<div className="flex flex-wrap gap-2">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="small"
|
||||||
|
onClick={handleRunWithPlaceholders}
|
||||||
|
className="gap-2"
|
||||||
|
>
|
||||||
|
<PlayIcon size={16} />
|
||||||
|
Run with example values
|
||||||
|
</Button>
|
||||||
|
{libraryAgent ? (
|
||||||
|
<RunAgentModal
|
||||||
|
triggerSlot={
|
||||||
|
<Button variant="outline" size="small" className="gap-2">
|
||||||
|
<PencilLineIcon size={16} />
|
||||||
|
Run with my inputs
|
||||||
|
</Button>
|
||||||
|
}
|
||||||
|
agent={libraryAgent}
|
||||||
|
onRunCreated={handleRunCreated}
|
||||||
|
onScheduleCreated={handleScheduleCreated}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="small"
|
||||||
|
loading={isLoading}
|
||||||
|
disabled
|
||||||
|
className="gap-2"
|
||||||
|
>
|
||||||
|
<PencilLineIcon size={16} />
|
||||||
|
Run with my inputs
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
or just ask me
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</AIChatBubble>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -2,11 +2,13 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import type { ToolResult } from "@/types/chat";
|
import type { ToolResult } from "@/types/chat";
|
||||||
import { WarningCircleIcon } from "@phosphor-icons/react";
|
import { WarningCircleIcon } from "@phosphor-icons/react";
|
||||||
|
import { AgentCreatedPrompt } from "./AgentCreatedPrompt";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||||
import {
|
import {
|
||||||
formatToolResponse,
|
formatToolResponse,
|
||||||
getErrorMessage,
|
getErrorMessage,
|
||||||
|
isAgentSavedResponse,
|
||||||
isErrorResponse,
|
isErrorResponse,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
@@ -16,6 +18,7 @@ export interface ToolResponseMessageProps {
|
|||||||
result?: ToolResult;
|
result?: ToolResult;
|
||||||
success?: boolean;
|
success?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
onSendMessage?: (content: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ToolResponseMessage({
|
export function ToolResponseMessage({
|
||||||
@@ -24,6 +27,7 @@ export function ToolResponseMessage({
|
|||||||
result,
|
result,
|
||||||
success: _success,
|
success: _success,
|
||||||
className,
|
className,
|
||||||
|
onSendMessage,
|
||||||
}: ToolResponseMessageProps) {
|
}: ToolResponseMessageProps) {
|
||||||
if (isErrorResponse(result)) {
|
if (isErrorResponse(result)) {
|
||||||
const errorMessage = getErrorMessage(result);
|
const errorMessage = getErrorMessage(result);
|
||||||
@@ -43,6 +47,18 @@ export function ToolResponseMessage({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for agent_saved response - show special prompt
|
||||||
|
const agentSavedData = isAgentSavedResponse(result);
|
||||||
|
if (agentSavedData.isSaved) {
|
||||||
|
return (
|
||||||
|
<AgentCreatedPrompt
|
||||||
|
agentName={agentSavedData.agentName}
|
||||||
|
libraryAgentId={agentSavedData.libraryAgentId}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const formattedText = formatToolResponse(result, toolName);
|
const formattedText = formatToolResponse(result, toolName);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -6,6 +6,43 @@ function stripInternalReasoning(content: string): string {
|
|||||||
.trim();
|
.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface AgentSavedData {
|
||||||
|
isSaved: boolean;
|
||||||
|
agentName: string;
|
||||||
|
agentId: string;
|
||||||
|
libraryAgentId: string;
|
||||||
|
libraryAgentLink: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isAgentSavedResponse(result: unknown): AgentSavedData {
|
||||||
|
if (typeof result !== "object" || result === null) {
|
||||||
|
return {
|
||||||
|
isSaved: false,
|
||||||
|
agentName: "",
|
||||||
|
agentId: "",
|
||||||
|
libraryAgentId: "",
|
||||||
|
libraryAgentLink: "",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const response = result as Record<string, unknown>;
|
||||||
|
if (response.type === "agent_saved") {
|
||||||
|
return {
|
||||||
|
isSaved: true,
|
||||||
|
agentName: (response.agent_name as string) || "Agent",
|
||||||
|
agentId: (response.agent_id as string) || "",
|
||||||
|
libraryAgentId: (response.library_agent_id as string) || "",
|
||||||
|
libraryAgentLink: (response.library_agent_link as string) || "",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
isSaved: false,
|
||||||
|
agentName: "",
|
||||||
|
agentId: "",
|
||||||
|
libraryAgentId: "",
|
||||||
|
libraryAgentLink: "",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export function isErrorResponse(result: unknown): boolean {
|
export function isErrorResponse(result: unknown): boolean {
|
||||||
if (typeof result === "string") {
|
if (typeof result === "string") {
|
||||||
const lower = result.toLowerCase();
|
const lower = result.toLowerCase();
|
||||||
@@ -39,69 +76,101 @@ export function getErrorMessage(result: unknown): string {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if a value is a workspace file reference.
|
* Check if a value is a workspace file reference.
|
||||||
|
* Format: workspace://{fileId} or workspace://{fileId}#{mimeType}
|
||||||
*/
|
*/
|
||||||
function isWorkspaceRef(value: unknown): value is string {
|
function isWorkspaceRef(value: unknown): value is string {
|
||||||
return typeof value === "string" && value.startsWith("workspace://");
|
return typeof value === "string" && value.startsWith("workspace://");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if a workspace reference appears to be an image based on common patterns.
|
* Extract MIME type from a workspace reference fragment.
|
||||||
* Since workspace refs don't have extensions, we check the context or assume image
|
* e.g., "workspace://abc123#video/mp4" → "video/mp4"
|
||||||
* for certain block types.
|
* Returns undefined if no fragment is present.
|
||||||
*
|
|
||||||
* TODO: Replace keyword matching with MIME type encoded in workspace ref.
|
|
||||||
* e.g., workspace://abc123#image/png or workspace://abc123#video/mp4
|
|
||||||
* This would let frontend render correctly without fragile keyword matching.
|
|
||||||
*/
|
*/
|
||||||
function isLikelyImageRef(value: string, outputKey?: string): boolean {
|
function getWorkspaceMimeType(value: string): string | undefined {
|
||||||
if (!isWorkspaceRef(value)) return false;
|
const hashIndex = value.indexOf("#");
|
||||||
|
if (hashIndex === -1) return undefined;
|
||||||
// Check output key name for video-related hints (these are NOT images)
|
return value.slice(hashIndex + 1) || undefined;
|
||||||
const videoKeywords = ["video", "mp4", "mov", "avi", "webm", "movie", "clip"];
|
|
||||||
if (outputKey) {
|
|
||||||
const lowerKey = outputKey.toLowerCase();
|
|
||||||
if (videoKeywords.some((kw) => lowerKey.includes(kw))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check output key name for image-related hints
|
|
||||||
const imageKeywords = [
|
|
||||||
"image",
|
|
||||||
"img",
|
|
||||||
"photo",
|
|
||||||
"picture",
|
|
||||||
"thumbnail",
|
|
||||||
"avatar",
|
|
||||||
"icon",
|
|
||||||
"screenshot",
|
|
||||||
];
|
|
||||||
if (outputKey) {
|
|
||||||
const lowerKey = outputKey.toLowerCase();
|
|
||||||
if (imageKeywords.some((kw) => lowerKey.includes(kw))) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to treating workspace refs as potential images
|
|
||||||
// since that's the most common case for generated content
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Format a single output value, converting workspace refs to markdown images.
|
* Determine the media category of a workspace ref or data URI.
|
||||||
|
* Uses the MIME type fragment on workspace refs when available,
|
||||||
|
* falls back to output key keyword matching for older refs without it.
|
||||||
*/
|
*/
|
||||||
function formatOutputValue(value: unknown, outputKey?: string): string {
|
function getMediaCategory(
|
||||||
if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) {
|
value: string,
|
||||||
// Format as markdown image
|
outputKey?: string,
|
||||||
return ``;
|
): "video" | "image" | "audio" | "unknown" {
|
||||||
|
// Data URIs carry their own MIME type
|
||||||
|
if (value.startsWith("data:video/")) return "video";
|
||||||
|
if (value.startsWith("data:image/")) return "image";
|
||||||
|
if (value.startsWith("data:audio/")) return "audio";
|
||||||
|
|
||||||
|
// Workspace refs: prefer MIME type fragment
|
||||||
|
if (isWorkspaceRef(value)) {
|
||||||
|
const mime = getWorkspaceMimeType(value);
|
||||||
|
if (mime) {
|
||||||
|
if (mime.startsWith("video/")) return "video";
|
||||||
|
if (mime.startsWith("image/")) return "image";
|
||||||
|
if (mime.startsWith("audio/")) return "audio";
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: keyword matching on output key for older refs without fragment
|
||||||
|
if (outputKey) {
|
||||||
|
const lowerKey = outputKey.toLowerCase();
|
||||||
|
|
||||||
|
const videoKeywords = [
|
||||||
|
"video",
|
||||||
|
"mp4",
|
||||||
|
"mov",
|
||||||
|
"avi",
|
||||||
|
"webm",
|
||||||
|
"movie",
|
||||||
|
"clip",
|
||||||
|
];
|
||||||
|
if (videoKeywords.some((kw) => lowerKey.includes(kw))) return "video";
|
||||||
|
|
||||||
|
const imageKeywords = [
|
||||||
|
"image",
|
||||||
|
"img",
|
||||||
|
"photo",
|
||||||
|
"picture",
|
||||||
|
"thumbnail",
|
||||||
|
"avatar",
|
||||||
|
"icon",
|
||||||
|
"screenshot",
|
||||||
|
];
|
||||||
|
if (imageKeywords.some((kw) => lowerKey.includes(kw))) return "image";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to image for backward compatibility
|
||||||
|
return "image";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format a single output value, converting workspace refs to markdown images/videos.
|
||||||
|
* Videos use a "video:" alt-text prefix so the MarkdownContent renderer can
|
||||||
|
* distinguish them from images and render a <video> element.
|
||||||
|
*/
|
||||||
|
function formatOutputValue(value: unknown, outputKey?: string): string {
|
||||||
if (typeof value === "string") {
|
if (typeof value === "string") {
|
||||||
// Check for data URIs (images)
|
const category = getMediaCategory(value, outputKey);
|
||||||
if (value.startsWith("data:image/")) {
|
|
||||||
|
if (category === "video") {
|
||||||
|
// Format with "video:" prefix so MarkdownContent renders <video>
|
||||||
|
return ``;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (category === "image") {
|
||||||
return ``;
|
return ``;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For audio, unknown workspace refs, data URIs, etc. - return as-is
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import { INITIAL_STREAM_ID } from "./chat-constants";
|
||||||
import type {
|
import type {
|
||||||
ActiveStream,
|
ActiveStream,
|
||||||
StreamChunk,
|
StreamChunk,
|
||||||
@@ -10,8 +11,14 @@ import {
|
|||||||
parseSSELine,
|
parseSSELine,
|
||||||
} from "./stream-utils";
|
} from "./stream-utils";
|
||||||
|
|
||||||
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
function notifySubscribers(
|
||||||
stream.chunks.push(chunk);
|
stream: ActiveStream,
|
||||||
|
chunk: StreamChunk,
|
||||||
|
skipStore = false,
|
||||||
|
) {
|
||||||
|
if (!skipStore) {
|
||||||
|
stream.chunks.push(chunk);
|
||||||
|
}
|
||||||
for (const callback of stream.onChunkCallbacks) {
|
for (const callback of stream.onChunkCallbacks) {
|
||||||
try {
|
try {
|
||||||
callback(chunk);
|
callback(chunk);
|
||||||
@@ -21,36 +28,114 @@ function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function executeStream(
|
interface StreamExecutionOptions {
|
||||||
stream: ActiveStream,
|
stream: ActiveStream;
|
||||||
message: string,
|
mode: "new" | "reconnect";
|
||||||
isUserMessage: boolean,
|
message?: string;
|
||||||
context?: { url: string; content: string },
|
isUserMessage?: boolean;
|
||||||
retryCount: number = 0,
|
context?: { url: string; content: string };
|
||||||
|
taskId?: string;
|
||||||
|
lastMessageId?: string;
|
||||||
|
retryCount?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function executeStreamInternal(
|
||||||
|
options: StreamExecutionOptions,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
|
const {
|
||||||
|
stream,
|
||||||
|
mode,
|
||||||
|
message,
|
||||||
|
isUserMessage,
|
||||||
|
context,
|
||||||
|
taskId,
|
||||||
|
lastMessageId = INITIAL_STREAM_ID,
|
||||||
|
retryCount = 0,
|
||||||
|
} = options;
|
||||||
|
|
||||||
const { sessionId, abortController } = stream;
|
const { sessionId, abortController } = stream;
|
||||||
|
const isReconnect = mode === "reconnect";
|
||||||
|
|
||||||
|
if (isReconnect) {
|
||||||
|
if (!taskId) {
|
||||||
|
throw new Error("taskId is required for reconnect mode");
|
||||||
|
}
|
||||||
|
if (lastMessageId === null || lastMessageId === undefined) {
|
||||||
|
throw new Error("lastMessageId is required for reconnect mode");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!message) {
|
||||||
|
throw new Error("message is required for new stream mode");
|
||||||
|
}
|
||||||
|
if (isUserMessage === undefined) {
|
||||||
|
throw new Error("isUserMessage is required for new stream mode");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const url = `/api/chat/sessions/${sessionId}/stream`;
|
let url: string;
|
||||||
const body = JSON.stringify({
|
let fetchOptions: RequestInit;
|
||||||
message,
|
|
||||||
is_user_message: isUserMessage,
|
|
||||||
context: context || null,
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await fetch(url, {
|
if (isReconnect) {
|
||||||
method: "POST",
|
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
||||||
headers: {
|
fetchOptions = {
|
||||||
"Content-Type": "application/json",
|
method: "GET",
|
||||||
Accept: "text/event-stream",
|
headers: {
|
||||||
},
|
Accept: "text/event-stream",
|
||||||
body,
|
},
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
});
|
};
|
||||||
|
} else {
|
||||||
|
url = `/api/chat/sessions/${sessionId}/stream`;
|
||||||
|
fetchOptions = {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Accept: "text/event-stream",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
message,
|
||||||
|
is_user_message: isUserMessage,
|
||||||
|
context: context || null,
|
||||||
|
}),
|
||||||
|
signal: abortController.signal,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(url, fetchOptions);
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorText = await response.text();
|
const errorText = await response.text();
|
||||||
throw new Error(errorText || `HTTP ${response.status}`);
|
let errorCode: string | undefined;
|
||||||
|
let errorMessage = errorText || `HTTP ${response.status}`;
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(errorText);
|
||||||
|
if (parsed.detail) {
|
||||||
|
const detail =
|
||||||
|
typeof parsed.detail === "string"
|
||||||
|
? parsed.detail
|
||||||
|
: parsed.detail.message || JSON.stringify(parsed.detail);
|
||||||
|
errorMessage = detail;
|
||||||
|
errorCode =
|
||||||
|
typeof parsed.detail === "object" ? parsed.detail.code : undefined;
|
||||||
|
}
|
||||||
|
} catch {}
|
||||||
|
|
||||||
|
const isPermanentError =
|
||||||
|
isReconnect &&
|
||||||
|
(response.status === 404 ||
|
||||||
|
response.status === 403 ||
|
||||||
|
response.status === 410);
|
||||||
|
|
||||||
|
const error = new Error(errorMessage) as Error & {
|
||||||
|
status?: number;
|
||||||
|
isPermanent?: boolean;
|
||||||
|
taskErrorCode?: string;
|
||||||
|
};
|
||||||
|
error.status = response.status;
|
||||||
|
error.isPermanent = isPermanentError;
|
||||||
|
error.taskErrorCode = errorCode;
|
||||||
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!response.body) {
|
if (!response.body) {
|
||||||
@@ -104,9 +189,7 @@ export async function executeStream(
|
|||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch {}
|
||||||
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,19 +200,17 @@ export async function executeStream(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (retryCount < MAX_RETRIES) {
|
const isPermanentError =
|
||||||
|
err instanceof Error &&
|
||||||
|
(err as Error & { isPermanent?: boolean }).isPermanent;
|
||||||
|
|
||||||
|
if (!isPermanentError && retryCount < MAX_RETRIES) {
|
||||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||||
console.log(
|
|
||||||
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
|
||||||
);
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||||
return executeStream(
|
return executeStreamInternal({
|
||||||
stream,
|
...options,
|
||||||
message,
|
retryCount: retryCount + 1,
|
||||||
isUserMessage,
|
});
|
||||||
context,
|
|
||||||
retryCount + 1,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.status = "error";
|
stream.status = "error";
|
||||||
@@ -140,3 +221,35 @@ export async function executeStream(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function executeStream(
|
||||||
|
stream: ActiveStream,
|
||||||
|
message: string,
|
||||||
|
isUserMessage: boolean,
|
||||||
|
context?: { url: string; content: string },
|
||||||
|
retryCount: number = 0,
|
||||||
|
): Promise<void> {
|
||||||
|
return executeStreamInternal({
|
||||||
|
stream,
|
||||||
|
mode: "new",
|
||||||
|
message,
|
||||||
|
isUserMessage,
|
||||||
|
context,
|
||||||
|
retryCount,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function executeTaskReconnect(
|
||||||
|
stream: ActiveStream,
|
||||||
|
taskId: string,
|
||||||
|
lastMessageId: string = INITIAL_STREAM_ID,
|
||||||
|
retryCount: number = 0,
|
||||||
|
): Promise<void> {
|
||||||
|
return executeStreamInternal({
|
||||||
|
stream,
|
||||||
|
mode: "reconnect",
|
||||||
|
taskId,
|
||||||
|
lastMessageId,
|
||||||
|
retryCount,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ export function normalizeStreamChunk(
|
|||||||
|
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
case "text-delta":
|
case "text-delta":
|
||||||
|
// Vercel AI SDK sends "delta" for text content
|
||||||
return { type: "text_chunk", content: chunk.delta };
|
return { type: "text_chunk", content: chunk.delta };
|
||||||
case "text-end":
|
case "text-end":
|
||||||
return { type: "text_ended" };
|
return { type: "text_ended" };
|
||||||
@@ -63,6 +64,10 @@ export function normalizeStreamChunk(
|
|||||||
case "finish":
|
case "finish":
|
||||||
return { type: "stream_end" };
|
return { type: "stream_end" };
|
||||||
case "start":
|
case "start":
|
||||||
|
// Start event with optional taskId for reconnection
|
||||||
|
return chunk.taskId
|
||||||
|
? { type: "stream_start", taskId: chunk.taskId }
|
||||||
|
: null;
|
||||||
case "text-start":
|
case "text-start":
|
||||||
return null;
|
return null;
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
|
|||||||
@@ -41,7 +41,17 @@ export function HostScopedCredentialsModal({
|
|||||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||||
|
|
||||||
const formSchema = z.object({
|
const formSchema = z.object({
|
||||||
host: z.string().min(1, "Host is required"),
|
host: z
|
||||||
|
.string()
|
||||||
|
.min(1, "Host is required")
|
||||||
|
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
|
||||||
|
message: "Enter only the host (e.g. api.example.com), not a full URL",
|
||||||
|
})
|
||||||
|
.refine((val) => !val.includes("/"), {
|
||||||
|
message:
|
||||||
|
"Enter only the host (e.g. api.example.com), without a trailing path. " +
|
||||||
|
"You may specify a port (e.g. api.example.com:8080) if needed.",
|
||||||
|
}),
|
||||||
title: z.string().optional(),
|
title: z.string().optional(),
|
||||||
headers: z.record(z.string()).optional(),
|
headers: z.record(z.string()).optional(),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ export const providerIcons: Partial<
|
|||||||
nvidia: fallbackIcon,
|
nvidia: fallbackIcon,
|
||||||
discord: FaDiscord,
|
discord: FaDiscord,
|
||||||
d_id: fallbackIcon,
|
d_id: fallbackIcon,
|
||||||
|
elevenlabs: fallbackIcon,
|
||||||
google_maps: FaGoogle,
|
google_maps: FaGoogle,
|
||||||
jina: fallbackIcon,
|
jina: fallbackIcon,
|
||||||
ideogram: fallbackIcon,
|
ideogram: fallbackIcon,
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ export function Navbar() {
|
|||||||
|
|
||||||
const actualLoggedInLinks = [
|
const actualLoggedInLinks = [
|
||||||
{ name: "Home", href: homeHref },
|
{ name: "Home", href: homeHref },
|
||||||
...(isChatEnabled === true ? [{ name: "Tasks", href: "/library" }] : []),
|
...(isChatEnabled === true ? [{ name: "Agents", href: "/library" }] : []),
|
||||||
...loggedInLinks,
|
...loggedInLinks,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import {
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
|
|
||||||
import { WalletIcon } from "@phosphor-icons/react";
|
import { WalletIcon } from "@phosphor-icons/react";
|
||||||
import { PopoverClose } from "@radix-ui/react-popover";
|
import { PopoverClose } from "@radix-ui/react-popover";
|
||||||
import { X } from "lucide-react";
|
import { X } from "lucide-react";
|
||||||
@@ -175,7 +174,6 @@ export function Wallet() {
|
|||||||
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
|
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
|
||||||
const [flash, setFlash] = useState(false);
|
const [flash, setFlash] = useState(false);
|
||||||
const [walletOpen, setWalletOpen] = useState(false);
|
const [walletOpen, setWalletOpen] = useState(false);
|
||||||
const [lastSeenCredits, setLastSeenCredits] = useState<number | null>(null);
|
|
||||||
|
|
||||||
const totalCount = useMemo(() => {
|
const totalCount = useMemo(() => {
|
||||||
return groups.reduce((acc, group) => acc + group.tasks.length, 0);
|
return groups.reduce((acc, group) => acc + group.tasks.length, 0);
|
||||||
@@ -200,38 +198,6 @@ export function Wallet() {
|
|||||||
setCompletedCount(completed);
|
setCompletedCount(completed);
|
||||||
}, [groups, state?.completedSteps]);
|
}, [groups, state?.completedSteps]);
|
||||||
|
|
||||||
// Load last seen credits from localStorage once on mount
|
|
||||||
useEffect(() => {
|
|
||||||
const stored = storage.get(StorageKey.WALLET_LAST_SEEN_CREDITS);
|
|
||||||
if (stored !== undefined && stored !== null) {
|
|
||||||
const parsed = parseFloat(stored);
|
|
||||||
if (!Number.isNaN(parsed)) setLastSeenCredits(parsed);
|
|
||||||
else setLastSeenCredits(0);
|
|
||||||
} else {
|
|
||||||
setLastSeenCredits(0);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
// Auto-open once if never shown, otherwise open only when credits increase beyond last seen
|
|
||||||
useEffect(() => {
|
|
||||||
if (typeof credits !== "number") return;
|
|
||||||
// Open once for first-time users
|
|
||||||
if (state && state.walletShown === false) {
|
|
||||||
requestAnimationFrame(() => setWalletOpen(true));
|
|
||||||
// Mark as shown so it won't reopen on every reload
|
|
||||||
updateState({ walletShown: true });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Open if user gained more credits than last acknowledged
|
|
||||||
if (
|
|
||||||
lastSeenCredits !== null &&
|
|
||||||
credits > lastSeenCredits &&
|
|
||||||
walletOpen === false
|
|
||||||
) {
|
|
||||||
requestAnimationFrame(() => setWalletOpen(true));
|
|
||||||
}
|
|
||||||
}, [credits, lastSeenCredits, state?.walletShown, updateState, walletOpen]);
|
|
||||||
|
|
||||||
const onWalletOpen = useCallback(async () => {
|
const onWalletOpen = useCallback(async () => {
|
||||||
if (!state?.walletShown) {
|
if (!state?.walletShown) {
|
||||||
updateState({ walletShown: true });
|
updateState({ walletShown: true });
|
||||||
@@ -324,19 +290,7 @@ export function Wallet() {
|
|||||||
if (credits === null || !state) return null;
|
if (credits === null || !state) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover
|
<Popover open={walletOpen} onOpenChange={(open) => setWalletOpen(open)}>
|
||||||
open={walletOpen}
|
|
||||||
onOpenChange={(open) => {
|
|
||||||
setWalletOpen(open);
|
|
||||||
if (!open) {
|
|
||||||
// Persist the latest acknowledged credits so we only auto-open on future gains
|
|
||||||
if (typeof credits === "number") {
|
|
||||||
storage.set(StorageKey.WALLET_LAST_SEEN_CREDITS, String(credits));
|
|
||||||
setLastSeenCredits(credits);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<PopoverTrigger asChild>
|
<PopoverTrigger asChild>
|
||||||
<div className="relative inline-block">
|
<div className="relative inline-block">
|
||||||
<button
|
<button
|
||||||
|
|||||||
68
classic/frontend/build/web/flutter_service_worker.js
generated
68
classic/frontend/build/web/flutter_service_worker.js
generated
@@ -3,45 +3,45 @@ const MANIFEST = 'flutter-app-manifest';
|
|||||||
const TEMP = 'flutter-temp-cache';
|
const TEMP = 'flutter-temp-cache';
|
||||||
const CACHE_NAME = 'flutter-app-cache';
|
const CACHE_NAME = 'flutter-app-cache';
|
||||||
|
|
||||||
const RESOURCES = {"flutter.js": "6fef97aeca90b426343ba6c5c9dc5d4a",
|
const RESOURCES = {"canvaskit/skwasm.worker.js": "51253d3321b11ddb8d73fa8aa87d3b15",
|
||||||
"icons/Icon-512.png": "96e752610906ba2a93c65f8abe1645f1",
|
|
||||||
"icons/Icon-maskable-512.png": "301a7604d45b3e739efc881eb04896ea",
|
|
||||||
"icons/Icon-192.png": "ac9a721a12bbc803b44f645561ecb1e1",
|
|
||||||
"icons/Icon-maskable-192.png": "c457ef57daa1d16f64b27b786ec2ea3c",
|
|
||||||
"manifest.json": "0fa552613b8ec0fda5cda565914e3b16",
|
|
||||||
"index.html": "3442c510a9ea217672c82e799ae070f7",
|
|
||||||
"/": "3442c510a9ea217672c82e799ae070f7",
|
|
||||||
"assets/shaders/ink_sparkle.frag": "f8b80e740d33eb157090be4e995febdf",
|
|
||||||
"assets/assets/tree_structure.json": "cda9b1a239f956c547411efad9f7c794",
|
|
||||||
"assets/assets/coding_tree_structure.json": "017a857cf3e274346a0a7eab4ce02eed",
|
|
||||||
"assets/assets/general_tree_structure.json": "41dfbcdc2349dcdda2b082e597c6d5ee",
|
|
||||||
"assets/assets/github_logo.svg.png": "ba087b073efdc4996b035d3a12bad0e4",
|
|
||||||
"assets/assets/images/discord_logo.png": "0e4a4162c5de8665a7d63ae9665405ae",
|
|
||||||
"assets/assets/images/github_logo.svg.png": "ba087b073efdc4996b035d3a12bad0e4",
|
|
||||||
"assets/assets/images/twitter_logo.png": "af6c11b96a5e732b8dfda86a2351ecab",
|
|
||||||
"assets/assets/images/google_logo.svg.png": "0e29f8e1acfb8996437dbb2b0f591f19",
|
|
||||||
"assets/assets/images/autogpt_logo.png": "6a5362a7d1f2f840e43ee259e733476c",
|
|
||||||
"assets/assets/google_logo.svg.png": "0e29f8e1acfb8996437dbb2b0f591f19",
|
|
||||||
"assets/assets/scrape_synthesize_tree_structure.json": "a9665c1b465bb0cb939c7210f2bf0b13",
|
|
||||||
"assets/assets/data_tree_structure.json": "5f9627548304155821968182f3883ca7",
|
|
||||||
"assets/fonts/MaterialIcons-Regular.otf": "245e0462249d95ad589a087f1c9f58e1",
|
|
||||||
"assets/NOTICES": "28ba0c63fc6e4d1ef829af7441e27f78",
|
|
||||||
"assets/packages/fluttertoast/assets/toastify.css": "a85675050054f179444bc5ad70ffc635",
|
|
||||||
"assets/packages/fluttertoast/assets/toastify.js": "56e2c9cedd97f10e7e5f1cebd85d53e3",
|
|
||||||
"assets/packages/cupertino_icons/assets/CupertinoIcons.ttf": "055d9e87e4a40dbf72b2af1a20865d57",
|
|
||||||
"assets/FontManifest.json": "dc3d03800ccca4601324923c0b1d6d57",
|
|
||||||
"assets/AssetManifest.bin": "791447d17744ac2ade3999c1672fdbe8",
|
|
||||||
"assets/AssetManifest.json": "1b1e4a4276722b65eb1ef765e2991840",
|
|
||||||
"canvaskit/chromium/canvaskit.wasm": "393ec8fb05d94036734f8104fa550a67",
|
|
||||||
"canvaskit/chromium/canvaskit.js": "ffb2bb6484d5689d91f393b60664d530",
|
|
||||||
"canvaskit/skwasm.worker.js": "51253d3321b11ddb8d73fa8aa87d3b15",
|
|
||||||
"canvaskit/skwasm.js": "95f16c6690f955a45b2317496983dbe9",
|
"canvaskit/skwasm.js": "95f16c6690f955a45b2317496983dbe9",
|
||||||
"canvaskit/canvaskit.wasm": "d9f69e0f428f695dc3d66b3a83a4aa8e",
|
"canvaskit/canvaskit.wasm": "d9f69e0f428f695dc3d66b3a83a4aa8e",
|
||||||
"canvaskit/canvaskit.js": "5caccb235fad20e9b72ea6da5a0094e6",
|
|
||||||
"canvaskit/skwasm.wasm": "d1fde2560be92c0b07ad9cf9acb10d05",
|
"canvaskit/skwasm.wasm": "d1fde2560be92c0b07ad9cf9acb10d05",
|
||||||
|
"canvaskit/canvaskit.js": "5caccb235fad20e9b72ea6da5a0094e6",
|
||||||
|
"canvaskit/chromium/canvaskit.wasm": "393ec8fb05d94036734f8104fa550a67",
|
||||||
|
"canvaskit/chromium/canvaskit.js": "ffb2bb6484d5689d91f393b60664d530",
|
||||||
|
"icons/Icon-maskable-192.png": "c457ef57daa1d16f64b27b786ec2ea3c",
|
||||||
|
"icons/Icon-maskable-512.png": "301a7604d45b3e739efc881eb04896ea",
|
||||||
|
"icons/Icon-512.png": "96e752610906ba2a93c65f8abe1645f1",
|
||||||
|
"icons/Icon-192.png": "ac9a721a12bbc803b44f645561ecb1e1",
|
||||||
|
"manifest.json": "0fa552613b8ec0fda5cda565914e3b16",
|
||||||
"favicon.png": "5dcef449791fa27946b3d35ad8803796",
|
"favicon.png": "5dcef449791fa27946b3d35ad8803796",
|
||||||
"version.json": "46a52461e018faa623d9196334aa3f50",
|
"version.json": "46a52461e018faa623d9196334aa3f50",
|
||||||
"main.dart.js": "6fcbf8bbcb0a76fae9029f72ac7fbdc3"};
|
"index.html": "e6981504a32bf86f892909c1875df208",
|
||||||
|
"/": "e6981504a32bf86f892909c1875df208",
|
||||||
|
"main.dart.js": "6fcbf8bbcb0a76fae9029f72ac7fbdc3",
|
||||||
|
"assets/AssetManifest.json": "1b1e4a4276722b65eb1ef765e2991840",
|
||||||
|
"assets/packages/cupertino_icons/assets/CupertinoIcons.ttf": "055d9e87e4a40dbf72b2af1a20865d57",
|
||||||
|
"assets/packages/fluttertoast/assets/toastify.js": "56e2c9cedd97f10e7e5f1cebd85d53e3",
|
||||||
|
"assets/packages/fluttertoast/assets/toastify.css": "a85675050054f179444bc5ad70ffc635",
|
||||||
|
"assets/shaders/ink_sparkle.frag": "f8b80e740d33eb157090be4e995febdf",
|
||||||
|
"assets/fonts/MaterialIcons-Regular.otf": "245e0462249d95ad589a087f1c9f58e1",
|
||||||
|
"assets/assets/images/twitter_logo.png": "af6c11b96a5e732b8dfda86a2351ecab",
|
||||||
|
"assets/assets/images/discord_logo.png": "0e4a4162c5de8665a7d63ae9665405ae",
|
||||||
|
"assets/assets/images/google_logo.svg.png": "0e29f8e1acfb8996437dbb2b0f591f19",
|
||||||
|
"assets/assets/images/autogpt_logo.png": "6a5362a7d1f2f840e43ee259e733476c",
|
||||||
|
"assets/assets/images/github_logo.svg.png": "ba087b073efdc4996b035d3a12bad0e4",
|
||||||
|
"assets/assets/scrape_synthesize_tree_structure.json": "a9665c1b465bb0cb939c7210f2bf0b13",
|
||||||
|
"assets/assets/coding_tree_structure.json": "017a857cf3e274346a0a7eab4ce02eed",
|
||||||
|
"assets/assets/general_tree_structure.json": "41dfbcdc2349dcdda2b082e597c6d5ee",
|
||||||
|
"assets/assets/google_logo.svg.png": "0e29f8e1acfb8996437dbb2b0f591f19",
|
||||||
|
"assets/assets/tree_structure.json": "cda9b1a239f956c547411efad9f7c794",
|
||||||
|
"assets/assets/data_tree_structure.json": "5f9627548304155821968182f3883ca7",
|
||||||
|
"assets/assets/github_logo.svg.png": "ba087b073efdc4996b035d3a12bad0e4",
|
||||||
|
"assets/NOTICES": "28ba0c63fc6e4d1ef829af7441e27f78",
|
||||||
|
"assets/AssetManifest.bin": "791447d17744ac2ade3999c1672fdbe8",
|
||||||
|
"assets/FontManifest.json": "dc3d03800ccca4601324923c0b1d6d57",
|
||||||
|
"flutter.js": "6fef97aeca90b426343ba6c5c9dc5d4a"};
|
||||||
// The application shell files that are downloaded before a service worker can
|
// The application shell files that are downloaded before a service worker can
|
||||||
// start.
|
// start.
|
||||||
const CORE = ["main.dart.js",
|
const CORE = ["main.dart.js",
|
||||||
|
|||||||
2
classic/frontend/build/web/index.html
generated
2
classic/frontend/build/web/index.html
generated
@@ -35,7 +35,7 @@
|
|||||||
|
|
||||||
<script>
|
<script>
|
||||||
// The value below is injected by flutter build, do not touch.
|
// The value below is injected by flutter build, do not touch.
|
||||||
const serviceWorkerVersion = "1550046101";
|
const serviceWorkerVersion = "726743092";
|
||||||
</script>
|
</script>
|
||||||
<!-- This script adds the flutter initialization JS code -->
|
<!-- This script adds the flutter initialization JS code -->
|
||||||
<script src="flutter.js" defer></script>
|
<script src="flutter.js" defer></script>
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
|
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
|
||||||
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
|
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
|
||||||
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data |
|
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data |
|
||||||
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
|
|
||||||
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
|
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
|
||||||
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
|
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
|
||||||
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
|
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
|
||||||
@@ -193,6 +192,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Get Current Time](block-integrations/text.md#get-current-time) | This block outputs the current time |
|
| [Get Current Time](block-integrations/text.md#get-current-time) | This block outputs the current time |
|
||||||
| [Match Text Pattern](block-integrations/text.md#match-text-pattern) | Matches text against a regex pattern and forwards data to positive or negative output based on the match |
|
| [Match Text Pattern](block-integrations/text.md#match-text-pattern) | Matches text against a regex pattern and forwards data to positive or negative output based on the match |
|
||||||
| [Text Decoder](block-integrations/text.md#text-decoder) | Decodes a string containing escape sequences into actual text |
|
| [Text Decoder](block-integrations/text.md#text-decoder) | Decodes a string containing escape sequences into actual text |
|
||||||
|
| [Text Encoder](block-integrations/text.md#text-encoder) | Encodes a string by converting special characters into escape sequences |
|
||||||
| [Text Replace](block-integrations/text.md#text-replace) | This block is used to replace a text with a new text |
|
| [Text Replace](block-integrations/text.md#text-replace) | This block is used to replace a text with a new text |
|
||||||
| [Text Split](block-integrations/text.md#text-split) | This block is used to split a text into a list of strings |
|
| [Text Split](block-integrations/text.md#text-split) | This block is used to split a text into a list of strings |
|
||||||
| [Word Character Count](block-integrations/text.md#word-character-count) | Counts the number of words and characters in a given text |
|
| [Word Character Count](block-integrations/text.md#word-character-count) | Counts the number of words and characters in a given text |
|
||||||
@@ -233,6 +233,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Stagehand Extract](block-integrations/stagehand/blocks.md#stagehand-extract) | Extract structured data from a webpage |
|
| [Stagehand Extract](block-integrations/stagehand/blocks.md#stagehand-extract) | Extract structured data from a webpage |
|
||||||
| [Stagehand Observe](block-integrations/stagehand/blocks.md#stagehand-observe) | Find suggested actions for your workflows |
|
| [Stagehand Observe](block-integrations/stagehand/blocks.md#stagehand-observe) | Find suggested actions for your workflows |
|
||||||
| [Unreal Text To Speech](block-integrations/llm.md#unreal-text-to-speech) | Converts text to speech using the Unreal Speech API |
|
| [Unreal Text To Speech](block-integrations/llm.md#unreal-text-to-speech) | Converts text to speech using the Unreal Speech API |
|
||||||
|
| [Video Narration](block-integrations/video/narration.md#video-narration) | Generate AI narration and add to video |
|
||||||
|
|
||||||
## Search and Information Retrieval
|
## Search and Information Retrieval
|
||||||
|
|
||||||
@@ -472,9 +473,13 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
|
|
||||||
| Block Name | Description |
|
| Block Name | Description |
|
||||||
|------------|-------------|
|
|------------|-------------|
|
||||||
| [Add Audio To Video](block-integrations/multimedia.md#add-audio-to-video) | Block to attach an audio file to a video file using moviepy |
|
| [Add Audio To Video](block-integrations/video/add_audio.md#add-audio-to-video) | Block to attach an audio file to a video file using moviepy |
|
||||||
| [Loop Video](block-integrations/multimedia.md#loop-video) | Block to loop a video to a given duration or number of repeats |
|
| [Loop Video](block-integrations/video/loop.md#loop-video) | Block to loop a video to a given duration or number of repeats |
|
||||||
| [Media Duration](block-integrations/multimedia.md#media-duration) | Block to get the duration of a media file |
|
| [Media Duration](block-integrations/video/duration.md#media-duration) | Block to get the duration of a media file |
|
||||||
|
| [Video Clip](block-integrations/video/clip.md#video-clip) | Extract a time segment from a video |
|
||||||
|
| [Video Concat](block-integrations/video/concat.md#video-concat) | Merge multiple video clips into one continuous video |
|
||||||
|
| [Video Download](block-integrations/video/download.md#video-download) | Download video from URL (YouTube, Vimeo, news sites, direct links) |
|
||||||
|
| [Video Text Overlay](block-integrations/video/text_overlay.md#video-text-overlay) | Add text overlay/caption to video |
|
||||||
|
|
||||||
## Productivity
|
## Productivity
|
||||||
|
|
||||||
@@ -571,6 +576,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Linear Create Comment](block-integrations/linear/comment.md#linear-create-comment) | Creates a new comment on a Linear issue |
|
| [Linear Create Comment](block-integrations/linear/comment.md#linear-create-comment) | Creates a new comment on a Linear issue |
|
||||||
| [Linear Create Issue](block-integrations/linear/issues.md#linear-create-issue) | Creates a new issue on Linear |
|
| [Linear Create Issue](block-integrations/linear/issues.md#linear-create-issue) | Creates a new issue on Linear |
|
||||||
| [Linear Get Project Issues](block-integrations/linear/issues.md#linear-get-project-issues) | Gets issues from a Linear project filtered by status and assignee |
|
| [Linear Get Project Issues](block-integrations/linear/issues.md#linear-get-project-issues) | Gets issues from a Linear project filtered by status and assignee |
|
||||||
|
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
|
||||||
| [Linear Search Projects](block-integrations/linear/projects.md#linear-search-projects) | Searches for projects on Linear |
|
| [Linear Search Projects](block-integrations/linear/projects.md#linear-search-projects) | Searches for projects on Linear |
|
||||||
|
|
||||||
## Hardware
|
## Hardware
|
||||||
|
|||||||
@@ -85,7 +85,6 @@
|
|||||||
* [LLM](block-integrations/llm.md)
|
* [LLM](block-integrations/llm.md)
|
||||||
* [Logic](block-integrations/logic.md)
|
* [Logic](block-integrations/logic.md)
|
||||||
* [Misc](block-integrations/misc.md)
|
* [Misc](block-integrations/misc.md)
|
||||||
* [Multimedia](block-integrations/multimedia.md)
|
|
||||||
* [Notion Create Page](block-integrations/notion/create_page.md)
|
* [Notion Create Page](block-integrations/notion/create_page.md)
|
||||||
* [Notion Read Database](block-integrations/notion/read_database.md)
|
* [Notion Read Database](block-integrations/notion/read_database.md)
|
||||||
* [Notion Read Page](block-integrations/notion/read_page.md)
|
* [Notion Read Page](block-integrations/notion/read_page.md)
|
||||||
@@ -129,5 +128,13 @@
|
|||||||
* [Twitter Timeline](block-integrations/twitter/timeline.md)
|
* [Twitter Timeline](block-integrations/twitter/timeline.md)
|
||||||
* [Twitter Tweet Lookup](block-integrations/twitter/tweet_lookup.md)
|
* [Twitter Tweet Lookup](block-integrations/twitter/tweet_lookup.md)
|
||||||
* [Twitter User Lookup](block-integrations/twitter/user_lookup.md)
|
* [Twitter User Lookup](block-integrations/twitter/user_lookup.md)
|
||||||
|
* [Video Add Audio](block-integrations/video/add_audio.md)
|
||||||
|
* [Video Clip](block-integrations/video/clip.md)
|
||||||
|
* [Video Concat](block-integrations/video/concat.md)
|
||||||
|
* [Video Download](block-integrations/video/download.md)
|
||||||
|
* [Video Duration](block-integrations/video/duration.md)
|
||||||
|
* [Video Loop](block-integrations/video/loop.md)
|
||||||
|
* [Video Narration](block-integrations/video/narration.md)
|
||||||
|
* [Video Text Overlay](block-integrations/video/text_overlay.md)
|
||||||
* [Wolfram LLM API](block-integrations/wolfram/llm_api.md)
|
* [Wolfram LLM API](block-integrations/wolfram/llm_api.md)
|
||||||
* [Zerobounce Validate Emails](block-integrations/zerobounce/validate_emails.md)
|
* [Zerobounce Validate Emails](block-integrations/zerobounce/validate_emails.md)
|
||||||
|
|||||||
@@ -90,9 +90,9 @@ Searches for issues on Linear
|
|||||||
|
|
||||||
### How it works
|
### How it works
|
||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues.
|
This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues. You can limit the number of results returned using the `max_results` parameter (default: 10, max: 100) to control token consumption and response size.
|
||||||
|
|
||||||
Returns a list of issues matching the search term.
|
Optionally filter results by team name to narrow searches to specific workspaces. If a team name is provided, the block resolves it to a team ID before searching. Returns matching issues with their state, creation date, project, and assignee information. If the search or team resolution fails, an error message is returned.
|
||||||
<!-- END MANUAL -->
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
### Inputs
|
### Inputs
|
||||||
@@ -100,12 +100,14 @@ Returns a list of issues matching the search term.
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| term | Term to search for issues | str | Yes |
|
| term | Term to search for issues | str | Yes |
|
||||||
|
| max_results | Maximum number of results to return | int | No |
|
||||||
|
| team_name | Optional team name to filter results (e.g., 'Internal', 'Open Source') | str | No |
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
| Output | Description | Type |
|
| Output | Description | Type |
|
||||||
|--------|-------------|------|
|
|--------|-------------|------|
|
||||||
| error | Error message if the operation failed | str |
|
| error | Error message if the search failed | str |
|
||||||
| issues | List of issues | List[Issue] |
|
| issues | List of issues | List[Issue] |
|
||||||
|
|
||||||
### Possible use case
|
### Possible use case
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user