Compare commits

...

8 Commits

Author SHA1 Message Date
Zamil Majdy
ec7c7ebea2 refactor(backend): extract _extract_agent_json helper, fail fast on unknown poll status 2026-02-25 16:57:49 +07:00
Zamil Majdy
8ef8bec14f fix(backend): validate completed job result type in _submit_and_poll 2026-02-25 16:22:19 +07:00
Zamil Majdy
9b3e25d98e fix(backend): retry transient HTTP errors during polling, validate agent_json responses 2026-02-25 15:44:11 +07:00
Zamil Majdy
0bc098acb1 fix(backend): address PR review - wire timeout setting, use monotonic clock, cap poll errors 2026-02-25 14:53:10 +07:00
Zamil Majdy
d78e0ee122 feat(backend/copilot): use async polling for agent-generator + frontend SSE reconnect
Platform service now submits jobs to agent-generator and polls for results
(10s interval) instead of blocking on a single HTTP call for up to 30 min.
asyncio.sleep in the poll loop yields to the event loop, keeping SSE
heartbeats alive through GCP's L7 load balancer.

Frontend auto-reconnects up to 3 times when SSE drops mid-stream.
2026-02-24 21:01:51 +07:00
Zamil Majdy
0e72e1f5e7 fix(platform/copilot): fix stuck sessions, stop button, and StreamFinish reliability (#12191)
## Summary

- **Fix stuck sessions**: Root cause was `_stream_listener` infinite
xread loop when Redis session metadata TTL expired — `hget` returned
`None` which bypassed the `status != "running"` break condition. Fixed
by treating `None` status as non-running.
- **Fix stop button reliability**: Cancel endpoint now force-completes
via `mark_session_completed` when executor doesn't respond within 5s.
Returns `cancelled=True` for already-expired sessions.
- **Single-owner StreamFinish**: All `yield StreamFinish()` removed from
service layers (sdk/service.py, service.py, dummy.py).
`mark_session_completed` is now the single atomic source of truth for
publishing StreamFinish via Lua CAS script.
- **Rename task → session/turn**: Consistent terminology across
stream_registry and processor.
- **Frontend session refetch**: Added `refetchOnMount: true` so page
refresh re-fetches session state.
- **Test fixes**: Updated e2e, service, and run_agent tests for
StreamFinish removal; fixed async fixture decorators.

## Test plan
- [x] E2E dummy streaming tests pass (13 passed, 1 xfailed)
- [x] run_agent_test.py passes (async fixture decorator fix)
- [x] service_test.py passes (StreamFinish assertions removed)
- [ ] Manual: verify stuck sessions recover on page refresh
- [ ] Manual: verify stop button works for active and expired sessions
- [ ] Manual: verify no duplicate StreamFinish events in SSE stream
2026-02-24 10:49:22 +00:00
Swifty
163b0b3c9d feat(backend): pre-populate CoPilotUnderstanding from Tally form on signup (#12119)
When new users sign up, check if they previously filled out the Tally
beta application form and, if so, pre-populate their
CoPilotUnderstanding with business data extracted from that form. This
gives the CoPilot (Otto) immediate context about the user on their very
first chat interaction.

### Changes 🏗️

- **`backend/util/settings.py`**: Added `tally_api_key` to `Secrets`
class
- **`backend/.env.default`**: Added `TALLY_API_KEY=` env var entry
- **`backend/data/tally.py`** (new): Core Tally integration module
- Redis-cached email index of form submissions (1h TTL) with incremental
refresh via `startDate`
  - Paginated Tally API fetching with Bearer token auth
  - Email matching (case-insensitive) against submission data
- LLM extraction (gpt-4o-mini via OpenRouter) of
`BusinessUnderstandingInput` fields
  - Fire-and-forget orchestrator that is idempotent and never raises
- **`backend/api/features/v1.py`**: Added background task in
`get_or_create_user_route` to trigger Tally lookup on login (skips if
understanding already exists)
- **`backend/data/tally_test.py`** (new): 15 unit tests covering index
building, email case-insensitivity, cache hit/miss, format helpers,
idempotency, graceful degradation, and error resilience

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] All 15 unit tests pass (`poetry run pytest
backend/data/tally_test.py --noconftest -xvs`)
  - [x] Lint clean (`poetry run ruff check` on changed files)
  - [x] Type check clean (`poetry run pyright` on new files)
- [ ] Manual: Set `TALLY_API_KEY` in `.env`, create a new user, verify
CoPilotUnderstanding is populated
- [ ] Manual: Verify user creation succeeds when Tally API key is
missing or API is down

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
- Added `TALLY_API_KEY=` to `.env.default` (optional, empty by default —
feature is a no-op without it)

<!-- greptile_comment -->

<details><summary><h3>Greptile Summary</h3></summary>

This PR adds a Tally form integration that pre-populates
`CoPilotUnderstanding` for new users by matching their signup email
against cached Tally beta application form submissions, then using an
LLM (gpt-4o-mini via OpenRouter) to extract structured business data.

- **New module `tally.py`** implements Redis-cached email indexing of
Tally form submissions with incremental refresh, email matching, LLM
extraction, and an idempotent fire-and-forget orchestrator
- **`v1.py`** adds a background task on the `get_or_create_user_route`
to trigger Tally lookup on every login (idempotency check is inside the
called function)
- **`settings.py` / `.env.default`** adds `tally_api_key` as an optional
secret — feature is a no-op without it
- **`tally_test.py`** adds 15 unit tests with thorough mocking coverage
- **Bug: TTL mismatch** — `_LAST_FETCH_TTL` (2h) > `_INDEX_TTL` (1h)
creates a window where incremental refresh loses all previously indexed
emails because the base index has expired but `last_fetch` persists.
This will cause silent data loss for users whose form submissions were
indexed before the cache expiry
- **Bug: `str.format()` on LLM prompt** — form data containing `{` or
`}` will crash the prompt formatting, silently preventing understanding
population for those users
</details>


<details><summary><h3>Confidence Score: 2/5</h3></summary>

- This PR has two logic bugs that will cause silent data loss in
production — recommend fixing before merge.
- The TTL mismatch between `_LAST_FETCH_TTL` and `_INDEX_TTL` will
intermittently cause incomplete caches, silently dropping users from the
email index. The `str.format()` issue will cause failures for any form
submission containing curly braces. Both bugs are caught by the
top-level exception handler, so they won't crash the service, but they
will silently prevent the feature from working correctly for affected
users. The overall architecture is sound and well-tested for normal
paths.
- `autogpt_platform/backend/backend/data/tally.py` — contains both the
TTL mismatch bug in `_refresh_cache` and the `str.format()` issue in
`extract_business_understanding`
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant User
    participant API as v1.py (get_or_create_user_route)
    participant Tally as tally.py (populate_understanding_from_tally)
    participant DB as Database (understanding)
    participant Redis
    participant TallyAPI as Tally API
    participant LLM as OpenRouter (gpt-4o-mini)

    User->>API: POST /auth/user (JWT)
    API->>API: get_or_create_user(user_data)
    API-->>User: Return user (immediate)
    API->>Tally: asyncio.create_task(populate_understanding_from_tally)

    Tally->>DB: get_business_understanding(user_id)
    alt Understanding exists
        DB-->>Tally: existing understanding
        Note over Tally: Skip (idempotent)
    else No understanding
        DB-->>Tally: None
        Tally->>Tally: Check tally_api_key configured
        Tally->>Redis: Check cached email index
        alt Cache hit
            Redis-->>Tally: email_index + questions
        else Cache miss
            Redis-->>Tally: None
            Tally->>TallyAPI: GET /forms/{id}/submissions (paginated)
            TallyAPI-->>Tally: submissions + questions
            Tally->>Tally: Build email index
            Tally->>Redis: Cache index (1h TTL)
        end
        Tally->>Tally: Lookup email in index
        alt Email found
            Tally->>Tally: format_submission_for_llm()
            Tally->>LLM: Extract BusinessUnderstandingInput
            LLM-->>Tally: JSON structured data
            Tally->>DB: upsert_business_understanding(user_id, input)
        end
    end
```
</details>


<sub>Last reviewed commit: 92d2da4</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Otto (AGPT) <otto@agpt.co>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-24 11:31:29 +01:00
Bently
ef42b17e3b docs: add Podman compatibility warning (#12120)
## Summary
Adds a warning to the Getting Started docs clarifying that **Podman and
podman-compose are not supported**.

## Problem
Users on Windows using `podman-compose` instead of Docker get errors
like:
```
Error: the specified Containerfile or Dockerfile does not exist, ..\..\autogpt_platform\backend\Dockerfile
```

This is because Podman handles relative paths differently than Docker,
causing incorrect path resolution on Windows.

## Solution
- Added a clear warning section after the Windows WSL 2 notes
- Explains the error users might see
- Directs them to install Docker Desktop instead

Closes #11358

<!-- greptile_comment -->

<details><summary><h3>Greptile Summary</h3></summary>

Adds a "Podman Not Supported" warning section to the Getting Started
documentation, placed after the Windows/WSL 2 installation notes. The
section clarifies that Docker is required, shows the typical error
message users encounter when using Podman, and directs them to install
Docker Desktop instead. This addresses issue #11358 where Windows users
using `podman-compose` hit path resolution errors.

- Adds `### ⚠️ Podman Not Supported` section under Manual Setup, after
Windows Installation Note
- Includes the specific error message users see with Podman for easy
identification
- Links to Docker Desktop installation docs as the recommended solution
- Formatting is consistent with existing sections in the document (emoji
headings, code blocks for errors)
</details>


<details><summary><h3>Confidence Score: 5/5</h3></summary>

- This PR is safe to merge — it only adds a documentation warning
section with no code changes.
- The change is a small, well-written documentation addition that adds a
Podman compatibility warning. It touches only one markdown file,
introduces no code changes, and is consistent with the existing document
structure and style. No issues were found.
- No files require special attention.
</details>


<details><summary><h3>Flowchart</h3></summary>

```mermaid
flowchart TD
    A[User wants to run AutoGPT] --> B{Which container runtime?}
    B -->|Docker / Docker Desktop| C[docker compose up -d --build]
    C --> D[AutoGPT starts successfully]
    B -->|Podman / podman-compose| E[podman-compose up -d --build]
    E --> F[Error: Containerfile or Dockerfile does not exist]
    F --> G[New warning section directs user to install Docker Desktop]
    G --> C
```
</details>


<sub>Last reviewed commit: 23ea6bd</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-23 15:19:24 +00:00
62 changed files with 3395 additions and 4629 deletions

View File

@@ -190,5 +190,8 @@ ZEROBOUNCE_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Tally Form Integration (pre-populate business understanding on signup)
TALLY_API_KEY=
# Other Services
AUTOMOD_API_KEY=

View File

@@ -2,23 +2,19 @@
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.completion_handler import (
process_operation_failure,
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -46,9 +42,6 @@ from backend.copilot.tools.models import (
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
SuggestedGoalResponse,
UnderstandingUpdatedResponse,
@@ -99,10 +92,8 @@ class CreateSessionResponse(BaseModel):
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
turn_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
@@ -132,22 +123,13 @@ class ListSessionsResponse(BaseModel):
total: int
class CancelTaskResponse(BaseModel):
"""Response model for the cancel task endpoint."""
class CancelSessionResponse(BaseModel):
"""Response model for the cancel session endpoint."""
cancelled: bool
task_id: str | None = None
reason: str | None = None
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -270,7 +252,7 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
If there's an active stream for this session, returns active_stream info for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
@@ -288,28 +270,21 @@ async def get_session(
# Check if there's an active stream for this session
active_stream_info = None
active_task, last_message_id = await stream_registry.get_active_task_for_session(
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
return SessionDetailResponse(
@@ -329,7 +304,7 @@ async def get_session(
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelTaskResponse:
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
@@ -338,39 +313,33 @@ async def cancel_session_task(
"""
await _validate_and_get_session(session_id, user_id)
active_task, _ = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return CancelTaskResponse(cancelled=False, reason="no_active_task")
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
if not active_session:
return CancelSessionResponse(cancelled=True, reason="no_active_session")
task_id = active_task.task_id
await enqueue_cancel_task(task_id)
logger.info(
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
f"session ...{session_id[-8:]}"
)
await enqueue_cancel_task(session_id)
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
# Poll until the executor confirms the task is no longer running.
# Keep max_wait below typical reverse-proxy read timeouts.
poll_interval = 0.5
max_wait = 5.0
waited = 0.0
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
task = await stream_registry.get_task(task_id)
if task is None or task.status != "running":
session_state = await stream_registry.get_session(session_id)
if session_state is None or session_state.status != "running":
logger.info(
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
)
return CancelTaskResponse(cancelled=True, task_id=task_id)
return CancelSessionResponse(cancelled=True)
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
return CancelTaskResponse(
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
logger.warning(
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
)
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
return CancelSessionResponse(cancelled=True)
@router.post(
@@ -390,16 +359,15 @@ async def stream_chat_post(
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
@@ -446,35 +414,35 @@ async def stream_chat_post(
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_call_id="chat_stream",
tool_name="chat",
operation_id=operation_id,
turn_id=turn_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
}
},
)
await enqueue_copilot_task(
task_id=task_id,
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
operation_id=operation_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
)
@@ -491,7 +459,7 @@ async def stream_chat_post(
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
@@ -499,11 +467,12 @@ async def stream_chat_post(
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
last_message_id=subscribe_from_id,
)
if subscriber_queue is None:
@@ -586,19 +555,19 @@ async def stream_chat_post(
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
f"Error unsubscribing from session {session_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
@@ -645,17 +614,22 @@ async def resume_session_stream(
"""
import asyncio
active_task, _last_id = await stream_registry.get_active_task_for_session(
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
if not active_task:
if not active_session:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
# Subscribe from the beginning ("0-0") to replay all chunks for this turn.
# This is necessary because hydrated messages filter out incomplete tool calls
# to avoid "No tool invocation found" errors. The resume stream delivers
# those tool calls fresh with proper SDK state.
# The AI SDK's deduplication will handle any duplicate chunks.
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
last_message_id="0-0",
)
if subscriber_queue is None:
@@ -691,12 +665,12 @@ async def resume_session_stream(
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
exc_info=True,
)
logger.info(
@@ -747,229 +721,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@@ -1050,9 +801,6 @@ ToolResponseUnion = (
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)

View File

@@ -126,6 +126,9 @@ v1_router = APIRouter()
########################################################
_tally_background_tasks: set[asyncio.Task] = set()
@v1_router.post(
"/auth/user",
summary="Get or create user",
@@ -134,6 +137,24 @@ v1_router = APIRouter()
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
from backend.data.tally import populate_understanding_from_tally
task = asyncio.create_task(
populate_understanding_from_tally(user.id, user.email)
)
_tally_background_tasks.add(task)
task.add_done_callback(_tally_background_tasks.discard)
except Exception:
logger.debug("Failed to start Tally population task", exc_info=True)
return user.model_dump()

View File

@@ -1,5 +1,5 @@
import json
from datetime import datetime
from datetime import datetime, timezone
from io import BytesIO
from unittest.mock import AsyncMock, Mock, patch
@@ -43,6 +43,7 @@ def test_get_or_create_user_route(
) -> None:
"""Test get or create user endpoint"""
mock_user = Mock()
mock_user.created_at = datetime.now(timezone.utc)
mock_user.model_dump.return_value = {
"id": test_user_id,
"email": "test@example.com",

View File

@@ -42,10 +42,6 @@ import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.copilot.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -123,21 +119,9 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:

View File

@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
# Run the last process in the foreground.
processes[-1].start(background=False, **kwargs)
finally:
for process in processes:
for process in reversed(processes):
try:
process.stop()
except Exception as e:

View File

@@ -1,349 +0,0 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import uuid
from typing import Any
import orjson
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
Database operations are handled through the chat_db() accessor, which
routes through DatabaseManager RPC when Prisma is not directly connected.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
await process_operation_success(task, message.result)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
await process_operation_failure(task, message.error)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -1,329 +0,0 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from backend.data.db_accessors import chat_db
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
) -> None:
"""Update tool message in database using the chat_db accessor.
Routes through DatabaseManager RPC when Prisma is not directly
connected (e.g. in the CoPilot Executor microservice).
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
Raises:
ToolMessageUpdateError: If the database update fails.
"""
try:
updated = await chat_db().update_tool_message_content(
session_id=session_id,
tool_call_id=tool_call_id,
new_content=content,
)
if not updated:
raise ToolMessageUpdateError(
f"No message found with tool_call_id="
f"{tool_call_id} in session {session_id}"
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}",
exc_info=True,
)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
Raises:
ToolMessageUpdateError: If the database update fails. The task
will be marked as failed instead of completed.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database
with the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -36,14 +36,6 @@ class ChatConfig(BaseSettings):
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=3600,
description="TTL in seconds for long-running operation deduplication lock "
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
"For longer operations, the stream_registry heartbeat keeps them alive.",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
@@ -59,36 +51,14 @@ class ChatConfig(BaseSettings):
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
session_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
description="Prefix for session metadata hash keys",
)
task_stream_prefix: str = Field(
turn_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
description="Prefix for turn message stream keys",
)
# Langfuse Prompt Management Configuration
@@ -160,14 +130,6 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):

View File

@@ -25,7 +25,7 @@ from backend.util.process import AppProcess
from backend.util.retry import continuous_retry
from backend.util.settings import Settings
from .processor import execute_copilot_task, init_worker
from .processor import execute_copilot_turn, init_worker
from .utils import (
COPILOT_CANCEL_QUEUE_NAME,
COPILOT_EXECUTION_QUEUE_NAME,
@@ -181,13 +181,13 @@ class CoPilotExecutor(AppProcess):
self._executor.shutdown(wait=False)
# Release any remaining locks
for task_id, lock in list(self._task_locks.items()):
for session_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
)
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
@@ -267,20 +267,20 @@ class CoPilotExecutor(AppProcess):
):
"""Handle cancel message from FANOUT exchange."""
request = CancelCoPilotEvent.model_validate_json(body)
task_id = request.task_id
if not task_id:
logger.warning("Cancel message missing 'task_id'")
session_id = request.session_id
if not session_id:
logger.warning("Cancel message missing 'session_id'")
return
if task_id not in self.active_tasks:
logger.debug(f"Cancel received for {task_id} but not active")
if session_id not in self.active_tasks:
logger.debug(f"Cancel received for {session_id} but not active")
return
_, cancel_event = self.active_tasks[task_id]
logger.info(f"Received cancel for {task_id}")
_, cancel_event = self.active_tasks[session_id]
logger.info(f"Received cancel for {session_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(f"Cancel already set for {task_id}")
logger.debug(f"Cancel already set for {session_id}")
def _handle_run_message(
self,
@@ -352,12 +352,12 @@ class CoPilotExecutor(AppProcess):
ack_message(reject=True, requeue=False)
return
task_id = entry.task_id
session_id = entry.session_id
# Check for local duplicate - task is already running on this executor
if task_id in self.active_tasks:
# Check for local duplicate - session is already running on this executor
if session_id in self.active_tasks:
logger.warning(
f"Task {task_id} already running locally, rejecting duplicate"
f"Session {session_id} already running locally, rejecting duplicate"
)
ack_message(reject=True, requeue=False)
return
@@ -365,53 +365,53 @@ class CoPilotExecutor(AppProcess):
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:task:{task_id}:lock",
key=f"copilot:session:{session_id}:lock",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
if current_owner is not None:
logger.warning(f"Task {task_id} already running on pod {current_owner}")
logger.warning(
f"Session {session_id} already running on pod {current_owner}"
)
ack_message(reject=True, requeue=False)
else:
logger.warning(
f"Could not acquire lock for {task_id} - Redis unavailable"
f"Could not acquire lock for {session_id} - Redis unavailable"
)
ack_message(reject=True, requeue=True)
return
# Execute the task
try:
self._task_locks[task_id] = cluster_lock
self._task_locks[session_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
f"Acquired cluster lock for {session_id}, "
f"executor_id={self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_task, entry, cancel_event, cluster_lock
execute_copilot_turn, entry, cancel_event, cluster_lock
)
self.active_tasks[task_id] = (future, cancel_event)
self.active_tasks[session_id] = (future, cancel_event)
except Exception as e:
logger.warning(f"Failed to setup execution for {task_id}: {e}")
logger.warning(f"Failed to setup execution for {session_id}: {e}")
cluster_lock.release()
if task_id in self._task_locks:
del self._task_locks[task_id]
if session_id in self._task_locks:
del self._task_locks[session_id]
ack_message(reject=True, requeue=True)
return
self._update_metrics()
def on_run_done(f: Future):
logger.info(f"Run completed for {task_id}")
logger.info(f"Run completed for {session_id}")
try:
if exec_error := f.exception():
logger.error(f"Execution for {task_id} failed: {exec_error}")
# Don't requeue failed tasks - they've been marked as failed
# in the stream registry. Requeuing would cause infinite retries
# for deterministic failures.
logger.error(f"Execution for {session_id} failed: {exec_error}")
ack_message(reject=True, requeue=False)
else:
ack_message(reject=False, requeue=False)
@@ -419,10 +419,10 @@ class CoPilotExecutor(AppProcess):
logger.exception(f"Error in run completion callback: {e}")
finally:
# Release the cluster lock
if task_id in self._task_locks:
logger.info(f"Releasing cluster lock for {task_id}")
self._task_locks[task_id].release()
del self._task_locks[task_id]
if session_id in self._task_locks:
logger.info(f"Releasing cluster lock for {session_id}")
self._task_locks[session_id].release()
del self._task_locks[session_id]
self._cleanup_completed_tasks()
future.add_done_callback(on_run_done)
@@ -433,11 +433,11 @@ class CoPilotExecutor(AppProcess):
"""Remove completed futures from active_tasks and update metrics."""
completed_tasks = []
with self._active_tasks_lock:
for task_id, (future, _) in list(self.active_tasks.items()):
for session_id, (future, _) in list(self.active_tasks.items()):
if future.done():
completed_tasks.append(task_id)
self.active_tasks.pop(task_id, None)
logger.info(f"Cleaned up completed task {task_id}")
completed_tasks.append(session_id)
self.active_tasks.pop(session_id, None)
logger.info(f"Cleaned up completed session {session_id}")
self._update_metrics()
return completed_tasks

View File

@@ -1,6 +1,6 @@
"""CoPilot execution processor - per-worker execution logic.
This module contains the processor class that handles CoPilot task execution
This module contains the processor class that handles CoPilot session execution
in a thread-local context, following the graph executor pattern.
"""
@@ -12,7 +12,7 @@ import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
from backend.copilot.response_model import StreamFinish
from backend.copilot.sdk import service as sdk_service
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
@@ -32,17 +32,17 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]"
_tls = threading.local()
def execute_copilot_task(
def execute_copilot_turn(
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot task using the thread-local processor.
"""Execute a single CoPilot turn (user message → AI response).
This function is the entry point called by the thread pool executor.
Args:
entry: The task payload
entry: The turn payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for this execution
"""
@@ -76,16 +76,16 @@ def cleanup_worker():
class CoPilotProcessor:
"""Per-worker execution logic for CoPilot tasks.
"""Per-worker execution logic for CoPilot sessions.
This class is instantiated once per worker thread and handles the execution
of CoPilot chat generation tasks. It maintains an async event loop for
of CoPilot chat generation sessions. It maintains an async event loop for
running the async service code.
The execution flow:
1. CoPilot task is picked from RabbitMQ queue
2. Manager submits task to thread pool
3. Processor executes the task in its event loop
1. Session entry is picked from RabbitMQ queue
2. Manager submits to thread pool
3. Processor executes in its event loop
4. Results are published to Redis Streams
"""
@@ -139,19 +139,17 @@ class CoPilotProcessor:
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot task.
"""Execute a CoPilot turn.
This is the main entry point for task execution. It runs the async
execution logic in the worker's event loop and handles errors.
Runs the async logic in the worker's event loop and handles errors.
Args:
entry: The task payload containing session and message info
entry: The turn payload containing session and message info
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock to prevent duplicate execution
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
task_id=entry.task_id,
session_id=entry.session_id,
user_id=entry.user_id,
)
@@ -185,11 +183,20 @@ class CoPilotProcessor:
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
except Exception as e:
except BaseException as e:
elapsed = time.monotonic() - start_time
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
# Note: _execute_async already marks the task as failed before re-raising,
# so we don't call _mark_task_failed here to avoid duplicate error events.
# Safety net: if _execute_async's error handler failed to mark
# the session (e.g. RuntimeError from SDK cleanup), do it here.
try:
asyncio.run_coroutine_threadsafe(
stream_registry.mark_session_completed(
entry.session_id, error_message=str(e) or "Unknown error"
),
self.execution_loop,
).result(timeout=5.0)
except Exception as cleanup_err:
log.error(f"Safety net mark_session_completed failed: {cleanup_err}")
raise
async def _execute_async(
@@ -199,16 +206,16 @@ class CoPilotProcessor:
cluster_lock: ClusterLock,
log: CoPilotLogMetadata,
):
"""Async execution logic for CoPilot task.
"""Async execution logic for a CoPilot turn.
This method calls the existing stream_chat_completion service function
and publishes results to the stream registry.
Calls the stream_chat_completion service function and publishes
results to the stream registry.
Args:
entry: The task payload
entry: The turn payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for refresh
log: Structured logger for this task
log: Structured logger
"""
last_refresh = time.monotonic()
refresh_interval = 30.0 # Refresh lock every 30 seconds
@@ -228,7 +235,7 @@ class CoPilotProcessor:
)
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
# Stream chat completion and publish chunks to Redis
# Stream chat completion and publish chunks to Redis.
async for chunk in stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
@@ -236,56 +243,38 @@ class CoPilotProcessor:
user_id=entry.user_id,
context=entry.context,
):
# Check for cancellation
if cancel.is_set():
log.info("Cancelled during streaming")
await stream_registry.publish_chunk(
entry.task_id, StreamError(errorText="Operation cancelled")
)
await stream_registry.publish_chunk(
entry.task_id, StreamFinishStep()
)
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
await stream_registry.mark_task_completed(
entry.task_id, status="failed"
)
return
log.info("Cancel requested, breaking stream")
break
# Refresh cluster lock periodically
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
# Publish chunk to stream registry
await stream_registry.publish_chunk(entry.task_id, chunk)
# Skip StreamFinish — mark_session_completed publishes it.
if isinstance(chunk, StreamFinish):
continue
# Mark task as completed
await stream_registry.mark_task_completed(entry.task_id, status="completed")
log.info("Task completed successfully")
try:
await stream_registry.publish_chunk(entry.turn_id, chunk)
except Exception as e:
log.error(
f"Error publishing chunk {type(chunk).__name__}: {e}",
exc_info=True,
)
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_task_completed(
entry.task_id,
status="failed",
error_message="Task was cancelled",
error_message = "Operation cancelled" if cancel.is_set() else None
await stream_registry.mark_session_completed(
entry.session_id, error_message=error_message
)
raise
except Exception as e:
log.error(f"Task failed: {e}")
await self._mark_task_failed(entry.task_id, str(e))
except BaseException as e:
log.error(f"Turn failed: {e}")
try:
await stream_registry.mark_session_completed(
entry.session_id, error_message=str(e) or "Unknown error"
)
except Exception as mark_err:
log.error(f"mark_session_completed also failed: {mark_err}")
raise
async def _mark_task_failed(self, task_id: str, error_message: str):
"""Mark a task as failed and publish error to stream registry."""
try:
await stream_registry.publish_chunk(
task_id, StreamError(errorText=error_message)
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception as e:
logger.error(f"Failed to mark task {task_id} as failed: {e}")

View File

@@ -28,7 +28,7 @@ class CoPilotLogMetadata(TruncatedLogger):
Args:
logger: The underlying logger instance
max_length: Maximum log message length before truncation
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
These are added to json_fields in cloud mode, or to the prefix in local mode.
"""
@@ -135,18 +135,15 @@ class CoPilotExecutionEntry(BaseModel):
This model represents a chat generation task to be processed by the executor.
"""
task_id: str
"""Unique identifier for this task (used for stream registry)"""
session_id: str
"""Chat session ID"""
"""Chat session ID (also used for dedup/locking)"""
turn_id: str = ""
"""Per-turn UUID for Redis stream isolation"""
user_id: str | None
"""User ID (may be None for anonymous users)"""
operation_id: str
"""Operation ID for webhook callbacks and completion tracking"""
message: str
"""User's message to process"""
@@ -160,40 +157,37 @@ class CoPilotExecutionEntry(BaseModel):
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
task_id: str
"""Task ID to cancel"""
session_id: str
"""Session ID to cancel"""
# ============ Queue Publishing Helpers ============ #
async def enqueue_copilot_task(
task_id: str,
async def enqueue_copilot_turn(
session_id: str,
user_id: str | None,
operation_id: str,
message: str,
turn_id: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
Args:
task_id: Unique identifier for this task (used for stream registry)
session_id: Chat session ID
session_id: Chat session ID (also used for dedup/locking)
user_id: User ID (may be None for anonymous users)
operation_id: Operation ID for webhook callbacks and completion tracking
message: User's message to process
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
"""
from backend.util.clients import get_async_copilot_queue
entry = CoPilotExecutionEntry(
task_id=task_id,
session_id=session_id,
turn_id=turn_id,
user_id=user_id,
operation_id=operation_id,
message=message,
is_user_message=is_user_message,
context=context,
@@ -207,15 +201,15 @@ async def enqueue_copilot_task(
)
async def enqueue_cancel_task(task_id: str) -> None:
"""Publish a cancel request for a running CoPilot task.
async def enqueue_cancel_task(session_id: str) -> None:
"""Publish a cancel request for a running CoPilot session.
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
pods receive the cancellation signal.
"""
from backend.util.clients import get_async_copilot_queue
event = CancelCoPilotEvent(task_id=task_id)
event = CancelCoPilotEvent(session_id=session_id)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key="", # FANOUT ignores routing key

View File

@@ -14,7 +14,6 @@ import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
# Import here to allow module-level mocking if needed
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
@@ -32,7 +31,6 @@ async def test_parallel_tool_calls_run_concurrently():
for i in range(n_tools)
]
# Minimal session mock
class FakeSession:
session_id = "test"
user_id = "test"
@@ -42,7 +40,7 @@ async def test_parallel_tool_calls_run_concurrently():
original_yield = None
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
@@ -101,7 +99,7 @@ async def test_single_tool_call_works():
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
@@ -144,7 +142,7 @@ async def test_retryable_error_propagates():
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
@@ -175,8 +173,8 @@ async def test_retryable_error_propagates():
@pytest.mark.asyncio
async def test_session_lock_shared():
"""All parallel tools should receive the same lock instance."""
async def test_session_shared_across_parallel_tools():
"""All parallel tools should receive the same session instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
@@ -199,10 +197,10 @@ async def test_session_lock_shared():
def __init__(self):
self.messages = []
observed_locks = []
observed_sessions = []
async def fake_yield(tc_list, idx, sess, lock=None):
observed_locks.append(lock)
async def fake_yield(tc_list, idx, sess):
observed_sessions.append(sess)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
@@ -222,9 +220,8 @@ async def test_session_lock_shared():
finally:
svc._yield_tool_call = orig
assert len(observed_locks) == 3
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
assert isinstance(observed_locks[0], asyncio.Lock)
assert len(observed_sessions) == 3
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
@pytest.mark.asyncio
@@ -251,7 +248,7 @@ async def test_cancellation_cleans_up():
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)

View File

@@ -5,6 +5,8 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
import json
import logging
from enum import Enum
from typing import Any
@@ -12,6 +14,8 @@ from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
logger = logging.getLogger(__name__)
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -47,7 +51,8 @@ class StreamBaseResponse(BaseModel):
def to_sse(self) -> str:
"""Convert to SSE format."""
return f"data: {self.model_dump_json()}\n\n"
json_str = self.model_dump_json(exclude_none=True)
return f"data: {json_str}\n\n"
# ========== Message Lifecycle ==========
@@ -58,15 +63,13 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
sessionId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
description="Session ID for SSE reconnection.",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
@@ -163,8 +166,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,

View File

@@ -0,0 +1,57 @@
"""Dummy SDK service for testing copilot streaming.
Returns mock streaming responses without calling Claude Agent SDK.
Enable via COPILOT_TEST_MODE=true environment variable.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from ..model import ChatSession
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
logger = logging.getLogger(__name__)
async def stream_chat_completion_dummy(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream dummy chat completion for testing.
Returns a simple streaming response with text deltas to test:
- Streaming infrastructure works
- No timeout occurs
- Text arrives in chunks
- StreamFinish is sent by mark_session_completed
"""
logger.warning(
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
)
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
# Start the stream
yield StreamStart(messageId=message_id, sessionId=session_id)
# Simulate streaming text response with delays
dummy_response = "I counted: 1... 2... 3. All done!"
words = dummy_response.split()
for i, word in enumerate(words):
# Add space except for last word
text = word if i == len(words) - 1 else f"{word} "
yield StreamTextDelta(id=text_block_id, delta=text)
# Small delay to simulate real streaming
await asyncio.sleep(0.1)

View File

@@ -55,13 +55,8 @@ class SDKResponseAdapter:
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.resolved_tool_calls: set[str] = set()
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
@property
def has_unresolved_tool_calls(self) -> bool:
"""True when there are tool calls that haven't received output yet."""
@@ -74,7 +69,7 @@ class SDKResponseAdapter:
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
StreamStart(messageId=self.message_id, sessionId=self.session_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())

View File

@@ -37,9 +37,7 @@ from .tool_adapter import wait_for_stash
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
# -- SystemMessage -----------------------------------------------------------
@@ -51,7 +49,7 @@ def test_system_init_emits_start_and_step():
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert results[0].sessionId == "session-1"
assert isinstance(results[1], StreamStartStep)

View File

@@ -13,7 +13,6 @@ from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
@@ -26,19 +25,13 @@ from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import (
_build_system_prompt,
_execute_long_running_tool_with_streaming,
_generate_session_title,
)
from ..tools.models import OperationPendingResponse, OperationStartedResponse
from ..service import _build_system_prompt, _generate_session_title
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
@@ -46,7 +39,6 @@ from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
SDK_DISALLOWED_TOOLS,
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
wait_for_stash,
@@ -84,7 +76,8 @@ class CapturedTranscript:
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
_HEARTBEAT_INTERVAL = 15.0 # seconds
# IMPORTANT: Must be less than frontend timeout (12s in useCopilotPage.ts)
_HEARTBEAT_INTERVAL = 10.0 # seconds
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
@@ -138,127 +131,6 @@ is delivered to the user via a background stream.
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
def _build_long_running_callback(
user_id: str | None,
) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
existing background infrastructure: stream_registry (Redis Streams),
database persistence, and SSE reconnection. This means results survive
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
Args:
user_id: User ID for the session
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
async def _callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
session_id = session.session_id
# --- Build user-friendly messages (matches non-SDK service) ---
if tool_name == "create_agent":
desc = args.get("description", "")
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = args.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Save OperationPendingResponse to chat history ---
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
# Collision detection happens in add_chat_messages_batch (db.py)
session = await upsert_chat_session(session)
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(
f"[SDK] Long-running tool {tool_name} delegated to background "
f"(operation_id={operation_id}, task_id={task_id})"
)
# --- Return OperationStartedResponse as MCP tool result ---
# This flows through SDK → response adapter → frontend, triggering
# the loading widget with SSE reconnection support.
started_json = OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id,
).model_dump_json()
return {
"content": [{"type": "text", "text": started_json}],
"isError": False,
}
return _callback
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
@@ -577,8 +449,7 @@ async def stream_chat_completion_sdk(
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
stream_id = task_id # Use task_id as unique stream identifier
stream_id = str(uuid.uuid4())
# Acquire stream lock to prevent concurrent streams to the same session
lock = AsyncClusterLock(
@@ -599,10 +470,9 @@ async def stream_chat_completion_sdk(
"Please wait or stop it.",
code="stream_already_active",
)
yield StreamFinish()
return
yield StreamStart(messageId=message_id, taskId=task_id)
yield StreamStart(messageId=message_id, sessionId=session_id)
stream_completed = False
# Initialise variables before the try so the finally block can
@@ -618,11 +488,7 @@ async def stream_chat_completion_sdk(
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
)
set_execution_context(user_id, session)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
@@ -714,7 +580,6 @@ async def stream_chat_completion_sdk(
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
@@ -728,7 +593,6 @@ async def stream_chat_completion_sdk(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
query_message = await _build_query_message(
@@ -739,8 +603,7 @@ async def stream_chat_completion_sdk(
session_id,
)
logger.info(
"[SDK] [%s] Sending query — resume=%s, "
"total_msgs=%d, query_len=%d",
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
session_id[:12],
use_resume,
len(session.messages),
@@ -789,8 +652,7 @@ async def stream_chat_completion_sdk(
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally "
"(StopAsyncIteration)",
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
session_id[:12],
)
break
@@ -927,18 +789,6 @@ async def stream_chat_completion_sdk(
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# Save before tool execution starts so the
# pending tool call is visible on refresh /
# other devices. Collision detection happens
# in add_chat_messages_batch (db.py).
try:
session = await upsert_chat_session(session)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
@@ -953,17 +803,6 @@ async def stream_chat_completion_sdk(
)
)
has_tool_results = True
# Save after tool completes so the result is
# visible on refresh / other devices.
# Collision detection happens in add_chat_messages_batch (db.py).
try:
session = await upsert_chat_session(session)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamFinish):
stream_completed = True
@@ -973,8 +812,7 @@ async def stream_chat_completion_sdk(
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled "
"(asyncio.CancelledError)",
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
session_id[:12],
)
raise
@@ -1016,25 +854,20 @@ async def stream_chat_completion_sdk(
)
yield response
# If the stream ended without a ResultMessage (no
# StreamFinish), the SDK CLI exited unexpectedly. Close
# the open step and emit StreamFinish so the frontend
# transitions to the "ready" state.
# If the stream ended without a ResultMessage, the SDK
# CLI exited unexpectedly. Close any open text/step so
# the chunks are well-formed. StreamFinish is published
# by mark_session_completed in the processor.
if not stream_completed:
logger.warning(
"[SDK] [%s] Stream ended without ResultMessage "
"(StopAsyncIteration) — emitting StreamFinish",
"(StopAsyncIteration)",
session_id[:12],
)
if adapter.step_open:
yield StreamFinishStep()
adapter.step_open = False
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
yield StreamFinish()
stream_completed = True
if (
assistant_response.content or assistant_response.tool_calls
@@ -1054,7 +887,7 @@ async def stream_chat_completion_sdk(
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), " "read result: %s",
"[SDK] Transcript source: stop hook (%s), read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
@@ -1095,14 +928,24 @@ async def stream_chat_completion_sdk(
session_id[:12],
len(session.messages),
)
if not stream_completed:
yield StreamFinish()
except asyncio.CancelledError:
# Client disconnect / server shutdown — log but re-raise so
# the framework can clean up. The finally block still runs
# for transcript upload.
# Client disconnect / server shutdown — save session before re-raising
# so accumulated messages aren't lost.
logger.warning("[SDK] [%s] Session cancelled (CancelledError)", session_id[:12])
if session:
try:
await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session saved on cancel (%d messages)",
session_id[:12],
len(session.messages),
)
except Exception as save_err:
logger.error(
"[SDK] [%s] Failed to save session on cancel: %s",
session_id[:12],
save_err,
)
raise
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
@@ -1115,7 +958,6 @@ async def stream_chat_completion_sdk(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
# --- Upload transcript for next-turn --resume ---
# This MUST run in finally so the transcript is uploaded even when

View File

@@ -2,11 +2,6 @@
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import asyncio
@@ -15,7 +10,6 @@ import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
@@ -43,7 +37,8 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
"pending_tool_outputs",
default=None, # type: ignore[arg-type]
)
# Event signaled whenever stash_pending_tool_output() adds a new entry.
# Used by the streaming loop to wait for PostToolUse hooks to complete
@@ -54,22 +49,10 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
"_stash_event", default=None
)
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
@@ -79,14 +62,11 @@ def set_execution_context(
Args:
user_id: Current user's ID.
session: Current chat session.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
@@ -276,11 +256,6 @@ def create_tool_handler(base_tool: BaseTool):
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
@@ -290,25 +265,6 @@ def create_tool_handler(base_tool: BaseTool):
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:

File diff suppressed because it is too large Load Diff

View File

@@ -6,12 +6,7 @@ import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
@@ -30,7 +25,6 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
@@ -40,10 +34,9 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
assert has_ended, "Chat completion did not end"
# StreamFinish is published by mark_session_completed (processor layer),
# not by the service. The generator completing means the stream ended.
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@@ -61,7 +54,6 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
@@ -71,13 +63,9 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
@@ -114,7 +102,6 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
)
turn1_text = ""
turn1_errors: list[str] = []
turn1_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -125,10 +112,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn1_ended = True
assert turn1_ended, "Turn 1 did not finish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
@@ -159,7 +143,6 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
turn2_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -171,10 +154,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn2_ended = True
assert turn2_ended, "Turn 2 did not finish"
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,401 @@
"""End-to-end tests for Copilot streaming with dummy implementations.
These tests verify the complete copilot flow using dummy implementations
for agent generator and SDK service, allowing automated testing without
external LLM calls.
Enable test mode with COPILOT_TEST_MODE=true environment variable.
Note: StreamFinish is NOT emitted by the dummy service — it is published
by mark_session_completed in the processor layer. These tests only cover
the service-level streaming output (StreamStart + StreamTextDelta).
"""
import asyncio
import os
from uuid import uuid4
import pytest
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
from backend.copilot.response_model import (
StreamError,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
)
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@pytest.fixture(autouse=True)
def enable_test_mode():
"""Enable test mode for all tests in this module."""
os.environ["COPILOT_TEST_MODE"] = "true"
yield
os.environ.pop("COPILOT_TEST_MODE", None)
@pytest.mark.asyncio
async def test_dummy_streaming_basic_flow():
"""Test that dummy streaming produces correct event sequence."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-basic",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Verify we got events
assert len(events) > 0, "Should receive events"
# Verify StreamStart
start_events = [e for e in events if isinstance(e, StreamStart)]
assert len(start_events) == 1
assert start_events[0].messageId
assert start_events[0].sessionId
# Verify StreamTextDelta events
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
assert len(text_events) > 0
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0
# Verify order: start before text
start_idx = events.index(start_events[0])
first_text_idx = events.index(text_events[0]) if text_events else -1
if first_text_idx >= 0:
assert start_idx < first_text_idx
print(f"✅ Basic flow: {len(events)} events, {len(text_events)} text deltas")
@pytest.mark.asyncio
async def test_streaming_no_timeout():
"""Test that streaming completes within reasonable time without timeout."""
import time
start_time = time.monotonic()
event_count = 0
async for _event in stream_chat_completion_dummy(
session_id="test-session-timeout",
message="count to 10",
is_user_message=True,
user_id="test-user",
):
event_count += 1
elapsed = time.monotonic() - start_time
# Should complete in < 5 seconds (dummy has 0.1s delays between words)
assert elapsed < 5.0, f"Streaming took {elapsed:.1f}s, expected < 5s"
assert event_count > 0, "Should receive events"
print(f"✅ No timeout: completed in {elapsed:.2f}s with {event_count} events")
@pytest.mark.asyncio
async def test_streaming_event_types():
"""Test that all expected event types are present."""
event_types = set()
async for event in stream_chat_completion_dummy(
session_id="test-session-types",
message="test",
is_user_message=True,
user_id="test-user",
):
event_types.add(type(event).__name__)
# Required event types (StreamFinish is published by processor, not service)
assert "StreamStart" in event_types, "Missing StreamStart"
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
print(f"✅ Event types: {sorted(event_types)}")
@pytest.mark.asyncio
async def test_streaming_text_content():
"""Test that streamed text is coherent and complete."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-content",
message="count to 3",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify text deltas
assert len(text_events) > 0, "Should have text deltas"
# Reconstruct full text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Text should not be empty"
assert (
"1" in full_text or "counted" in full_text.lower()
), "Text should contain count"
# Verify all deltas have IDs
for text_event in text_events:
assert text_event.id, "Text delta must have ID"
assert text_event.delta, "Text delta must have content"
print(f"✅ Text content: '{full_text}' ({len(text_events)} deltas)")
@pytest.mark.asyncio
async def test_streaming_heartbeat_timing():
"""Test that heartbeats are sent at correct interval during long operations."""
# This test would need a dummy that takes longer
# For now, just verify heartbeat structure if we receive one
heartbeats = []
async for event in stream_chat_completion_dummy(
session_id="test-session-heartbeat",
message="test",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamHeartbeat):
heartbeats.append(event)
# Dummy is fast, so we might not get heartbeats
# But if we do, verify they're valid
if heartbeats:
print(f"✅ Heartbeat structure verified ({len(heartbeats)} received)")
else:
print("✅ No heartbeats (dummy executes quickly)")
@pytest.mark.asyncio
async def test_error_handling():
"""Test that errors are properly formatted and sent."""
# This would require a dummy that can trigger errors
# For now, just verify error event structure
error = StreamError(errorText="Test error", code="test_error")
assert error.errorText == "Test error"
assert error.code == "test_error"
assert str(error.type.value) in ["error", "error"]
print("✅ Error structure verified")
@pytest.mark.asyncio
async def test_concurrent_sessions():
"""Test that multiple sessions can stream concurrently."""
async def stream_session(session_id: str) -> int:
count = 0
async for _event in stream_chat_completion_dummy(
session_id=session_id,
message="test",
is_user_message=True,
user_id="test-user",
):
count += 1
return count
# Run 3 concurrent sessions
results = await asyncio.gather(
stream_session("session-1"),
stream_session("session-2"),
stream_session("session-3"),
)
# All should complete successfully
assert all(count > 0 for count in results), "All sessions should produce events"
print(f"✅ Concurrent sessions: {results} events each")
@pytest.mark.asyncio
@pytest.mark.xfail(
reason="Event loop isolation issue with DB operations in tests - needs fixture refactoring"
)
async def test_session_state_persistence():
"""Test that session state is maintained across multiple messages."""
from datetime import datetime, timezone
session_id = f"test-session-{uuid4()}"
user_id = "test-user"
# Create session with first message
session = ChatSession(
session_id=session_id,
user_id=user_id,
messages=[
ChatMessage(role="user", content="Hello"),
ChatMessage(role="assistant", content="Hi there!"),
],
usage=[],
started_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
await upsert_chat_session(session)
# Stream second message
events = []
async for event in stream_chat_completion_dummy(
session_id=session_id,
message="How are you?",
is_user_message=True,
user_id=user_id,
session=session, # Pass existing session
):
events.append(event)
# Verify events were produced
assert len(events) > 0, "Should produce events for second message"
print(f"✅ Session persistence: {len(events)} events for second message")
@pytest.mark.asyncio
async def test_message_deduplication():
"""Test that duplicate messages are filtered out."""
# Simulate receiving duplicate events (e.g., from reconnection)
events = []
# First stream
async for event in stream_chat_completion_dummy(
session_id="test-dedup-1",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Count unique message IDs in StreamStart events
start_events = [e for e in events if isinstance(e, StreamStart)]
message_ids = [e.messageId for e in start_events]
# Verify all IDs are present
assert len(message_ids) == len(set(message_ids)), "Message IDs should be unique"
print(f"✅ Deduplication: {len(events)} events, all unique")
@pytest.mark.asyncio
async def test_event_ordering():
"""Test that events arrive in correct order."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-ordering",
message="Test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Find event indices
start_idx = next(
(i for i, e in enumerate(events) if isinstance(e, StreamStart)), None
)
text_indices = [i for i, e in enumerate(events) if isinstance(e, StreamTextDelta)]
# Verify ordering
assert start_idx is not None, "Should have StreamStart"
assert start_idx == 0, "StreamStart should be first"
if text_indices:
assert all(
start_idx < i for i in text_indices
), "Text deltas should be after start"
print(f"✅ Event ordering: start({start_idx}) < text deltas")
@pytest.mark.asyncio
async def test_stream_completeness():
"""Test that stream includes all required event types."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-completeness",
message="Complete stream test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Check for required events (StreamFinish is published by processor)
has_start = any(isinstance(e, StreamStart) for e in events)
has_text = any(isinstance(e, StreamTextDelta) for e in events)
assert has_start, "Stream must include StreamStart"
assert has_text, "Stream must include text deltas"
# Verify exactly one start
start_count = sum(1 for e in events if isinstance(e, StreamStart))
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
print(
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
)
@pytest.mark.asyncio
async def test_text_delta_consistency():
"""Test that text deltas have consistent IDs and build coherent text."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-consistency",
message="Test consistency",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify all text deltas have IDs
assert all(e.id for e in text_events), "All text deltas must have IDs"
# Verify all deltas have the same ID (same text block)
if text_events:
first_id = text_events[0].id
assert all(
e.id == first_id for e in text_events
), "All text deltas should share the same block ID"
# Verify deltas build coherent text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Deltas should build non-empty text"
assert (
full_text == full_text.strip()
), "Text should not have leading/trailing whitespace artifacts"
print(
f"✅ Consistency: {len(text_events)} deltas with ID '{text_events[0].id if text_events else 'N/A'}', text: '{full_text}'"
)
if __name__ == "__main__":
# Run tests directly
print("Running Copilot E2E tests with dummy implementations...")
print("=" * 60)
asyncio.run(test_dummy_streaming_basic_flow())
asyncio.run(test_streaming_no_timeout())
asyncio.run(test_streaming_event_types())
asyncio.run(test_streaming_text_content())
asyncio.run(test_streaming_heartbeat_timing())
asyncio.run(test_error_handling())
asyncio.run(test_concurrent_sessions())
asyncio.run(test_session_state_persistence())
asyncio.run(test_message_deduplication())
asyncio.run(test_event_ordering())
asyncio.run(test_stream_completeness())
asyncio.run(test_text_delta_consistency())
print("=" * 60)
print("✅ All E2E tests passed!")

View File

@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -47,7 +46,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval

View File

@@ -3,6 +3,7 @@ from datetime import UTC, datetime
from os import getenv
import pytest
import pytest_asyncio
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
@@ -31,14 +32,16 @@ def make_session(user_id: str):
)
@pytest.fixture(scope="session")
async def setup_test_data():
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_test_data(server):
"""
Set up test data for run_agent tests:
1. Create a test user
2. Create a test graph (agent input -> agent output)
3. Create a store listing and store listing version
4. Approve the store listing version
Depends on ``server`` to ensure Prisma is connected.
"""
# 1. Create a test user
user_data = {
@@ -150,14 +153,16 @@ async def setup_test_data():
}
@pytest.fixture(scope="session")
async def setup_llm_test_data():
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_llm_test_data(server):
"""
Set up test data for LLM agent tests:
1. Create a test user
2. Create test OpenAI credentials for the user
3. Create a test graph with input -> LLM block -> output
4. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
key = getenv("OPENAI_API_KEY")
if not key:
@@ -315,13 +320,15 @@ async def setup_llm_test_data():
}
@pytest.fixture(scope="session")
async def setup_firecrawl_test_data():
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_firecrawl_test_data(server):
"""
Set up test data for Firecrawl agent tests (missing credentials scenario):
1. Create a test user (WITHOUT Firecrawl credentials)
2. Create a test graph with input -> Firecrawl block -> output
3. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
# 1. Create a test user
user_data = {

View File

@@ -19,6 +19,7 @@ from .core import (
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_by_ids,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
@@ -49,6 +50,7 @@ __all__ = [
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_by_ids",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",

View File

@@ -3,6 +3,7 @@
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any, NotRequired, TypedDict
from backend.data.db_accessors import graph_db, library_db, store_db
@@ -78,7 +79,7 @@ AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: list[AgentSummary] | list[dict[str, Any]] | None,
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
@@ -190,6 +191,36 @@ async def get_library_agent_by_id(
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_by_ids(
user_id: str,
agent_ids: list[str],
) -> list[LibraryAgentSummary]:
"""Fetch multiple library agents by their IDs.
Args:
user_id: The user ID
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
Returns:
List of LibraryAgentSummary for found agents (silently skips not found)
"""
agents: list[LibraryAgentSummary] = []
for agent_id in agent_ids:
try:
agent = await get_library_agent_by_id(user_id, agent_id)
if agent:
agents.append(agent)
logger.debug(f"Fetched library agent by ID: {agent['name']}")
else:
logger.warning(f"Library agent not found for ID: {agent_id}")
except Exception as e:
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
continue
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
return agents
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
@@ -214,10 +245,17 @@ async def get_library_agents_for_generation(
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
search_term = search_query.strip() if search_query else None
if search_term and len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await library_db().list_library_agents(
user_id=user_id,
search_term=search_query,
search_term=search_term,
page=1,
page_size=max_results,
include_executions=True,
@@ -271,9 +309,16 @@ async def search_marketplace_agents_for_generation(
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
search_term = search_query.strip()
if len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await store_db().get_store_agents(
search_query=search_query,
search_query=search_term,
page=1,
page_size=max_results,
)
@@ -424,7 +469,7 @@ def extract_search_terms_from_steps(
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
@@ -448,7 +493,7 @@ async def enrich_library_agents_from_steps(
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return existing_agents
return list(existing_agents)
existing_ids: set[str] = set()
existing_names: set[str] = set()
@@ -511,7 +556,7 @@ async def enrich_library_agents_from_steps(
async def decompose_goal(
description: str,
context: str = "",
library_agents: list[AgentSummary] | None = None,
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
@@ -539,22 +584,16 @@ async def decompose_goal(
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -562,13 +601,9 @@ async def generate_agent(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
dict(instructions), _to_dict_list(library_agents)
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
@@ -758,9 +793,7 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -773,12 +806,10 @@ async def generate_agent_patch(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -789,8 +820,6 @@ async def generate_agent_patch(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)

View File

@@ -102,10 +102,15 @@ async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent JSON after a simulated delay."""
logger.info("Using dummy agent generator for generate_agent (30s delay)")
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
@@ -115,10 +120,16 @@ async def generate_agent_patch_dummy(
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent (returns the current agent with updated description)."""
logger.info("Using dummy agent generator for generate_agent_patch")
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"

View File

@@ -1,11 +1,13 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
microservice. All generation endpoints use async polling: submit a job (202),
then poll GET /api/jobs/{job_id} every few seconds until the result is ready.
"""
import asyncio
import logging
import time
from typing import Any
import httpx
@@ -25,22 +27,21 @@ logger = logging.getLogger(__name__)
_dummy_mode_warned = False
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
POLL_INTERVAL_SECONDS = 10.0
MAX_POLL_TIME_SECONDS = 1800.0 # 30 minutes
MAX_CONSECUTIVE_POLL_ERRORS = 5
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
"""Create a standardized error response dict."""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
@@ -52,14 +53,7 @@ def _create_error_response(
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
"""Classify an HTTP error into error_type and message."""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
@@ -72,14 +66,7 @@ def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
"""Classify a request error into error_type and message."""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
@@ -89,6 +76,10 @@ def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
return "request_error", f"Request error calling Agent Generator: {e}"
# ---------------------------------------------------------------------------
# Client / settings singletons
# ---------------------------------------------------------------------------
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
@@ -136,13 +127,149 @@ def _get_client() -> httpx.AsyncClient:
global _client
if _client is None:
settings = _get_settings()
timeout = httpx.Timeout(float(settings.config.agentgenerator_timeout))
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
timeout=timeout,
)
return _client
# ---------------------------------------------------------------------------
# Core polling helper
# ---------------------------------------------------------------------------
async def _submit_and_poll(
endpoint: str,
payload: dict[str, Any],
) -> dict[str, Any]:
"""Submit a job to the agent-generator and poll until the result is ready.
The endpoint is expected to return 202 with ``{"job_id": "..."}`` on success.
We then poll ``GET /api/jobs/{job_id}`` every ``POLL_INTERVAL_SECONDS``
until the job completes or fails.
Returns:
The *result* dict from a completed job, or an error dict.
"""
client = _get_client()
# 1. Submit ----------------------------------------------------------------
try:
response = await client.post(endpoint, json=payload)
response.raise_for_status()
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
data = response.json()
job_id = data.get("job_id")
if not job_id:
return _create_error_response(
"Agent Generator did not return a job_id", "invalid_response"
)
logger.info(f"Agent Generator job submitted: {job_id} via {endpoint}")
# 2. Poll ------------------------------------------------------------------
start = time.monotonic()
consecutive_errors = 0
while (time.monotonic() - start) < MAX_POLL_TIME_SECONDS:
await asyncio.sleep(POLL_INTERVAL_SECONDS)
try:
poll_resp = await client.get(f"/api/jobs/{job_id}")
poll_resp.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
return _create_error_response(
"Agent Generator job not found or expired", "job_not_found"
)
status_code = e.response.status_code
if status_code in {429, 503, 504, 408}:
consecutive_errors += 1
logger.warning(
f"Transient HTTP {status_code} polling job {job_id} "
f"({consecutive_errors}/{MAX_CONSECUTIVE_POLL_ERRORS}): {e}"
)
if consecutive_errors >= MAX_CONSECUTIVE_POLL_ERRORS:
error_type, error_msg = _classify_http_error(e)
logger.error(
f"Giving up on job {job_id} after "
f"{MAX_CONSECUTIVE_POLL_ERRORS} consecutive poll errors: {error_msg}"
)
return _create_error_response(error_msg, error_type)
continue
error_type, error_msg = _classify_http_error(e)
logger.error(f"Poll error for job {job_id}: {error_msg}")
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
consecutive_errors += 1
logger.warning(
f"Transient poll error for job {job_id} "
f"({consecutive_errors}/{MAX_CONSECUTIVE_POLL_ERRORS}): {e}"
)
if consecutive_errors >= MAX_CONSECUTIVE_POLL_ERRORS:
error_msg = (
f"Giving up on job {job_id} after "
f"{MAX_CONSECUTIVE_POLL_ERRORS} consecutive poll errors: {e}"
)
logger.error(error_msg)
return _create_error_response(error_msg, "poll_error")
continue
consecutive_errors = 0
poll_data = poll_resp.json()
status = poll_data.get("status")
if status == "completed":
logger.info(f"Agent Generator job {job_id} completed")
result = poll_data.get("result", {})
if not isinstance(result, dict):
return _create_error_response(
"Agent Generator returned invalid result payload",
"invalid_response",
)
return result
elif status == "failed":
error_msg = poll_data.get("error", "Job failed")
logger.error(f"Agent Generator job {job_id} failed: {error_msg}")
return _create_error_response(error_msg, "job_failed")
elif status in {"running", "pending", "queued"}:
continue
else:
return _create_error_response(
f"Agent Generator returned unexpected job status: {status}",
"invalid_response",
)
return _create_error_response("Agent generation timed out after polling", "timeout")
def _extract_agent_json(result: dict[str, Any]) -> dict[str, Any]:
"""Extract and validate agent_json from a job result.
Returns the agent_json dict, or an error response if missing/invalid.
"""
agent_json = result.get("agent_json")
if not isinstance(agent_json, dict):
return _create_error_response(
"Agent Generator returned no agent_json in result", "invalid_response"
)
return agent_json
# ---------------------------------------------------------------------------
# Public functions — same signatures as before, now using polling
# ---------------------------------------------------------------------------
async def decompose_goal_external(
description: str,
context: str = "",
@@ -150,25 +277,17 @@ async def decompose_goal_external(
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns one of the following dicts (keyed by ``"type"``):
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
* ``{"type": "instructions", "steps": [...]}``
* ``{"type": "clarifying_questions", "questions": [...]}``
* ``{"type": "unachievable_goal", "reason": ..., "suggested_goal": ...}``
* ``{"type": "vague_goal", "suggested_goal": ...}``
* ``{"type": "error", "error": ..., "error_type": ...}``
"""
if _is_dummy_mode():
return await decompose_goal_dummy(description, context, library_agents)
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
@@ -177,236 +296,113 @@ async def decompose_goal_external(
payload["library_agents"] = library_agents
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
result = await _submit_and_poll("/api/decompose-description", payload)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
# The result dict from the job is already in the expected format
# (type, steps, questions, etc.) — just return it as-is.
if result.get("type") == "error":
return result
response_type = result.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": result.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": result.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": result.get("reason"),
"suggested_goal": result.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": result.get("suggested_goal"),
}
else:
logger.error(f"Unknown response type from Agent Generator job: {response_type}")
return _create_error_response(
f"Unknown response type: {response_type}",
"invalid_response",
)
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
Agent JSON dict or error dict {"type": "error", ...} on error.
"""
if _is_dummy_mode():
return await generate_agent_dummy(
instructions, library_agents, operation_id, task_id
)
return await generate_agent_dummy(instructions, library_agents)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
result = await _submit_and_poll("/api/generate-agent", payload)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
if result.get("type") == "error":
return result
return _extract_agent_json(result)
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
Updated agent JSON, clarifying questions dict, or error dict.
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents, operation_id, task_id
update_request, current_agent, library_agents
)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
result = await _submit_and_poll("/api/update-agent", payload)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
if result.get("type") == "error":
return result
if result.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": result.get("questions", []),
}
return _extract_agent_json(result)
async def customize_template_external(
template_agent: dict[str, Any],
@@ -415,81 +411,51 @@ async def customize_template_external(
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
Customized agent JSON, clarifying questions dict, or error dict.
"""
if _is_dummy_mode():
return await customize_template_dummy(
template_agent, modification_request, context
)
client = _get_client()
request = modification_request
request_text = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
request_text = (
f"{modification_request}\n\nAdditional context from user:\n{context}"
)
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
"modification_request": request_text,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
result = await _submit_and_poll("/api/template-modification", payload)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
if result.get("type") == "error":
return result
if result.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": result.get("questions", []),
}
return _extract_agent_json(result)
# ---------------------------------------------------------------------------
# Non-generation endpoints (still synchronous — quick responses)
# ---------------------------------------------------------------------------
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
"""Get available blocks from the external service."""
if _is_dummy_mode():
return await get_blocks_dummy()
@@ -518,11 +484,7 @@ async def get_blocks_external() -> list[dict[str, Any]] | None:
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
"""Check if the external service is healthy."""
if not is_external_service_configured():
return False

View File

@@ -36,16 +36,6 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -1,124 +0,0 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.copilot import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -10,7 +10,6 @@ from .agent_generator import (
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -18,7 +17,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -40,17 +38,16 @@ class CreateAgentTool(BaseTool):
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true."
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -70,6 +67,15 @@ class CreateAgentTool(BaseTool):
"Include any preferences or constraints mentioned by the user."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
),
},
"save": {
"type": "boolean",
"description": (
@@ -97,12 +103,14 @@ class CreateAgentTool(BaseTool):
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
)
if not description:
return ErrorResponse(
@@ -111,25 +119,34 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id:
if user_id and library_agent_ids:
try:
library_agents = await get_all_relevant_agents_for_generation(
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
user_id=user_id,
search_query=description,
include_marketplace=True,
agent_ids=library_agent_ids,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
logger.warning(f"Failed to fetch library agents by IDs: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -230,10 +247,17 @@ class CreateAgentTool(BaseTool):
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -276,25 +300,20 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
@@ -320,6 +339,13 @@ class CreateAgentTool(BaseTool):
agent_json, user_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
@@ -330,6 +356,12 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",

View File

@@ -43,11 +43,6 @@ async def test_vague_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -78,11 +73,6 @@ async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -120,11 +110,6 @@ async def test_clarifying_questions_returns_clarification_needed_response(
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,

View File

@@ -46,10 +46,6 @@ class CustomizeAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -9,7 +9,6 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -17,7 +16,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -38,17 +36,16 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts."
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -74,6 +71,15 @@ class EditAgentTool(BaseTool):
"Additional context or answers to previous clarifying questions."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
),
},
"save": {
"type": "boolean",
"description": (
@@ -102,13 +108,10 @@ class EditAgentTool(BaseTool):
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -132,21 +135,25 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id:
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
library_agents = await get_library_agents_by_ids(
user_id=user_id,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
agent_ids=filtered_ids,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
logger.warning(f"Failed to fetch library agents by IDs: {e}")
update_request = changes
if context:
@@ -157,8 +164,6 @@ class EditAgentTool(BaseTool):
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -178,19 +183,6 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")

View File

@@ -366,12 +366,15 @@ class TestFindBlockFiltering:
return_value=(search_results, len(search_results))
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
), patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
),
):
tool = FindBlockTool()
response = await tool._execute(

View File

@@ -36,8 +36,6 @@ class ResponseType(str, Enum):
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
@@ -45,8 +43,6 @@ class ResponseType(str, Enum):
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
@@ -420,34 +416,6 @@ class BlockOutputResponse(ToolResponseBase):
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
@@ -459,23 +427,6 @@ class OperationInProgressResponse(ToolResponseBase):
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""

View File

@@ -160,9 +160,10 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
)
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
# Get block schemas for details/validation
try:

View File

@@ -0,0 +1,426 @@
"""Tally form integration: cache submissions, match by email, extract business understanding."""
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Optional
from openai import AsyncOpenAI
from backend.data.redis_client import get_redis_async
from backend.data.understanding import (
BusinessUnderstandingInput,
get_business_understanding,
upsert_business_understanding,
)
from backend.util.request import Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
TALLY_API_BASE = "https://api.tally.so"
_settings = Settings()
TALLY_FORM_ID = _settings.secrets.tally_form_id
# Redis key templates
_EMAIL_INDEX_KEY = "tally:form:{form_id}:email_index"
_QUESTIONS_KEY = "tally:form:{form_id}:questions"
_LAST_FETCH_KEY = "tally:form:{form_id}:last_fetch"
# TTLs — keep aligned so last_fetch never outlives the index
_INDEX_TTL = 3600 # 1 hour
_LAST_FETCH_TTL = 3600 # 1 hour (same as index)
# Pagination
_PAGE_LIMIT = 500
_MAX_PAGES = 100
# LLM extraction timeout (seconds)
_LLM_TIMEOUT = 30
def _mask_email(email: str) -> str:
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
try:
local, domain = email.rsplit("@", 1)
if len(local) <= 2:
masked_local = local[0] + "***"
else:
masked_local = local[0] + "***" + local[-1]
return f"{masked_local}@{domain}"
except (ValueError, IndexError):
return "***"
async def _fetch_tally_page(
client: Requests,
form_id: str,
page: int,
limit: int = _PAGE_LIMIT,
start_date: Optional[str] = None,
) -> dict:
"""Fetch a single page of submissions from the Tally API."""
url = f"{TALLY_API_BASE}/forms/{form_id}/submissions?page={page}&limit={limit}"
if start_date:
url += f"&startDate={start_date}"
response = await client.get(url)
return response.json()
def _make_tally_client(api_key: str) -> Requests:
"""Create a Requests client configured for the Tally API."""
return Requests(
trusted_origins=[TALLY_API_BASE],
raise_for_status=True,
extra_headers={
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
},
)
async def _fetch_all_submissions(
client: Requests,
form_id: str,
start_date: Optional[str] = None,
max_pages: int = _MAX_PAGES,
) -> tuple[list[dict], list[dict]]:
"""Paginate through all Tally submissions. Returns (questions, submissions)."""
questions: list[dict] = []
all_submissions: list[dict] = []
page = 1
while True:
data = await _fetch_tally_page(client, form_id, page, start_date=start_date)
if page == 1:
questions = data.get("questions", [])
submissions = data.get("submissions", [])
all_submissions.extend(submissions)
# Tally API uses `hasMore` for pagination
has_more = data.get("hasMore", False)
if not has_more:
break
if page >= max_pages:
total = data.get("totalNumberOfSubmissionsPerFilter", {}).get("all", "?")
logger.warning(
f"Tally: hit max page cap ({max_pages}) for form {form_id}, "
f"fetched {len(all_submissions)} of {total} total submissions"
)
break
page += 1
return questions, all_submissions
def _build_email_index(
submissions: list[dict], questions: list[dict]
) -> dict[str, dict]:
"""Build an {email -> submission_data} index from submissions.
Scans question titles for email/contact fields to find the email answer.
"""
# Find question IDs that are likely email fields
email_question_ids: list[str] = []
for q in questions:
label = (q.get("label") or q.get("title") or q.get("name") or "").lower()
q_type = (q.get("type") or "").lower()
if q_type in ("input_email", "email"):
email_question_ids.append(q["id"])
elif any(kw in label for kw in ("email", "e-mail", "contact")):
email_question_ids.append(q["id"])
index: dict[str, dict] = {}
for sub in submissions:
email = _extract_email_from_submission(sub, email_question_ids)
if email:
index[email.lower()] = {
"responses": sub.get("responses", sub.get("fields", [])),
"submitted_at": sub.get("submittedAt", sub.get("createdAt", "")),
"questions": sub.get("questions", []),
}
return index
def _extract_email_from_submission(
submission: dict, email_question_ids: list[str]
) -> Optional[str]:
"""Extract email address from a submission by checking respondentEmail, then field responses."""
# Try respondent email first (Tally often includes this)
respondent_email = submission.get("respondentEmail")
if respondent_email:
return respondent_email
# Search through responses/fields for matching question IDs
responses = submission.get("responses", submission.get("fields", []))
if isinstance(responses, list):
for resp in responses:
q_id = resp.get("questionId") or resp.get("key") or resp.get("id")
if q_id in email_question_ids:
value = resp.get("value") or resp.get("answer")
if isinstance(value, str) and "@" in value:
return value
elif isinstance(responses, dict):
for q_id in email_question_ids:
value = responses.get(q_id)
if isinstance(value, str) and "@" in value:
return value
return None
async def _get_cached_index(
form_id: str,
) -> tuple[Optional[dict], Optional[list]]:
"""Return (email_index, questions) from Redis, or (None, None) on cache miss."""
redis = await get_redis_async()
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
raw_index = await redis.get(index_key)
raw_questions = await redis.get(questions_key)
if raw_index and raw_questions:
return json.loads(raw_index), json.loads(raw_questions)
return None, None
async def _refresh_cache(form_id: str) -> tuple[dict, list]:
"""Refresh the Tally submission cache. Uses incremental fetch when possible.
Returns (email_index, questions).
"""
settings = Settings()
client = _make_tally_client(settings.secrets.tally_api_key)
redis = await get_redis_async()
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
last_fetch = await redis.get(last_fetch_key)
if last_fetch:
# Try to load existing index for incremental merge
raw_existing = await redis.get(index_key)
if raw_existing is None:
# Index expired but last_fetch still present — fall back to full fetch
logger.info("Tally: last_fetch present but index missing, doing full fetch")
questions, submissions = await _fetch_all_submissions(client, form_id)
email_index = _build_email_index(submissions, questions)
else:
# Incremental fetch: only get new submissions since last fetch
logger.info(f"Tally incremental fetch since {last_fetch}")
questions, new_submissions = await _fetch_all_submissions(
client, form_id, start_date=last_fetch
)
existing_index: dict[str, dict] = json.loads(raw_existing)
if not questions:
raw_q = await redis.get(questions_key)
if raw_q:
questions = json.loads(raw_q)
new_index = _build_email_index(new_submissions, questions)
existing_index.update(new_index)
email_index = existing_index
else:
# Full initial fetch
logger.info("Tally full initial fetch")
questions, submissions = await _fetch_all_submissions(client, form_id)
email_index = _build_email_index(submissions, questions)
# Store in Redis
now = datetime.now(timezone.utc).isoformat()
await redis.setex(index_key, _INDEX_TTL, json.dumps(email_index))
await redis.setex(questions_key, _INDEX_TTL, json.dumps(questions))
await redis.setex(last_fetch_key, _LAST_FETCH_TTL, now)
logger.info(f"Tally cache refreshed: {len(email_index)} emails indexed")
return email_index, questions
async def find_submission_by_email(
form_id: str, email: str
) -> Optional[tuple[dict, list]]:
"""Look up a Tally submission by email. Uses cache when available.
Returns (submission_data, questions) or None.
"""
email_lower = email.lower()
# Try cache first
email_index, questions = await _get_cached_index(form_id)
if email_index is not None and questions is not None:
sub = email_index.get(email_lower)
if sub is not None:
return sub, questions
return None
# Cache miss - refresh
email_index, questions = await _refresh_cache(form_id)
sub = email_index.get(email_lower)
if sub is not None:
return sub, questions
return None
def format_submission_for_llm(submission: dict, questions: list[dict]) -> str:
"""Format a submission as readable Q&A text for LLM consumption."""
# Build question ID -> title lookup
q_titles: dict[str, str] = {}
for q in questions:
q_id = q.get("id", "")
title = q.get("label") or q.get("title") or q.get("name") or f"Question {q_id}"
q_titles[q_id] = title
lines: list[str] = []
responses = submission.get("responses", [])
if isinstance(responses, list):
for resp in responses:
q_id = resp.get("questionId") or resp.get("key") or resp.get("id") or ""
title = q_titles.get(q_id, f"Question {q_id}")
value = resp.get("value") or resp.get("answer") or ""
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
elif isinstance(responses, dict):
for q_id, value in responses.items():
title = q_titles.get(q_id, f"Question {q_id}")
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
return "\n\n".join(lines)
def _format_answer(value: object) -> str:
"""Convert an answer value (str, list, dict, None) to a human-readable string."""
if value is None:
return "(no answer)"
if isinstance(value, list):
return ", ".join(str(v) for v in value)
if isinstance(value, dict):
parts = [f"{k}: {v}" for k, v in value.items() if v]
return "; ".join(parts) if parts else "(no answer)"
return str(value)
_EXTRACTION_PROMPT = """\
You are a business analyst. Given the following form submission data, extract structured business understanding information.
Return a JSON object with ONLY the fields that can be confidently extracted. Use null for fields that cannot be determined.
Fields:
- user_name (string): the person's name
- job_title (string): their job title
- business_name (string): company/business name
- industry (string): industry or sector
- business_size (string): company size e.g. "1-10", "11-50", "51-200"
- user_role (string): their role context e.g. "decision maker", "implementer"
- key_workflows (list of strings): key business workflows
- daily_activities (list of strings): daily activities performed
- pain_points (list of strings): current pain points
- bottlenecks (list of strings): process bottlenecks
- manual_tasks (list of strings): manual/repetitive tasks
- automation_goals (list of strings): desired automation goals
- current_software (list of strings): software/tools currently used
- existing_automation (list of strings): existing automations
- additional_notes (string): any additional context
Form data:
"""
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
async def extract_business_understanding(
formatted_text: str,
) -> BusinessUnderstandingInput:
"""Use an LLM to extract structured business understanding from form text.
Raises on timeout or unparseable response so the caller can handle it.
"""
settings = Settings()
api_key = settings.secrets.open_router_api_key
client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"{_EXTRACTION_PROMPT}{formatted_text}{_EXTRACTION_SUFFIX}",
}
],
response_format={"type": "json_object"},
temperature=0.0,
),
timeout=_LLM_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning("Tally: LLM extraction timed out")
raise
raw = response.choices[0].message.content or "{}"
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Tally: LLM returned invalid JSON, skipping extraction")
raise
# Filter out null values before constructing
cleaned = {k: v for k, v in data.items() if v is not None}
return BusinessUnderstandingInput(**cleaned)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
Fire-and-forget safe — all exceptions are caught and logged.
"""
try:
# Check if understanding already exists (idempotency)
existing = await get_business_understanding(user_id)
if existing is not None:
logger.debug(
f"Tally: user {user_id} already has business understanding, skipping"
)
return
# Check API key is configured
settings = Settings()
if not settings.secrets.tally_api_key:
logger.debug("Tally: no API key configured, skipping")
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
logger.info(f"Tally: successfully populated understanding for user {user_id}")
except Exception:
logger.exception(f"Tally: error populating understanding for user {user_id}")

View File

@@ -0,0 +1,589 @@
"""Tests for backend.data.tally module."""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.tally import (
_EXTRACTION_PROMPT,
_EXTRACTION_SUFFIX,
_build_email_index,
_format_answer,
_make_tally_client,
_mask_email,
_refresh_cache,
extract_business_understanding,
find_submission_by_email,
format_submission_for_llm,
populate_understanding_from_tally,
)
# ── Fixtures ──────────────────────────────────────────────────────────────────
SAMPLE_QUESTIONS = [
{"id": "q1", "label": "What is your name?", "type": "INPUT_TEXT"},
{"id": "q2", "label": "Email address", "type": "INPUT_EMAIL"},
{"id": "q3", "label": "Company name", "type": "INPUT_TEXT"},
{"id": "q4", "label": "Industry", "type": "INPUT_TEXT"},
]
SAMPLE_SUBMISSIONS = [
{
"respondentEmail": None,
"responses": [
{"questionId": "q1", "value": "Alice Smith"},
{"questionId": "q2", "value": "alice@example.com"},
{"questionId": "q3", "value": "Acme Corp"},
{"questionId": "q4", "value": "Technology"},
],
"submittedAt": "2025-01-15T10:00:00Z",
},
{
"respondentEmail": "bob@example.com",
"responses": [
{"questionId": "q1", "value": "Bob Jones"},
{"questionId": "q2", "value": "bob@example.com"},
{"questionId": "q3", "value": "Bob's Burgers"},
{"questionId": "q4", "value": "Food"},
],
"submittedAt": "2025-01-16T10:00:00Z",
},
]
# ── _build_email_index ────────────────────────────────────────────────────────
def test_build_email_index():
index = _build_email_index(SAMPLE_SUBMISSIONS, SAMPLE_QUESTIONS)
assert "alice@example.com" in index
assert "bob@example.com" in index
assert len(index) == 2
def test_build_email_index_case_insensitive():
submissions = [
{
"respondentEmail": None,
"responses": [
{"questionId": "q2", "value": "Alice@Example.COM"},
],
"submittedAt": "2025-01-15T10:00:00Z",
},
]
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
assert "alice@example.com" in index
assert "Alice@Example.COM" not in index
def test_build_email_index_empty():
index = _build_email_index([], SAMPLE_QUESTIONS)
assert index == {}
def test_build_email_index_no_email_field():
questions = [{"id": "q1", "label": "Name", "type": "INPUT_TEXT"}]
submissions = [
{
"responses": [{"questionId": "q1", "value": "Alice"}],
"submittedAt": "2025-01-15T10:00:00Z",
}
]
index = _build_email_index(submissions, questions)
assert index == {}
def test_build_email_index_respondent_email():
"""respondentEmail takes precedence over field scanning."""
submissions = [
{
"respondentEmail": "direct@example.com",
"responses": [
{"questionId": "q2", "value": "field@example.com"},
],
"submittedAt": "2025-01-15T10:00:00Z",
}
]
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
assert "direct@example.com" in index
assert "field@example.com" not in index
# ── format_submission_for_llm ─────────────────────────────────────────────────
def test_format_submission_for_llm():
submission = {
"responses": [
{"questionId": "q1", "value": "Alice Smith"},
{"questionId": "q3", "value": "Acme Corp"},
],
}
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
assert "Q: What is your name?" in result
assert "A: Alice Smith" in result
assert "Q: Company name" in result
assert "A: Acme Corp" in result
def test_format_submission_for_llm_dict_responses():
submission = {
"responses": {
"q1": "Alice Smith",
"q3": "Acme Corp",
},
}
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
assert "A: Alice Smith" in result
assert "A: Acme Corp" in result
def test_format_answer_types():
assert _format_answer(None) == "(no answer)"
assert _format_answer("hello") == "hello"
assert _format_answer(["a", "b"]) == "a, b"
assert _format_answer({"key": "val"}) == "key: val"
assert _format_answer(42) == "42"
# ── find_submission_by_email ──────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_find_submission_by_email_cache_hit():
cached_index = {
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
}
cached_questions = SAMPLE_QUESTIONS
with patch(
"backend.data.tally._get_cached_index",
new_callable=AsyncMock,
return_value=(cached_index, cached_questions),
) as mock_cache:
result = await find_submission_by_email("form123", "alice@example.com")
mock_cache.assert_awaited_once_with("form123")
assert result is not None
sub, questions = result
assert sub["submitted_at"] == "2025-01-15"
@pytest.mark.asyncio
async def test_find_submission_by_email_cache_miss():
refreshed_index = {
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
}
with (
patch(
"backend.data.tally._get_cached_index",
new_callable=AsyncMock,
return_value=(None, None),
),
patch(
"backend.data.tally._refresh_cache",
new_callable=AsyncMock,
return_value=(refreshed_index, SAMPLE_QUESTIONS),
) as mock_refresh,
):
result = await find_submission_by_email("form123", "alice@example.com")
mock_refresh.assert_awaited_once_with("form123")
assert result is not None
@pytest.mark.asyncio
async def test_find_submission_by_email_no_match():
cached_index = {
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
}
with patch(
"backend.data.tally._get_cached_index",
new_callable=AsyncMock,
return_value=(cached_index, SAMPLE_QUESTIONS),
):
result = await find_submission_by_email("form123", "unknown@example.com")
assert result is None
# ── populate_understanding_from_tally ─────────────────────────────────────────
@pytest.mark.asyncio
async def test_populate_understanding_skips_existing():
"""If user already has understanding, skip entirely."""
mock_understanding = MagicMock()
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=mock_understanding,
) as mock_get,
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
) as mock_find,
):
await populate_understanding_from_tally("user-1", "test@example.com")
mock_get.assert_awaited_once_with("user-1")
mock_find.assert_not_awaited()
@pytest.mark.asyncio
async def test_populate_understanding_skips_no_api_key():
"""If no Tally API key, skip gracefully."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = ""
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
) as mock_find,
):
await populate_understanding_from_tally("user-1", "test@example.com")
mock_find.assert_not_awaited()
@pytest.mark.asyncio
async def test_populate_understanding_handles_errors():
"""Must never raise, even on unexpected errors."""
with patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
side_effect=RuntimeError("DB down"),
):
# Should not raise
await populate_understanding_from_tally("user-1", "test@example.com")
@pytest.mark.asyncio
async def test_populate_understanding_full_flow():
"""Happy path: no existing understanding, finds submission, extracts, upserts."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
submission = {
"responses": [
{"questionId": "q1", "value": "Alice"},
{"questionId": "q3", "value": "Acme"},
],
}
mock_input = MagicMock()
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
return_value=mock_input,
) as mock_extract,
patch(
"backend.data.tally.upsert_business_understanding",
new_callable=AsyncMock,
) as mock_upsert,
):
await populate_understanding_from_tally("user-1", "alice@example.com")
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once_with("user-1", mock_input)
@pytest.mark.asyncio
async def test_populate_understanding_handles_llm_timeout():
"""LLM timeout is caught and doesn't raise."""
import asyncio
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
submission = {
"responses": [{"questionId": "q1", "value": "Alice"}],
}
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError(),
),
patch(
"backend.data.tally.upsert_business_understanding",
new_callable=AsyncMock,
) as mock_upsert,
):
await populate_understanding_from_tally("user-1", "alice@example.com")
mock_upsert.assert_not_awaited()
# ── _mask_email ───────────────────────────────────────────────────────────────
def test_mask_email():
assert _mask_email("alice@example.com") == "a***e@example.com"
assert _mask_email("ab@example.com") == "a***@example.com"
assert _mask_email("a@example.com") == "a***@example.com"
def test_mask_email_invalid():
assert _mask_email("no-at-sign") == "***"
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
def test_extraction_prompt_safe_with_curly_braces():
"""User content with curly braces must not break prompt construction.
Previously _EXTRACTION_PROMPT.format(submission_text=...) would raise
KeyError/ValueError if the user text contained { or }.
"""
text_with_braces = "Q: What tools do you use?\nA: We use {Slack} and {{Jira}}"
# This must not raise — the old .format() call would fail here
prompt = f"{_EXTRACTION_PROMPT}{text_with_braces}{_EXTRACTION_SUFFIX}"
assert text_with_braces in prompt
assert prompt.startswith("You are a business analyst.")
assert prompt.endswith("Return ONLY valid JSON.")
def test_extraction_prompt_no_format_placeholders():
"""_EXTRACTION_PROMPT must not contain Python format placeholders."""
assert "{submission_text}" not in _EXTRACTION_PROMPT
# Ensure no stray single-brace placeholders
# (double braces {{ are fine — they're literal in format strings)
import re
single_braces = re.findall(r"(?<!\{)\{[^{].*?\}(?!\})", _EXTRACTION_PROMPT)
assert single_braces == [], f"Found format placeholders: {single_braces}"
# ── extract_business_understanding ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_business_understanding_success():
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"user_name": "Alice",
"business_name": "Acme Corp",
"industry": "Technology",
"pain_points": ["manual reporting"],
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name == "Acme Corp"
assert result.industry == "Technology"
assert result.pain_points == ["manual reporting"]
@pytest.mark.asyncio
async def test_extract_business_understanding_filters_nulls():
"""Null values from LLM should be excluded from the result."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{"user_name": "Alice", "business_name": None, "industry": None}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name is None
assert result.industry is None
@pytest.mark.asyncio
async def test_extract_business_understanding_invalid_json():
"""Invalid JSON from LLM should raise JSONDecodeError."""
mock_choice = MagicMock()
mock_choice.message.content = "not valid json {"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with (
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
pytest.raises(json.JSONDecodeError),
):
await extract_business_understanding("Q: Name?\nA: Alice")
@pytest.mark.asyncio
async def test_extract_business_understanding_timeout():
"""LLM timeout should propagate as asyncio.TimeoutError."""
mock_client = AsyncMock()
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
with (
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
pytest.raises(asyncio.TimeoutError),
):
await extract_business_understanding("Q: Name?\nA: Alice")
# ── _refresh_cache ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_refresh_cache_full_fetch():
"""First fetch (no last_fetch in Redis) should do a full fetch and store in Redis."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
mock_redis = AsyncMock()
mock_redis.get.return_value = None # No last_fetch, no cached index
questions = SAMPLE_QUESTIONS
submissions = SAMPLE_SUBMISSIONS
with (
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
),
patch(
"backend.data.tally._fetch_all_submissions",
new_callable=AsyncMock,
return_value=(questions, submissions),
) as mock_fetch,
):
index, returned_questions = await _refresh_cache("form123")
mock_fetch.assert_awaited_once()
assert "alice@example.com" in index
assert "bob@example.com" in index
assert returned_questions == questions
# Verify Redis setex was called for index, questions, and last_fetch
assert mock_redis.setex.await_count == 3
@pytest.mark.asyncio
async def test_refresh_cache_incremental_fetch():
"""When last_fetch and index both exist, should do incremental fetch and merge."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
existing_index = {
"old@example.com": {"responses": [], "submitted_at": "2025-01-01"}
}
mock_redis = AsyncMock()
def mock_get(key):
if "last_fetch" in key:
return "2025-01-14T00:00:00Z"
if "email_index" in key:
return json.dumps(existing_index)
if "questions" in key:
return json.dumps(SAMPLE_QUESTIONS)
return None
mock_redis.get.side_effect = mock_get
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
with (
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
),
patch(
"backend.data.tally._fetch_all_submissions",
new_callable=AsyncMock,
return_value=(SAMPLE_QUESTIONS, new_submissions),
),
):
index, _ = await _refresh_cache("form123")
# Should contain both old and new entries
assert "old@example.com" in index
assert "alice@example.com" in index
# ── _make_tally_client ───────────────────────────────────────────────────────
def test_make_tally_client_returns_configured_client():
"""_make_tally_client should create a Requests client with auth headers."""
client = _make_tally_client("test-api-key")
assert client.extra_headers is not None
assert client.extra_headers.get("Authorization") == "Bearer test-api-key"
@pytest.mark.asyncio
async def test_fetch_tally_page_uses_provided_client():
"""_fetch_tally_page should use the passed client, not create its own."""
from backend.data.tally import _fetch_tally_page
mock_response = MagicMock()
mock_response.json.return_value = {"submissions": [], "questions": []}
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
result = await _fetch_tally_page(mock_client, "form123", page=1)
mock_client.get.assert_awaited_once()
call_url = mock_client.get.call_args[0][0]
assert "form123" in call_url
assert "page=1" in call_url
assert result == {"submissions": [], "questions": []}

View File

@@ -372,8 +372,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=600,
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
default=30,
description="The timeout in seconds for individual Agent Generator HTTP requests (submit and poll)",
)
agentgenerator_use_dummy: bool = Field(
default=False,
@@ -691,6 +691,15 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
screenshotone_api_key: str = Field(default="", description="ScreenshotOne API Key")
tally_api_key: str = Field(
default="",
description="Tally API key for form submission lookup on signup",
)
tally_form_id: str = Field(
default="npGe0q",
description="Tally form ID for signup business understanding form",
)
apollo_api_key: str = Field(default="", description="Apollo API Key")
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")

View File

@@ -109,7 +109,7 @@ class TestGenerateAgent:
instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await core.generate_agent(instructions)
mock_external.assert_called_once_with(instructions, None, None, None)
mock_external.assert_called_once_with(instructions, None)
assert result is not None
assert result["name"] == "Test Agent"
assert "id" in result
@@ -173,9 +173,7 @@ class TestGenerateAgentPatch:
current_agent = {"nodes": [], "links": []}
result = await core.generate_agent_patch("Add a node", current_agent)
mock_external.assert_called_once_with(
"Add a node", current_agent, None, None, None
)
mock_external.assert_called_once_with("Add a node", current_agent, None)
assert result == expected_result
@pytest.mark.asyncio

View File

@@ -2,7 +2,7 @@
Tests for the Agent Generator external service client.
This test suite verifies the external Agent Generator service integration,
including service detection, API calls, and error handling.
including service detection, async polling, and error handling.
"""
from unittest.mock import AsyncMock, MagicMock, patch
@@ -49,6 +49,292 @@ class TestServiceConfiguration:
assert url == "http://agent-generator.local:8000"
class TestSubmitAndPoll:
"""Test the _submit_and_poll helper that handles async job polling."""
def setup_method(self):
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_successful_submit_and_poll(self):
"""Test normal submit -> poll -> completed flow."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-123", "status": "accepted"}
submit_resp.raise_for_status = MagicMock()
poll_resp = MagicMock()
poll_resp.json.return_value = {
"job_id": "job-123",
"status": "completed",
"result": {"type": "instructions", "steps": ["Step 1"]},
}
poll_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
mock_client.get.return_value = poll_resp
with (
patch.object(service, "_get_client", return_value=mock_client),
patch("asyncio.sleep", new_callable=AsyncMock),
):
result = await service._submit_and_poll("/api/test", {"key": "value"})
assert result == {"type": "instructions", "steps": ["Step 1"]}
mock_client.post.assert_called_once_with("/api/test", json={"key": "value"})
mock_client.get.assert_called_once_with("/api/jobs/job-123")
@pytest.mark.asyncio
async def test_poll_returns_failed_job(self):
"""Test submit -> poll -> failed flow."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-456", "status": "accepted"}
submit_resp.raise_for_status = MagicMock()
poll_resp = MagicMock()
poll_resp.json.return_value = {
"job_id": "job-456",
"status": "failed",
"error": "Generation failed",
}
poll_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
mock_client.get.return_value = poll_resp
with (
patch.object(service, "_get_client", return_value=mock_client),
patch("asyncio.sleep", new_callable=AsyncMock),
):
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "job_failed"
assert "Generation failed" in result["error"]
@pytest.mark.asyncio
async def test_submit_http_error(self):
"""Test HTTP error during job submission."""
mock_response = MagicMock()
mock_response.status_code = 500
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.HTTPStatusError(
"Server error", request=MagicMock(), response=mock_response
)
with patch.object(service, "_get_client", return_value=mock_client):
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "http_error"
@pytest.mark.asyncio
async def test_submit_connection_error(self):
"""Test connection error during job submission."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "connection_error"
@pytest.mark.asyncio
async def test_no_job_id_in_submit_response(self):
"""Test submit response missing job_id."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"status": "accepted"} # no job_id
submit_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
with patch.object(service, "_get_client", return_value=mock_client):
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "invalid_response"
@pytest.mark.asyncio
async def test_poll_retries_on_transient_network_error(self):
"""Test that transient network errors during polling are retried."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-789"}
submit_resp.raise_for_status = MagicMock()
ok_poll_resp = MagicMock()
ok_poll_resp.json.return_value = {
"job_id": "job-789",
"status": "completed",
"result": {"data": "ok"},
}
ok_poll_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
# First poll fails with transient error, second succeeds
mock_client.get.side_effect = [
httpx.RequestError("transient"),
ok_poll_resp,
]
with (
patch.object(service, "_get_client", return_value=mock_client),
patch("asyncio.sleep", new_callable=AsyncMock),
):
result = await service._submit_and_poll("/api/test", {})
assert result == {"data": "ok"}
assert mock_client.get.call_count == 2
@pytest.mark.asyncio
async def test_poll_returns_404_for_expired_job(self):
"""Test that 404 during polling returns job_not_found error."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-expired"}
submit_resp.raise_for_status = MagicMock()
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
mock_client.get.side_effect = httpx.HTTPStatusError(
"Not Found", request=MagicMock(), response=mock_404_response
)
with (
patch.object(service, "_get_client", return_value=mock_client),
patch("asyncio.sleep", new_callable=AsyncMock),
):
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "job_not_found"
@pytest.mark.asyncio
async def test_poll_retries_on_transient_http_status(self):
"""Test that transient HTTP status codes (429, 503, etc.) are retried."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-transient"}
submit_resp.raise_for_status = MagicMock()
mock_429_response = MagicMock()
mock_429_response.status_code = 429
ok_poll_resp = MagicMock()
ok_poll_resp.json.return_value = {
"job_id": "job-transient",
"status": "completed",
"result": {"data": "recovered"},
}
ok_poll_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
mock_client.get.side_effect = [
httpx.HTTPStatusError(
"Too Many Requests", request=MagicMock(), response=mock_429_response
),
ok_poll_resp,
]
with (
patch.object(service, "_get_client", return_value=mock_client),
patch("asyncio.sleep", new_callable=AsyncMock),
):
result = await service._submit_and_poll("/api/test", {})
assert result == {"data": "recovered"}
assert mock_client.get.call_count == 2
@pytest.mark.asyncio
async def test_poll_does_not_retry_non_transient_http_status(self):
"""Test that non-transient HTTP status codes (e.g. 500) fail immediately."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-500"}
submit_resp.raise_for_status = MagicMock()
mock_500_response = MagicMock()
mock_500_response.status_code = 500
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
mock_client.get.side_effect = httpx.HTTPStatusError(
"Internal Server Error", request=MagicMock(), response=mock_500_response
)
with (
patch.object(service, "_get_client", return_value=mock_client),
patch("asyncio.sleep", new_callable=AsyncMock),
):
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "http_error"
assert mock_client.get.call_count == 1
@pytest.mark.asyncio
async def test_poll_timeout(self):
"""Test that polling times out after MAX_POLL_TIME_SECONDS."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-slow"}
submit_resp.raise_for_status = MagicMock()
running_resp = MagicMock()
running_resp.json.return_value = {"job_id": "job-slow", "status": "running"}
running_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
mock_client.get.return_value = running_resp
# Simulate time passing: first call returns 0.0 (start), then jumps past limit
monotonic_values = iter([0.0, 0.0, 100.0])
with (
patch.object(service, "_get_client", return_value=mock_client),
patch.object(service, "MAX_POLL_TIME_SECONDS", 50.0),
patch.object(service, "POLL_INTERVAL_SECONDS", 0.01),
patch("asyncio.sleep", new_callable=AsyncMock),
patch("backend.copilot.tools.agent_generator.service.time") as mock_time,
):
mock_time.monotonic.side_effect = monotonic_values
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "timeout"
@pytest.mark.asyncio
async def test_poll_gives_up_after_consecutive_transient_errors(self):
"""Test that polling gives up after MAX_CONSECUTIVE_POLL_ERRORS."""
submit_resp = MagicMock()
submit_resp.json.return_value = {"job_id": "job-flaky"}
submit_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = submit_resp
mock_client.get.side_effect = httpx.RequestError("network down")
# Ensure monotonic always returns 0 so timeout doesn't kick in
with (
patch.object(service, "_get_client", return_value=mock_client),
patch.object(service, "MAX_POLL_TIME_SECONDS", 9999.0),
patch.object(service, "POLL_INTERVAL_SECONDS", 0.01),
patch("asyncio.sleep", new_callable=AsyncMock),
patch("backend.copilot.tools.agent_generator.service.time") as mock_time,
):
mock_time.monotonic.return_value = 0.0
result = await service._submit_and_poll("/api/test", {})
assert result["type"] == "error"
assert result["error_type"] == "poll_error"
assert mock_client.get.call_count == service.MAX_CONSECUTIVE_POLL_ERRORS
class TestDecomposeGoalExternal:
"""Test decompose_goal_external function."""
@@ -60,40 +346,37 @@ class TestDecomposeGoalExternal:
@pytest.mark.asyncio
async def test_decompose_goal_returns_instructions(self):
"""Test successful decomposition returning instructions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1", "Step 2"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"type": "instructions",
"steps": ["Step 1", "Step 2"],
}
result = await service.decompose_goal_external("Build a chatbot")
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
mock_client.post.assert_called_once_with(
"/api/decompose-description", json={"description": "Build a chatbot"}
mock_poll.assert_called_once_with(
"/api/decompose-description",
{"description": "Build a chatbot"},
)
@pytest.mark.asyncio
async def test_decompose_goal_returns_clarifying_questions(self):
"""Test decomposition returning clarifying questions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "clarifying_questions",
"questions": ["What platform?", "What language?"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"type": "clarifying_questions",
"questions": ["What platform?", "What language?"],
}
result = await service.decompose_goal_external("Build something")
assert result == {
@@ -104,18 +387,13 @@ class TestDecomposeGoalExternal:
@pytest.mark.asyncio
async def test_decompose_goal_with_context(self):
"""Test decomposition with additional context enriched into description."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
await service.decompose_goal_external(
"Build a chatbot", context="Use Python"
)
@@ -123,27 +401,25 @@ class TestDecomposeGoalExternal:
expected_description = (
"Build a chatbot\n\nAdditional context from user:\nUse Python"
)
mock_client.post.assert_called_once_with(
mock_poll.assert_called_once_with(
"/api/decompose-description",
json={"description": expected_description},
{"description": expected_description},
)
@pytest.mark.asyncio
async def test_decompose_goal_returns_unachievable_goal(self):
"""Test decomposition returning unachievable goal response."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "unachievable_goal",
"reason": "Cannot do X",
"suggested_goal": "Try Y instead",
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"type": "unachievable_goal",
"reason": "Cannot do X",
"suggested_goal": "Try Y instead",
}
result = await service.decompose_goal_external("Do something impossible")
assert result == {
@@ -153,58 +429,40 @@ class TestDecomposeGoalExternal:
}
@pytest.mark.asyncio
async def test_decompose_goal_handles_http_error(self):
"""Test decomposition handles HTTP errors gracefully."""
mock_response = MagicMock()
mock_response.status_code = 500
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.HTTPStatusError(
"Server error", request=MagicMock(), response=mock_response
)
with patch.object(service, "_get_client", return_value=mock_client):
async def test_decompose_goal_handles_poll_error(self):
"""Test that errors from _submit_and_poll are passed through."""
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"type": "error",
"error": "HTTP error calling Agent Generator: Server error",
"error_type": "http_error",
}
result = await service.decompose_goal_external("Build a chatbot")
assert result is not None
assert result.get("type") == "error"
assert result.get("error_type") == "http_error"
assert "Server error" in result.get("error", "")
@pytest.mark.asyncio
async def test_decompose_goal_handles_request_error(self):
"""Test decomposition handles request errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
async def test_decompose_goal_handles_unexpected_exception(self):
"""Test that unexpected exceptions are caught and returned as errors."""
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.side_effect = RuntimeError("unexpected")
result = await service.decompose_goal_external("Build a chatbot")
assert result is not None
assert result.get("type") == "error"
assert result.get("error_type") == "connection_error"
assert "Connection failed" in result.get("error", "")
@pytest.mark.asyncio
async def test_decompose_goal_handles_service_error(self):
"""Test decomposition handles service returning error."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": False,
"error": "Internal error",
"error_type": "internal_error",
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result is not None
assert result.get("type") == "error"
assert result.get("error") == "Internal error"
assert result.get("error_type") == "internal_error"
assert result.get("error_type") == "unexpected_error"
class TestGenerateAgentExternal:
@@ -223,39 +481,59 @@ class TestGenerateAgentExternal:
"nodes": [],
"links": [],
}
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": agent_json,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {"success": True, "agent_json": agent_json}
instructions = {"type": "instructions", "steps": ["Step 1"]}
with patch.object(service, "_get_client", return_value=mock_client):
instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await service.generate_agent_external(instructions)
assert result == agent_json
mock_client.post.assert_called_once_with(
"/api/generate-agent", json={"instructions": instructions}
mock_poll.assert_called_once_with(
"/api/generate-agent",
{"instructions": instructions},
)
@pytest.mark.asyncio
async def test_generate_agent_handles_error(self):
"""Test agent generation handles errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"type": "error",
"error": "Connection failed",
"error_type": "connection_error",
}
result = await service.generate_agent_external({"steps": []})
assert result is not None
assert result.get("type") == "error"
assert result.get("error_type") == "connection_error"
assert "Connection failed" in result.get("error", "")
@pytest.mark.asyncio
async def test_generate_agent_missing_agent_json(self):
"""Test that missing agent_json in result returns an error."""
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {"success": True}
result = await service.generate_agent_external({"steps": ["Step 1"]})
assert result is not None
assert result.get("type") == "error"
assert result.get("error_type") == "invalid_response"
class TestGenerateAgentPatchExternal:
@@ -274,27 +552,24 @@ class TestGenerateAgentPatchExternal:
"nodes": [{"id": "1", "block_id": "test"}],
"links": [],
}
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": updated_agent,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {"success": True, "agent_json": updated_agent}
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
with patch.object(service, "_get_client", return_value=mock_client):
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
result = await service.generate_agent_patch_external(
"Add a new node", current_agent
)
assert result == updated_agent
mock_client.post.assert_called_once_with(
mock_poll.assert_called_once_with(
"/api/update-agent",
json={
{
"update_request": "Add a new node",
"current_agent_json": current_agent,
},
@@ -303,18 +578,16 @@ class TestGenerateAgentPatchExternal:
@pytest.mark.asyncio
async def test_generate_patch_returns_clarifying_questions(self):
"""Test patch generation returning clarifying questions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "clarifying_questions",
"questions": ["What type of node?"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"type": "clarifying_questions",
"questions": ["What type of node?"],
}
result = await service.generate_agent_patch_external(
"Add something", {"nodes": []}
)
@@ -355,9 +628,12 @@ class TestHealthCheck:
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
with (
patch.object(service, "is_external_service_configured", return_value=True),
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(service, "_get_client", return_value=mock_client),
):
result = await service.health_check()
assert result is True
mock_client.get.assert_called_once_with("/health")
@@ -375,9 +651,12 @@ class TestHealthCheck:
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
with (
patch.object(service, "is_external_service_configured", return_value=True),
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(service, "_get_client", return_value=mock_client),
):
result = await service.health_check()
assert result is False
@@ -387,9 +666,12 @@ class TestHealthCheck:
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
with (
patch.object(service, "is_external_service_configured", return_value=True),
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(service, "_get_client", return_value=mock_client),
):
result = await service.health_check()
assert result is False
@@ -419,7 +701,10 @@ class TestGetBlocksExternal:
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(service, "_get_client", return_value=mock_client),
):
result = await service.get_blocks_external()
assert result == blocks
@@ -431,7 +716,10 @@ class TestGetBlocksExternal:
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(service, "_get_client", return_value=mock_client),
):
result = await service.get_blocks_external()
assert result is None
@@ -459,26 +747,22 @@ class TestLibraryAgentsPassthrough:
},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
await service.decompose_goal_external(
"Send an email",
library_agents=library_agents,
)
# Verify library_agents was passed in the payload
call_args = mock_client.post.call_args
assert call_args[1]["json"]["library_agents"] == library_agents
call_args = mock_poll.call_args
payload = call_args[0][1]
assert payload["library_agents"] == library_agents
@pytest.mark.asyncio
async def test_generate_agent_passes_library_agents(self):
@@ -494,25 +778,24 @@ class TestLibraryAgentsPassthrough:
},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": {"name": "Test Agent", "nodes": []},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"agent_json": {"name": "Test Agent", "nodes": []},
}
await service.generate_agent_external(
{"steps": ["Step 1"]},
library_agents=library_agents,
)
# Verify library_agents was passed in the payload
call_args = mock_client.post.call_args
assert call_args[1]["json"]["library_agents"] == library_agents
call_args = mock_poll.call_args
payload = call_args[0][1]
assert payload["library_agents"] == library_agents
@pytest.mark.asyncio
async def test_generate_agent_patch_passes_library_agents(self):
@@ -528,17 +811,15 @@ class TestLibraryAgentsPassthrough:
},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": {"name": "Updated Agent", "nodes": []},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {
"agent_json": {"name": "Updated Agent", "nodes": []},
}
await service.generate_agent_patch_external(
"Add error handling",
{"name": "Original Agent", "nodes": []},
@@ -546,29 +827,26 @@ class TestLibraryAgentsPassthrough:
)
# Verify library_agents was passed in the payload
call_args = mock_client.post.call_args
assert call_args[1]["json"]["library_agents"] == library_agents
call_args = mock_poll.call_args
payload = call_args[0][1]
assert payload["library_agents"] == library_agents
@pytest.mark.asyncio
async def test_decompose_goal_without_library_agents(self):
"""Test that decompose goal works without library_agents."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
with (
patch.object(service, "_is_dummy_mode", return_value=False),
patch.object(
service, "_submit_and_poll", new_callable=AsyncMock
) as mock_poll,
):
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
await service.decompose_goal_external("Build a workflow")
# Verify library_agents was NOT passed when not provided
call_args = mock_client.post.call_args
assert "library_agents" not in call_args[1]["json"]
call_args = mock_poll.call_args
payload = call_args[0][1]
assert "library_agents" not in payload
if __name__ == "__main__":

View File

@@ -1,349 +0,0 @@
#!/usr/bin/env python3
"""
Integration test for the requeue fix implementation.
Tests actual RabbitMQ behavior to verify that republishing sends messages to back of queue.
"""
import json
import time
from threading import Event
from typing import List
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
class QueueOrderTester:
"""Helper class to test message ordering in RabbitMQ using a dedicated test queue."""
def __init__(self):
self.received_messages: List[dict] = []
self.stop_consuming = Event()
self.queue_client = SyncRabbitMQ(create_execution_queue_config())
self.queue_client.connect()
# Use a dedicated test queue name to avoid conflicts
self.test_queue_name = "test_requeue_ordering"
self.test_exchange = "test_exchange"
self.test_routing_key = "test.requeue"
def setup_queue(self):
"""Set up a dedicated test queue for testing."""
channel = self.queue_client.get_channel()
# Declare test exchange
channel.exchange_declare(
exchange=self.test_exchange, exchange_type="direct", durable=True
)
# Declare test queue
channel.queue_declare(
queue=self.test_queue_name, durable=True, auto_delete=False
)
# Bind queue to exchange
channel.queue_bind(
exchange=self.test_exchange,
queue=self.test_queue_name,
routing_key=self.test_routing_key,
)
# Purge the queue to start fresh
channel.queue_purge(self.test_queue_name)
print(f"✅ Test queue {self.test_queue_name} setup and purged")
def create_test_message(self, message_id: str, user_id: str = "test-user") -> str:
"""Create a test graph execution message."""
return json.dumps(
{
"graph_exec_id": f"exec-{message_id}",
"graph_id": f"graph-{message_id}",
"user_id": user_id,
"execution_context": {"timezone": "UTC"},
"nodes_input_masks": {},
"starting_nodes_input": [],
}
)
def publish_message(self, message: str):
"""Publish a message to the test queue."""
channel = self.queue_client.get_channel()
channel.basic_publish(
exchange=self.test_exchange,
routing_key=self.test_routing_key,
body=message,
)
def consume_messages(self, max_messages: int = 10, timeout: float = 5.0):
"""Consume messages and track their order."""
def callback(ch, method, properties, body):
try:
message_data = json.loads(body.decode())
self.received_messages.append(message_data)
ch.basic_ack(delivery_tag=method.delivery_tag)
if len(self.received_messages) >= max_messages:
self.stop_consuming.set()
except Exception as e:
print(f"Error processing message: {e}")
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
# Use synchronous consumption with blocking
channel = self.queue_client.get_channel()
# Check if there are messages in the queue first
method_frame, header_frame, body = channel.basic_get(
queue=self.test_queue_name, auto_ack=False
)
if method_frame:
# There are messages, set up consumer
channel.basic_nack(
delivery_tag=method_frame.delivery_tag, requeue=True
) # Put message back
# Set up consumer
channel.basic_consume(
queue=self.test_queue_name,
on_message_callback=callback,
)
# Consume with timeout
start_time = time.time()
while (
not self.stop_consuming.is_set()
and (time.time() - start_time) < timeout
and len(self.received_messages) < max_messages
):
try:
channel.connection.process_data_events(time_limit=0.1)
except Exception as e:
print(f"Error during consumption: {e}")
break
# Cancel the consumer
try:
channel.cancel()
except Exception:
pass
else:
# No messages in queue - this might be expected for some tests
pass
return self.received_messages
def cleanup(self):
"""Clean up test resources."""
try:
channel = self.queue_client.get_channel()
channel.queue_delete(queue=self.test_queue_name)
channel.exchange_delete(exchange=self.test_exchange)
print(f"✅ Test queue {self.test_queue_name} cleaned up")
except Exception as e:
print(f"⚠️ Cleanup issue: {e}")
def test_queue_ordering_behavior():
"""
Integration test to verify that our republishing method sends messages to back of queue.
This tests the actual fix for the rate limiting queue blocking issue.
"""
tester = QueueOrderTester()
try:
tester.setup_queue()
print("🧪 Testing actual RabbitMQ queue ordering behavior...")
# Test 1: Normal FIFO behavior
print("1. Testing normal FIFO queue behavior")
# Publish messages in order: A, B, C
msg_a = tester.create_test_message("A")
msg_b = tester.create_test_message("B")
msg_c = tester.create_test_message("C")
tester.publish_message(msg_a)
tester.publish_message(msg_b)
tester.publish_message(msg_c)
# Consume and verify FIFO order: A, B, C
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=3)
assert len(messages) == 3, f"Expected 3 messages, got {len(messages)}"
assert (
messages[0]["graph_exec_id"] == "exec-A"
), f"First message should be A, got {messages[0]['graph_exec_id']}"
assert (
messages[1]["graph_exec_id"] == "exec-B"
), f"Second message should be B, got {messages[1]['graph_exec_id']}"
assert (
messages[2]["graph_exec_id"] == "exec-C"
), f"Third message should be C, got {messages[2]['graph_exec_id']}"
print("✅ FIFO order confirmed: A -> B -> C")
# Test 2: Rate limiting simulation - the key test!
print("2. Testing rate limiting fix scenario")
# Simulate the scenario where user1 is rate limited
user1_msg = tester.create_test_message("RATE-LIMITED", "user1")
user2_msg1 = tester.create_test_message("USER2-1", "user2")
user2_msg2 = tester.create_test_message("USER2-2", "user2")
# Initially publish user1 message (gets consumed, then rate limited on retry)
tester.publish_message(user1_msg)
# Other users publish their messages
tester.publish_message(user2_msg1)
tester.publish_message(user2_msg2)
# Now simulate: user1 message gets "requeued" using our new republishing method
# This is what happens in manager.py when requeue_by_republishing=True
tester.publish_message(user1_msg) # Goes to back via our method
# Expected order: RATE-LIMITED, USER2-1, USER2-2, RATE-LIMITED (republished to back)
# This shows that user2 messages get processed instead of being blocked
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=4)
assert len(messages) == 4, f"Expected 4 messages, got {len(messages)}"
# The key verification: user2 messages are NOT blocked by user1's rate-limited message
user2_messages = [msg for msg in messages if msg["user_id"] == "user2"]
assert len(user2_messages) == 2, "Both user2 messages should be processed"
assert user2_messages[0]["graph_exec_id"] == "exec-USER2-1"
assert user2_messages[1]["graph_exec_id"] == "exec-USER2-2"
print("✅ Rate limiting fix confirmed: user2 executions NOT blocked by user1")
# Test 3: Verify our method behaves like going to back of queue
print("3. Testing republishing sends messages to back")
# Start with message X in queue
msg_x = tester.create_test_message("X")
tester.publish_message(msg_x)
# Add message Y
msg_y = tester.create_test_message("Y")
tester.publish_message(msg_y)
# Republish X (simulates requeue using our method)
tester.publish_message(msg_x)
# Expected: X, Y, X (X was republished to back)
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=3)
assert len(messages) == 3
# Y should come before the republished X
y_index = next(
i for i, msg in enumerate(messages) if msg["graph_exec_id"] == "exec-Y"
)
republished_x_index = next(
i
for i, msg in enumerate(messages[1:], 1)
if msg["graph_exec_id"] == "exec-X"
)
assert (
y_index < republished_x_index
), f"Y should come before republished X, but got order: {[m['graph_exec_id'] for m in messages]}"
print("✅ Republishing confirmed: messages go to back of queue")
print("🎉 All integration tests passed!")
print("🎉 Our republishing method works correctly with real RabbitMQ")
print("🎉 Queue blocking issue is fixed!")
finally:
tester.cleanup()
def test_traditional_requeue_behavior():
"""
Test that traditional requeue (basic_nack with requeue=True) sends messages to FRONT of queue.
This validates our hypothesis about why queue blocking occurs.
"""
tester = QueueOrderTester()
try:
tester.setup_queue()
print("🧪 Testing traditional requeue behavior (basic_nack with requeue=True)")
# Step 1: Publish message A
msg_a = tester.create_test_message("A")
tester.publish_message(msg_a)
# Step 2: Publish message B
msg_b = tester.create_test_message("B")
tester.publish_message(msg_b)
# Step 3: Consume message A and requeue it using traditional method
channel = tester.queue_client.get_channel()
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=False
)
assert method_frame is not None, "Should have received message A"
consumed_msg = json.loads(body.decode())
assert (
consumed_msg["graph_exec_id"] == "exec-A"
), f"Should have consumed message A, got {consumed_msg['graph_exec_id']}"
# Traditional requeue: basic_nack with requeue=True (sends to FRONT)
channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True)
print(f"🔄 Traditional requeue (to FRONT): {consumed_msg['graph_exec_id']}")
# Step 4: Consume all messages using basic_get for reliability
received_messages = []
# Get first message
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=True
)
if method_frame:
msg = json.loads(body.decode())
received_messages.append(msg)
# Get second message
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=True
)
if method_frame:
msg = json.loads(body.decode())
received_messages.append(msg)
# CRITICAL ASSERTION: Traditional requeue should put A at FRONT
# Expected order: A (requeued to front), B
assert (
len(received_messages) == 2
), f"Expected 2 messages, got {len(received_messages)}"
first_msg = received_messages[0]["graph_exec_id"]
second_msg = received_messages[1]["graph_exec_id"]
# This is the critical test: requeued message A should come BEFORE B
assert (
first_msg == "exec-A"
), f"Traditional requeue should put A at FRONT, but first message was: {first_msg}"
assert (
second_msg == "exec-B"
), f"B should come after requeued A, but second message was: {second_msg}"
print(
"✅ HYPOTHESIS CONFIRMED: Traditional requeue sends messages to FRONT of queue"
)
print(f" Order: {first_msg} (requeued to front) → {second_msg}")
print(" This explains why rate-limited messages block other users!")
finally:
tester.cleanup()
if __name__ == "__main__":
test_queue_ordering_behavior()

View File

@@ -27,6 +27,7 @@ export function CopilotPage() {
createSession,
onSend,
isLoadingSession,
isSessionError,
isCreatingSession,
isUserLoading,
isLoggedIn,
@@ -71,6 +72,7 @@ export function CopilotPage() {
error={error}
sessionId={sessionId}
isLoadingSession={isLoadingSession}
isSessionError={isSessionError}
isCreatingSession={isCreatingSession}
isReconnecting={isReconnecting}
onCreateSession={createSession}

View File

@@ -13,6 +13,7 @@ export interface ChatContainerProps {
error: Error | undefined;
sessionId: string | null;
isLoadingSession: boolean;
isSessionError?: boolean;
isCreatingSession: boolean;
/** True when backend has an active stream but we haven't reconnected yet. */
isReconnecting?: boolean;
@@ -27,6 +28,7 @@ export const ChatContainer = ({
error,
sessionId,
isLoadingSession,
isSessionError,
isCreatingSession,
isReconnecting,
onCreateSession,
@@ -34,7 +36,12 @@ export const ChatContainer = ({
onStop,
headerSlot,
}: ChatContainerProps) => {
const isBusy = status === "streaming" || !!isReconnecting;
const isBusy =
status === "streaming" ||
status === "submitted" ||
!!isReconnecting ||
isLoadingSession ||
!!isSessionError;
const inputLayoutId = "copilot-2-chat-input";
return (

View File

@@ -10,9 +10,8 @@ import {
MessageResponse,
} from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { toast } from "@/components/molecules/Toast/use-toast";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useRef, useState } from "react";
import { useEffect, useState } from "react";
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
import {
@@ -129,7 +128,6 @@ export const ChatMessagesContainer = ({
headerSlot,
}: ChatMessagesContainerProps) => {
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
const lastToastTimeRef = useRef(0);
useEffect(() => {
if (status === "submitted") {
@@ -137,20 +135,6 @@ export const ChatMessagesContainer = ({
}
}, [status]);
// Show a toast when a new error occurs, debounced to avoid spam
useEffect(() => {
if (!error) return;
const now = Date.now();
if (now - lastToastTimeRef.current < 3_000) return;
lastToastTimeRef.current = now;
toast({
variant: "destructive",
title: "Something went wrong",
description:
"The assistant encountered an error. Please try sending your message again.",
});
}, [error]);
const lastMessage = messages[messages.length - 1];
const lastAssistantHasVisibleContent =
lastMessage?.role === "assistant" &&
@@ -314,13 +298,15 @@ export const ChatMessagesContainer = ({
</Message>
)}
{error && (
<div className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
<p className="font-medium">Something went wrong</p>
<p className="mt-1 text-red-600">
<details className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
<summary className="cursor-pointer font-medium">
The assistant encountered an error. Please try sending your
message again.
</p>
</div>
</summary>
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words text-xs text-red-600">
{error instanceof Error ? error.message : String(error)}
</pre>
</details>
)}
</ConversationContent>
<ConversationScrollButton />

View File

@@ -116,12 +116,10 @@ export function convertChatSessionMessagesToUiMessages(
output: "",
});
} else {
parts.push({
type: `tool-${toolName}`,
toolCallId,
state: "input-available",
input,
});
// Active stream exists: Skip incomplete tool calls during hydration.
// The resume stream will deliver them fresh with proper SDK state.
// This prevents "No tool invocation found" errors on page refresh.
continue;
}
}
}

View File

@@ -1,47 +0,0 @@
import { useEffect, useRef, useState } from "react";
/**
* Hook that returns a progress value that starts fast and slows down,
* asymptotically approaching but never reaching the max value.
*
* Uses a half-life formula: progress = max * (1 - 0.5^(time/halfLife))
* This creates a "loading bar" effect where:
* - 50% is reached at halfLifeSeconds
* - 75% is reached at 2 * halfLifeSeconds
* - 87.5% is reached at 3 * halfLifeSeconds
*
* @param isActive - Whether the progress should be animating
* @param halfLifeSeconds - Time in seconds to reach 50% progress (default: 30)
* @param maxProgress - Maximum progress value to approach (default: 100)
* @param intervalMs - Update interval in milliseconds (default: 100)
* @returns Current progress value (0maxProgress)
*/
export function useAsymptoticProgress(
isActive: boolean,
halfLifeSeconds = 30,
maxProgress = 100,
intervalMs = 100,
) {
const [progress, setProgress] = useState(0);
const elapsedTimeRef = useRef(0);
useEffect(() => {
if (!isActive) {
setProgress(0);
elapsedTimeRef.current = 0;
return;
}
const interval = setInterval(() => {
elapsedTimeRef.current += intervalMs / 1000;
const newProgress =
maxProgress *
(1 - Math.pow(0.5, elapsedTimeRef.current / halfLifeSeconds));
setProgress(newProgress);
}, intervalMs);
return () => clearInterval(interval);
}, [isActive, halfLifeSeconds, maxProgress, intervalMs]);
return progress;
}

View File

@@ -1,126 +0,0 @@
import { getGetV2GetSessionQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
import { useQueryClient } from "@tanstack/react-query";
import type { UIDataTypes, UIMessage, UITools } from "ai";
import { useCallback, useEffect, useRef } from "react";
import { convertChatSessionMessagesToUiMessages } from "../helpers/convertChatSessionToUiMessages";
const OPERATING_TYPES = new Set([
"operation_started",
"operation_pending",
"operation_in_progress",
]);
const POLL_INTERVAL_MS = 1_500;
/**
* Detects whether any message contains a tool part whose output indicates
* a long-running operation is still in progress.
*/
function hasOperatingTool(
messages: UIMessage<unknown, UIDataTypes, UITools>[],
) {
for (const msg of messages) {
for (const part of msg.parts) {
if (!part.type.startsWith("tool-")) continue;
const toolPart = part as { output?: unknown };
if (!toolPart.output) continue;
const output =
typeof toolPart.output === "string"
? safeParse(toolPart.output)
: toolPart.output;
if (
output &&
typeof output === "object" &&
"type" in output &&
OPERATING_TYPES.has((output as { type: string }).type)
) {
return true;
}
}
}
return false;
}
function safeParse(value: string): unknown {
try {
return JSON.parse(value);
} catch {
return null;
}
}
/**
* Polls the session endpoint while any tool is in an "operating" state
* (operation_started / operation_pending / operation_in_progress).
*
* When the session data shows the tool output has changed (e.g. to
* agent_saved), it calls `setMessages` with the updated messages.
*/
export function useLongRunningToolPolling(
sessionId: string | null,
messages: UIMessage<unknown, UIDataTypes, UITools>[],
setMessages: (
updater: (
prev: UIMessage<unknown, UIDataTypes, UITools>[],
) => UIMessage<unknown, UIDataTypes, UITools>[],
) => void,
) {
const queryClient = useQueryClient();
const intervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
const stopPolling = useCallback(() => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
}, []);
const poll = useCallback(async () => {
if (!sessionId) return;
// Invalidate the query cache so the next fetch gets fresh data
await queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
// Fetch fresh session data
const data = queryClient.getQueryData<{
status: number;
data: { messages?: unknown[] };
}>(getGetV2GetSessionQueryKey(sessionId));
if (data?.status !== 200 || !data.data.messages) return;
const freshMessages = convertChatSessionMessagesToUiMessages(
sessionId,
data.data.messages,
);
if (!freshMessages || freshMessages.length === 0) return;
// Update when the long-running tool completed
if (!hasOperatingTool(freshMessages)) {
setMessages(() => freshMessages);
stopPolling();
}
}, [sessionId, queryClient, setMessages, stopPolling]);
useEffect(() => {
const shouldPoll = hasOperatingTool(messages);
// Always clear any previous interval first so we never leak timers
// when the effect re-runs due to dependency changes (e.g. messages
// updating as the LLM streams text after the tool call).
stopPolling();
if (shouldPoll && sessionId) {
intervalRef.current = setInterval(() => {
poll();
}, POLL_INTERVAL_MS);
}
return () => {
stopPolling();
};
}, [messages, sessionId, poll, stopPolling]);
}

View File

@@ -1120,56 +1120,6 @@ export default function StyleguidePage() {
/>
</SubSection>
<SubSection label="Output available (operation started)">
<CreateAgentTool
part={{
type: "tool-create_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_started,
operation_id: "op-create-123",
tool_name: "create_agent",
message:
"Agent creation has been started. This may take a moment.",
},
}}
/>
</SubSection>
<SubSection label="Output available (operation pending)">
<CreateAgentTool
part={{
type: "tool-create_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_pending,
operation_id: "op-create-123",
tool_name: "create_agent",
message:
"Agent creation is queued and will begin shortly.",
},
}}
/>
</SubSection>
<SubSection label="Output available (operation in progress)">
<CreateAgentTool
part={{
type: "tool-create_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_in_progress,
tool_call_id: "tc-456",
message:
"An agent creation operation is already in progress. Please wait for it to finish.",
},
}}
/>
</SubSection>
<SubSection label="Output available (agent preview)">
<CreateAgentTool
part={{
@@ -1292,22 +1242,6 @@ export default function StyleguidePage() {
/>
</SubSection>
<SubSection label="Output available (operation started)">
<EditAgentTool
part={{
type: "tool-edit_agent",
toolCallId: uid(),
state: "output-available",
output: {
type: ResponseType.operation_started,
operation_id: "op-edit-456",
tool_name: "edit_agent",
message: "Agent editing has started.",
},
}}
/>
</SubSection>
<SubSection label="Output available (agent preview)">
<EditAgentTool
part={{

View File

@@ -16,6 +16,7 @@ import {
ContentCardDescription,
ContentCodeBlock,
ContentGrid,
ContentHint,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
@@ -35,9 +36,6 @@ import {
isAgentSavedOutput,
isClarificationNeededOutput,
isErrorOutput,
isOperationInProgressOutput,
isOperationPendingOutput,
isOperationStartedOutput,
isSuggestedGoalOutput,
ToolIcon,
truncateText,
@@ -56,9 +54,18 @@ interface Props {
part: CreateAgentToolPart;
}
function getAccordionMeta(output: CreateAgentToolOutput) {
function getAccordionMeta(output: CreateAgentToolOutput | null) {
const icon = <AccordionIcon />;
if (!output) {
return {
icon,
title:
"Creating agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
}
if (isAgentSavedOutput(output)) {
return { icon, title: output.agent_name, expanded: true };
}
@@ -85,16 +92,6 @@ function getAccordionMeta(output: CreateAgentToolOutput) {
expanded: true,
};
}
if (
isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output)
) {
return {
icon,
title: output.message || "Agent creation started",
};
}
return {
icon: (
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
@@ -116,23 +113,11 @@ export function CreateAgentTool({ part }: Props) {
const isError =
part.state === "output-error" || (!!output && isErrorOutput(output));
const isOperating =
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output));
const isOperating = !output;
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output) ||
isAgentPreviewOutput(output) ||
isAgentSavedOutput(output) ||
isClarificationNeededOutput(output) ||
isSuggestedGoalOutput(output) ||
isErrorOutput(output));
// Show accordion for operating state and successful outputs, but not for errors
// (errors are shown inline so they get replaced when retrying)
const hasExpandableContent = !isError;
function handleUseSuggestedGoal(goal: string) {
onSend(`Please create an agent with this goal: ${goal}`);
@@ -158,33 +143,77 @@ export function CreateAgentTool({ part }: Props) {
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isStreaming && (
<ToolAccordion
icon={<AccordionIcon />}
title="Creating agent, this may take a few minutes. Play while you wait."
expanded
>
<ContentGrid>
<MiniGame />
</ContentGrid>
</ToolAccordion>
{isOperating && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
)}
{hasExpandableContent && output && (
{isError && output && isErrorOutput(output) && (
<div className="space-y-3 rounded-lg border border-red-200 bg-red-50 p-4">
<div className="flex items-start gap-2">
<WarningDiamondIcon
size={20}
weight="regular"
className="mt-0.5 shrink-0 text-red-500"
/>
<div className="flex-1 space-y-2">
<Text variant="body-medium" className="text-red-900">
{output.message ||
"Failed to generate the agent. Please try again."}
</Text>
{output.error && (
<details className="text-xs text-red-700">
<summary className="cursor-pointer font-medium">
Technical details
</summary>
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2">
{formatMaybeJson(output.error)}
</pre>
</details>
)}
{output.details && (
<pre className="max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2 text-xs text-red-700">
{formatMaybeJson(output.details)}
</pre>
)}
</div>
</div>
<div className="flex gap-2">
<Button
variant="outline"
size="small"
onClick={() => onSend("Please try creating the agent again.")}
>
Try again
</Button>
<Button
variant="outline"
size="small"
onClick={() => onSend("Can you help me simplify this goal?")}
>
Simplify goal
</Button>
</div>
</div>
)}
{hasExpandableContent && (
<ToolAccordion {...getAccordionMeta(output)}>
{isOperating && output.message && (
<ContentMessage>{output.message}</ContentMessage>
{isOperating && (
<ContentGrid>
<MiniGame />
<ContentHint>
This could take a few minutes play while you wait!
</ContentHint>
</ContentGrid>
)}
{isAgentSavedOutput(output) && (
{output && isAgentSavedOutput(output) && (
<div className="rounded-xl border border-border/60 bg-card p-4 shadow-sm">
<div className="flex items-baseline gap-2">
<Image
@@ -230,7 +259,7 @@ export function CreateAgentTool({ part }: Props) {
</div>
)}
{isAgentPreviewOutput(output) && (
{output && isAgentPreviewOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.description?.trim() && (
@@ -244,7 +273,7 @@ export function CreateAgentTool({ part }: Props) {
</ContentGrid>
)}
{isClarificationNeededOutput(output) && (
{output && isClarificationNeededOutput(output) && (
<ClarificationQuestionsCard
questions={(output.questions ?? []).map((q) => {
const item: ClarifyingQuestion = {
@@ -263,7 +292,7 @@ export function CreateAgentTool({ part }: Props) {
/>
)}
{isSuggestedGoalOutput(output) && (
{output && isSuggestedGoalOutput(output) && (
<SuggestedGoalCard
message={output.message}
suggestedGoal={output.suggested_goal}
@@ -272,38 +301,6 @@ export function CreateAgentTool({ part }: Props) {
onUseSuggestedGoal={handleUseSuggestedGoal}
/>
)}
{isErrorOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.error && (
<ContentCodeBlock>
{formatMaybeJson(output.error)}
</ContentCodeBlock>
)}
{output.details && (
<ContentCodeBlock>
{formatMaybeJson(output.details)}
</ContentCodeBlock>
)}
<div className="flex gap-2">
<Button
variant="outline"
size="small"
onClick={() => onSend("Please try creating the agent again.")}
>
Try again
</Button>
<Button
variant="outline"
size="small"
onClick={() => onSend("Can you help me simplify this goal?")}
>
Simplify goal
</Button>
</div>
</ContentGrid>
)}
</ToolAccordion>
)}
</div>

View File

@@ -2,9 +2,6 @@ import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentP
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import type { SuggestedGoalResponse } from "@/app/api/__generated__/models/suggestedGoalResponse";
import {
@@ -16,9 +13,6 @@ import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
export type CreateAgentToolOutput =
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
@@ -39,9 +33,6 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (
type === ResponseType.operation_started ||
type === ResponseType.operation_pending ||
type === ResponseType.operation_in_progress ||
type === ResponseType.agent_preview ||
type === ResponseType.agent_saved ||
type === ResponseType.clarification_needed ||
@@ -50,9 +41,6 @@ function parseOutput(output: unknown): CreateAgentToolOutput | null {
) {
return output as CreateAgentToolOutput;
}
if ("operation_id" in output && "tool_name" in output)
return output as OperationStartedResponse | OperationPendingResponse;
if ("tool_call_id" in output) return output as OperationInProgressResponse;
if ("agent_json" in output && "agent_name" in output)
return output as AgentPreviewResponse;
if ("agent_id" in output && "library_agent_id" in output)
@@ -72,30 +60,6 @@ export function getCreateAgentToolOutput(
return parseOutput((part as { output?: unknown }).output);
}
export function isOperationStartedOutput(
output: CreateAgentToolOutput,
): output is OperationStartedResponse {
return (
output.type === ResponseType.operation_started ||
("operation_id" in output && "tool_name" in output)
);
}
export function isOperationPendingOutput(
output: CreateAgentToolOutput,
): output is OperationPendingResponse {
return output.type === ResponseType.operation_pending;
}
export function isOperationInProgressOutput(
output: CreateAgentToolOutput,
): output is OperationInProgressResponse {
return (
output.type === ResponseType.operation_in_progress ||
"tool_call_id" in output
);
}
export function isAgentPreviewOutput(
output: CreateAgentToolOutput,
): output is AgentPreviewResponse {
@@ -144,10 +108,6 @@ export function getAnimationText(part: {
case "output-available": {
const output = parseOutput(part.output);
if (!output) return "Creating a new agent";
if (isOperationStartedOutput(output)) return "Agent creation started";
if (isOperationPendingOutput(output)) return "Agent creation in progress";
if (isOperationInProgressOutput(output))
return "Agent creation already in progress";
if (isAgentSavedOutput(output)) return `Saved ${output.agent_name}`;
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
if (isClarificationNeededOutput(output)) return "Needs clarification";

View File

@@ -1,18 +1,27 @@
"use client";
import { WarningDiamondIcon } from "@phosphor-icons/react";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import {
BookOpenIcon,
PencilSimpleIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import Image from "next/image";
import NextLink from "next/link";
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
import sparklesImg from "../../components/MiniGame/assets/sparkles.png";
import { MiniGame } from "../../components/MiniGame/MiniGame";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import {
ContentCardDescription,
ContentCodeBlock,
ContentGrid,
ContentLink,
ContentHint,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import { MiniGame } from "../../components/MiniGame/MiniGame";
import {
ClarificationQuestionsCard,
ClarifyingQuestion,
@@ -26,9 +35,6 @@ import {
isAgentSavedOutput,
isClarificationNeededOutput,
isErrorOutput,
isOperationInProgressOutput,
isOperationPendingOutput,
isOperationStartedOutput,
ToolIcon,
truncateText,
type EditAgentToolOutput,
@@ -46,7 +52,7 @@ interface Props {
part: EditAgentToolPart;
}
function getAccordionMeta(output: EditAgentToolOutput): {
function getAccordionMeta(output: EditAgentToolOutput | null): {
icon: React.ReactNode;
title: string;
titleClassName?: string;
@@ -55,8 +61,16 @@ function getAccordionMeta(output: EditAgentToolOutput): {
} {
const icon = <AccordionIcon />;
if (!output) {
return {
icon,
title: "Editing agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
}
if (isAgentSavedOutput(output)) {
return { icon, title: output.agent_name };
return { icon, title: output.agent_name, expanded: true };
}
if (isAgentPreviewOutput(output)) {
return {
@@ -73,16 +87,6 @@ function getAccordionMeta(output: EditAgentToolOutput): {
description: `${questions.length} question${questions.length === 1 ? "" : "s"}`,
};
}
if (
isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output)
) {
return {
icon,
title: output.message || "Agent editing started",
};
}
return {
icon: (
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
@@ -101,21 +105,12 @@ export function EditAgentTool({ part }: Props) {
const output = getEditAgentToolOutput(part);
const isError =
part.state === "output-error" || (!!output && isErrorOutput(output));
const isOperating =
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output));
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output) ||
isAgentPreviewOutput(output) ||
isAgentSavedOutput(output) ||
isClarificationNeededOutput(output) ||
isErrorOutput(output));
const isOperating = !output;
// Show accordion for operating state and successful outputs, but not for errors
// (errors are shown inline so they get replaced when retrying)
const hasExpandableContent = !isError;
function handleClarificationAnswers(answers: Record<string, string>) {
const questions =
@@ -137,53 +132,114 @@ export function EditAgentTool({ part }: Props) {
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isStreaming && (
<ToolAccordion
icon={<AccordionIcon />}
title="Editing agent, this may take a few minutes. Play while you wait."
expanded
>
<ContentGrid>
<MiniGame />
</ContentGrid>
</ToolAccordion>
{isOperating && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
)}
{hasExpandableContent && output && (
<ToolAccordion {...getAccordionMeta(output)}>
{isOperating && output.message && (
<ContentMessage>{output.message}</ContentMessage>
)}
{isError && output && isErrorOutput(output) && (
<div className="space-y-3 rounded-lg border border-red-200 bg-red-50 p-4">
<div className="flex items-start gap-2">
<WarningDiamondIcon
size={20}
weight="regular"
className="mt-0.5 shrink-0 text-red-500"
/>
<div className="flex-1 space-y-2">
<Text variant="body-medium" className="text-red-900">
{output.message ||
"Failed to edit the agent. Please try again."}
</Text>
{output.error && (
<details className="text-xs text-red-700">
<summary className="cursor-pointer font-medium">
Technical details
</summary>
<pre className="mt-2 max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2">
{formatMaybeJson(output.error)}
</pre>
</details>
)}
{output.details && (
<pre className="max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-red-100 p-2 text-xs text-red-700">
{formatMaybeJson(output.details)}
</pre>
)}
</div>
</div>
<Button
variant="outline"
size="small"
onClick={() => onSend("Please try editing the agent again.")}
>
Try again
</Button>
</div>
)}
{isAgentSavedOutput(output) && (
{hasExpandableContent && (
<ToolAccordion {...getAccordionMeta(output)}>
{isOperating && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
<div className="flex flex-wrap gap-2">
<ContentLink href={output.library_agent_link}>
Open in library
</ContentLink>
<ContentLink href={output.agent_page_link}>
Open in builder
</ContentLink>
</div>
<ContentCodeBlock>
{truncateText(
formatMaybeJson({ agent_id: output.agent_id }),
800,
)}
</ContentCodeBlock>
<MiniGame />
<ContentHint>
This could take a few minutes play while you wait!
</ContentHint>
</ContentGrid>
)}
{isAgentPreviewOutput(output) && (
{output && isAgentSavedOutput(output) && (
<div className="rounded-xl border border-border/60 bg-card p-4 shadow-sm">
<div className="flex items-baseline gap-2">
<Image
src={sparklesImg}
alt="sparkles"
width={24}
height={24}
className="relative top-1"
/>
<Text
variant="body-medium"
className="mb-2 text-[16px] text-black"
>
Agent{" "}
<span className="text-violet-600">{output.agent_name}</span>{" "}
has been updated!
</Text>
</div>
<div className="mt-3 flex flex-wrap gap-4">
<Button variant="outline" size="small">
<NextLink
href={output.library_agent_link}
className="inline-flex items-center gap-1.5"
target="_blank"
rel="noopener noreferrer"
>
<BookOpenIcon size={14} weight="regular" />
Open in library
</NextLink>
</Button>
<Button variant="outline" size="small">
<NextLink
href={output.agent_page_link}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center gap-1.5"
>
<PencilSimpleIcon size={14} weight="regular" />
Open in builder
</NextLink>
</Button>
</div>
</div>
)}
{output && isAgentPreviewOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.description?.trim() && (
@@ -197,7 +253,7 @@ export function EditAgentTool({ part }: Props) {
</ContentGrid>
)}
{isClarificationNeededOutput(output) && (
{output && isClarificationNeededOutput(output) && (
<ClarificationQuestionsCard
questions={(output.questions ?? []).map((q) => {
const item: ClarifyingQuestion = {
@@ -215,22 +271,6 @@ export function EditAgentTool({ part }: Props) {
onSubmitAnswers={handleClarificationAnswers}
/>
)}
{isErrorOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.error && (
<ContentCodeBlock>
{formatMaybeJson(output.error)}
</ContentCodeBlock>
)}
{output.details && (
<ContentCodeBlock>
{formatMaybeJson(output.details)}
</ContentCodeBlock>
)}
</ContentGrid>
)}
</ToolAccordion>
)}
</div>

View File

@@ -2,9 +2,6 @@ import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentP
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import {
NotePencilIcon,
@@ -15,9 +12,6 @@ import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
export type EditAgentToolOutput =
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
@@ -37,9 +31,6 @@ function parseOutput(output: unknown): EditAgentToolOutput | null {
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (
type === ResponseType.operation_started ||
type === ResponseType.operation_pending ||
type === ResponseType.operation_in_progress ||
type === ResponseType.agent_preview ||
type === ResponseType.agent_saved ||
type === ResponseType.clarification_needed ||
@@ -47,9 +38,6 @@ function parseOutput(output: unknown): EditAgentToolOutput | null {
) {
return output as EditAgentToolOutput;
}
if ("operation_id" in output && "tool_name" in output)
return output as OperationStartedResponse | OperationPendingResponse;
if ("tool_call_id" in output) return output as OperationInProgressResponse;
if ("agent_json" in output && "agent_name" in output)
return output as AgentPreviewResponse;
if ("agent_id" in output && "library_agent_id" in output)
@@ -68,30 +56,6 @@ export function getEditAgentToolOutput(
return parseOutput((part as { output?: unknown }).output);
}
export function isOperationStartedOutput(
output: EditAgentToolOutput,
): output is OperationStartedResponse {
return (
output.type === ResponseType.operation_started ||
("operation_id" in output && "tool_name" in output)
);
}
export function isOperationPendingOutput(
output: EditAgentToolOutput,
): output is OperationPendingResponse {
return output.type === ResponseType.operation_pending;
}
export function isOperationInProgressOutput(
output: EditAgentToolOutput,
): output is OperationInProgressResponse {
return (
output.type === ResponseType.operation_in_progress ||
"tool_call_id" in output
);
}
export function isAgentPreviewOutput(
output: EditAgentToolOutput,
): output is AgentPreviewResponse {
@@ -132,10 +96,6 @@ export function getAnimationText(part: {
case "output-available": {
const output = parseOutput(part.output);
if (!output) return "Editing the agent";
if (isOperationStartedOutput(output)) return "Agent update started";
if (isOperationPendingOutput(output)) return "Agent update in progress";
if (isOperationInProgressOutput(output))
return "Agent update already in progress";
if (isAgentSavedOutput(output)) return `Saved "${output.agent_name}"`;
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
if (isClarificationNeededOutput(output)) return "Needs clarification";

View File

@@ -686,17 +686,20 @@ export function GenericTool({ part }: Props) {
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon
category={category}
isStreaming={isStreaming}
isError={isError}
/>
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{/* Only show loading text when NOT showing accordion */}
{!showAccordion && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon
category={category}
isStreaming={isStreaming}
isError={isError}
/>
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
)}
{showAccordion && accordionData ? (
<ToolAccordion

View File

@@ -69,13 +69,20 @@ export function RunAgentTool({ part }: Props) {
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{/* Only show loading text when NOT showing accordion or other content */}
{!isStreaming &&
!setupRequirementsOutput &&
!agentDetailsOutput &&
!needLoginOutput &&
!hasExpandableContent && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
)}
{isStreaming && !output && (
<ToolAccordion

View File

@@ -20,7 +20,8 @@ export function useChatSession() {
enabled: !!sessionId,
staleTime: Infinity,
refetchOnWindowFocus: false,
refetchOnReconnect: false,
refetchOnReconnect: true,
refetchOnMount: true,
},
});
@@ -115,6 +116,7 @@ export function useChatSession() {
hydratedMessages,
hasActiveStream,
isLoadingSession: sessionQuery.isLoading,
isSessionError: sessionQuery.isError,
createSession,
isCreatingSession,
};

View File

@@ -14,7 +14,6 @@ import { DefaultChatTransport } from "ai";
import type { UIMessage } from "ai";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useChatSession } from "./useChatSession";
import { useLongRunningToolPolling } from "./hooks/useLongRunningToolPolling";
const STREAM_START_TIMEOUT_MS = 12_000;
@@ -36,6 +35,46 @@ function resolveInProgressTools(
}));
}
/** Build a fingerprint from a message's role + text/tool content for cross-boundary dedup. */
function messageFingerprint(msg: UIMessage): string {
const fragments = msg.parts.map((p) => {
if ("text" in p && typeof p.text === "string") return p.text;
if ("toolCallId" in p && typeof p.toolCallId === "string")
return `tool:${p.toolCallId}`;
return "";
});
return `${msg.role}::${fragments.join("\n")}`;
}
/**
* Deduplicate messages by ID *and* by content fingerprint.
* ID-based dedup catches duplicates within the same source (e.g. two
* identical stream events). Fingerprint-based dedup catches duplicates
* across the hydration/stream boundary where IDs differ (synthetic
* `${sessionId}-${index}` vs AI SDK nanoid).
*
* NOTE: Fingerprint dedup only applies to assistant messages, not user messages.
* Users should be able to send the same message multiple times.
*/
function deduplicateMessages(messages: UIMessage[]): UIMessage[] {
const seenIds = new Set<string>();
const seenFingerprints = new Set<string>();
return messages.filter((msg) => {
if (seenIds.has(msg.id)) return false;
seenIds.add(msg.id);
// Only apply fingerprint deduplication to assistant messages
// User messages should allow duplicates (same text sent multiple times)
if (msg.role === "assistant") {
const fp = messageFingerprint(msg);
if (fp !== "::" && seenFingerprints.has(fp)) return false;
seenFingerprints.add(fp);
}
return true;
});
}
export function useCopilotPage() {
const { isUserLoading, isLoggedIn } = useSupabase();
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
@@ -52,6 +91,7 @@ export function useCopilotPage() {
hydratedMessages,
hasActiveStream,
isLoadingSession,
isSessionError,
createSession,
isCreatingSession,
} = useChatSession();
@@ -114,7 +154,7 @@ export function useCopilotPage() {
);
const {
messages,
messages: rawMessages,
sendMessage,
stop: sdkStop,
status,
@@ -129,6 +169,12 @@ export function useCopilotPage() {
// call resumeStream() manually after hydration + active_stream detection.
});
// Deduplicate messages continuously to prevent duplicates when resuming streams
const messages = useMemo(
() => deduplicateMessages(rawMessages),
[rawMessages],
);
// Wrap AI SDK's stop() to also cancel the backend executor task.
// sdkStop() aborts the SSE fetch instantly (UI feedback), then we fire
// the cancel API to actually stop the executor and wait for confirmation.
@@ -184,19 +230,26 @@ export function useCopilotPage() {
if (status === "streaming" || status === "submitted") return;
setMessages((prev) => {
if (prev.length >= hydratedMessages.length) return prev;
return hydratedMessages;
// Deduplicate to handle rare cases where duplicate streams might occur
return deduplicateMessages(hydratedMessages);
});
}, [hydratedMessages, setMessages, status]);
// Ref: tracks whether we've already resumed for a given session.
// Reset when the stream ends so re-resume is possible if the backend
// task is still running (SSE dropped but executor didn't finish).
const hasResumedRef = useRef<string | null>(null);
// Format: Map<sessionId, hasResumed>
const hasResumedRef = useRef<Map<string, boolean>>(new Map());
// When the stream ends (or drops), invalidate the session cache so the
// next hydration fetches fresh messages from the backend. Without this,
// staleTime: Infinity means the cache keeps the pre-stream data forever,
// and any messages added during streaming are lost on remount/navigation.
// Track status transitions for cache invalidation and auto-reconnect.
// Auto-reconnect: GCP's L7 load balancer kills SSE connections at ~5 min.
// When that happens the AI SDK goes "streaming" → "error". If the backend
// executor is still running (hasActiveStream), we call resumeStream() to
// reconnect via GET and replay from Redis.
const MAX_RECONNECT_ATTEMPTS = 3;
const reconnectAttemptsRef = useRef(0);
const prevStatusRef = useRef(status);
useEffect(() => {
const prev = prevStatusRef.current;
@@ -204,33 +257,63 @@ export function useCopilotPage() {
const wasActive = prev === "streaming" || prev === "submitted";
const isIdle = status === "ready" || status === "error";
// Invalidate session cache when stream ends so hydration fetches fresh data
if (wasActive && isIdle && sessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
// Allow re-resume if the backend task is still running.
hasResumedRef.current = null;
}
}, [status, sessionId, queryClient]);
// Auto-reconnect on mid-stream SSE drop
if (
prev === "streaming" &&
status === "error" &&
sessionId &&
hasActiveStream
) {
if (reconnectAttemptsRef.current < MAX_RECONNECT_ATTEMPTS) {
reconnectAttemptsRef.current += 1;
const attempt = reconnectAttemptsRef.current;
console.info(
`[copilot] SSE dropped mid-stream, reconnecting (attempt ${attempt}/${MAX_RECONNECT_ATTEMPTS})...`,
);
const timer = setTimeout(() => resumeStream(), 1_000);
return () => clearTimeout(timer);
} else {
toast({
title: "Connection lost",
description:
"Could not reconnect to the stream. Please refresh the page.",
variant: "destructive",
});
}
}
// Reset reconnect counter when stream completes normally or resumes
if (status === "ready" || status === "streaming") {
reconnectAttemptsRef.current = 0;
}
}, [status, sessionId, hasActiveStream, queryClient, resumeStream]);
// Resume an active stream AFTER hydration completes.
// The backend returns active_stream info when a task is still running.
// We wait for hydration so the AI SDK has the conversation history
// before the resumed stream appends the in-progress assistant message.
// IMPORTANT: Only runs when page loads with existing active stream (reconnection).
// Does NOT run when new streams start during active conversation.
useEffect(() => {
if (!hasActiveStream || !sessionId) return;
if (!sessionId) return;
if (!hasActiveStream) return;
if (!hydratedMessages || hydratedMessages.length === 0) return;
if (status === "streaming" || status === "submitted") return;
// Only resume once per session to avoid re-triggering after stream ends
if (hasResumedRef.current === sessionId) return;
hasResumedRef.current = sessionId;
resumeStream();
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
// is in progress. When the backend completes, the session data will contain
// the final tool output — this hook detects the change and updates messages.
useLongRunningToolPolling(sessionId, messages, setMessages);
// Never resume if currently streaming
if (status === "streaming" || status === "submitted") return;
// Only resume once per session
if (hasResumedRef.current.get(sessionId)) return;
// Mark as resumed immediately to prevent race conditions
hasResumedRef.current.set(sessionId, true);
resumeStream();
}, [sessionId, hasActiveStream, hydratedMessages, status, resumeStream]);
// Clear messages when session is null
useEffect(() => {
@@ -321,6 +404,7 @@ export function useCopilotPage() {
stop,
isReconnecting,
isLoadingSession,
isSessionError,
isCreatingSession,
isUserLoading,
isLoggedIn,

View File

@@ -1,64 +0,0 @@
import { environment } from "@/services/environment";
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest } from "next/server";
import { normalizeSSEStream, SSE_HEADERS } from "../../../sse-helpers";
export async function GET(
request: NextRequest,
{ params }: { params: Promise<{ taskId: string }> },
) {
const { taskId } = await params;
const searchParams = request.nextUrl.searchParams;
const lastMessageId = searchParams.get("last_message_id") || "0-0";
try {
const token = await getServerAuthToken();
const backendUrl = environment.getAGPTServerBaseUrl();
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
streamUrl.searchParams.set("last_message_id", lastMessageId);
const headers: Record<string, string> = {
Accept: "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(streamUrl.toString(), {
method: "GET",
headers,
});
if (!response.ok) {
const error = await response.text();
return new Response(error, {
status: response.status,
headers: { "Content-Type": "application/json" },
});
}
if (!response.body) {
return new Response(null, { status: 204 });
}
return new Response(normalizeSSEStream(response.body), {
headers: SSE_HEADERS,
});
} catch (error) {
console.error("Task stream proxy error:", error);
return new Response(
JSON.stringify({
error: "Failed to connect to task stream",
detail: error instanceof Error ? error.message : String(error),
}),
{
status: 500,
headers: { "Content-Type": "application/json" },
},
);
}
}

View File

@@ -961,63 +961,6 @@
}
}
},
"/api/chat/operations/{operation_id}/complete": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Complete Operation",
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
"operationId": "postV2CompleteOperation",
"parameters": [
{
"name": "operation_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Operation Id" }
},
{
"name": "x-api-key",
"in": "header",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "X-Api-Key"
}
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OperationCompleteRequest"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Postv2Completeoperation"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/schema/tool-responses": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -1057,12 +1000,7 @@
{ "$ref": "#/components/schemas/BlockDetailsResponse" },
{ "$ref": "#/components/schemas/BlockOutputResponse" },
{ "$ref": "#/components/schemas/DocSearchResultsResponse" },
{ "$ref": "#/components/schemas/DocPageResponse" },
{ "$ref": "#/components/schemas/OperationStartedResponse" },
{ "$ref": "#/components/schemas/OperationPendingResponse" },
{
"$ref": "#/components/schemas/OperationInProgressResponse"
}
{ "$ref": "#/components/schemas/DocPageResponse" }
],
"title": "Response Getv2[Dummy] Tool Response Type Export For Codegen"
}
@@ -1185,7 +1123,7 @@
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns active_stream info for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
"operationId": "getV2GetSession",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1283,7 +1221,9 @@
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CancelTaskResponse" }
"schema": {
"$ref": "#/components/schemas/CancelSessionResponse"
}
}
}
},
@@ -1337,7 +1277,7 @@
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1375,94 +1315,6 @@
}
}
},
"/api/chat/tasks/{task_id}": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Task Status",
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2GetTaskStatus",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Getv2Gettaskstatus"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/tasks/{task_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Task",
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
"operationId": "getV2StreamTask",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
},
{
"name": "last_message_id",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
"default": "0-0",
"title": "Last Message Id"
},
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -6562,13 +6414,11 @@
},
"ActiveStreamInfo": {
"properties": {
"task_id": { "type": "string", "title": "Task Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" },
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" }
"turn_id": { "type": "string", "title": "Turn Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" }
},
"type": "object",
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
"required": ["turn_id", "last_message_id"],
"title": "ActiveStreamInfo",
"description": "Information about an active stream for reconnection."
},
@@ -7575,13 +7425,9 @@
"required": ["file"],
"title": "Body_postV2Upload submission media"
},
"CancelTaskResponse": {
"CancelSessionResponse": {
"properties": {
"cancelled": { "type": "boolean", "title": "Cancelled" },
"task_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Task Id"
},
"reason": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Reason"
@@ -7589,8 +7435,8 @@
},
"type": "object",
"required": ["cancelled"],
"title": "CancelTaskResponse",
"description": "Response model for the cancel task endpoint."
"title": "CancelSessionResponse",
"description": "Response model for the cancel session endpoint."
},
"ChangelogEntry": {
"properties": {
@@ -10107,87 +9953,6 @@
],
"title": "OnboardingStep"
},
"OperationCompleteRequest": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"result": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
{ "type": "string" },
{ "type": "null" }
],
"title": "Result"
},
"error": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Error"
}
},
"type": "object",
"required": ["success"],
"title": "OperationCompleteRequest",
"description": "Request model for external completion webhook."
},
"OperationInProgressResponse": {
"properties": {
"type": {
"$ref": "#/components/schemas/ResponseType",
"default": "operation_in_progress"
},
"message": { "type": "string", "title": "Message" },
"session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
},
"tool_call_id": { "type": "string", "title": "Tool Call Id" }
},
"type": "object",
"required": ["message", "tool_call_id"],
"title": "OperationInProgressResponse",
"description": "Response when an operation is already in progress.\n\nReturned for idempotency when the same tool_call_id is requested again\nwhile the background task is still running."
},
"OperationPendingResponse": {
"properties": {
"type": {
"$ref": "#/components/schemas/ResponseType",
"default": "operation_pending"
},
"message": { "type": "string", "title": "Message" },
"session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
},
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" }
},
"type": "object",
"required": ["message", "operation_id", "tool_name"],
"title": "OperationPendingResponse",
"description": "Response stored in chat history while a long-running operation is executing.\n\nThis is persisted to the database so users see a pending state when they\nrefresh before the operation completes."
},
"OperationStartedResponse": {
"properties": {
"type": {
"$ref": "#/components/schemas/ResponseType",
"default": "operation_started"
},
"message": { "type": "string", "title": "Message" },
"session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
},
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" },
"task_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Task Id"
}
},
"type": "object",
"required": ["message", "operation_id", "tool_name"],
"title": "OperationStartedResponse",
"description": "Response when a long-running operation has been started in the background.\n\nThis is returned immediately to the client while the operation continues\nto execute. The user can close the tab and check back later.\n\nThe task_id can be used to reconnect to the SSE stream via\nGET /chat/tasks/{task_id}/stream?last_idx=0"
},
"Pagination": {
"properties": {
"total_items": {
@@ -10844,13 +10609,10 @@
"workspace_file_metadata",
"workspace_file_written",
"workspace_file_deleted",
"operation_started",
"operation_pending",
"operation_in_progress",
"input_validation_error",
"web_fetch",
"bash_exec",
"operation_status",
"feature_request_search",
"feature_request_created",
"suggested_goal"

View File

@@ -218,6 +218,17 @@ If you initially installed Docker with Hyper-V, you **dont need to reinstall*
For more details, refer to [Docker's official documentation](https://docs.docker.com/desktop/windows/wsl/).
### ⚠️ Podman Not Supported
AutoGPT requires **Docker** (Docker Desktop or Docker Engine). **Podman and podman-compose are not supported** and may cause path resolution issues, particularly on Windows.
If you see errors like:
```text
Error: the specified Containerfile or Dockerfile does not exist, ..\..\autogpt_platform\backend\Dockerfile
```
This indicates you're using Podman instead of Docker. Please install [Docker Desktop](https://docs.docker.com/desktop/) and use `docker compose` instead of `podman-compose`.
## Development